micro_sam.automatic_segmentation
1import os 2import warnings 3from glob import glob 4from tqdm import tqdm 5from pathlib import Path 6from functools import partial 7from typing import Dict, List, Optional, Union, Tuple, Literal 8 9import numpy as np 10import imageio.v3 as imageio 11 12from torch_em.data.datasets.util import split_kwargs 13 14from . import util 15from .instance_segmentation import ( 16 get_instance_segmentation_generator, get_decoder, AMGBase, 17 AutomaticMaskGenerator, TiledAutomaticMaskGenerator, 18 AutomaticPromptGenerator, TiledAutomaticPromptGenerator, 19 InstanceSegmentationWithDecoder, TiledInstanceSegmentationWithDecoder, 20 DEFAULT_SEGMENTATION_MODE_WITH_DECODER, 21) 22from .multi_dimensional_segmentation import automatic_3d_segmentation, automatic_tracking_implementation 23 24 25def get_predictor_and_segmenter( 26 model_type: str, 27 checkpoint: Optional[Union[os.PathLike, str]] = None, 28 device: str = None, 29 segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None, 30 is_tiled: bool = False, 31 predictor=None, 32 state=None, 33 **kwargs, 34) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: 35 f"""Get the Segment Anything model and class for automatic instance segmentation. 36 37 Args: 38 model_type: The Segment Anything model choice. 39 checkpoint: The filepath to the stored model checkpoints. 40 device: The torch device. By default, automatically chooses the best available device. 41 segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'. 42 By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used 43 if a decoder is passed, otherwise 'amg' is used. 44 is_tiled: Whether to return segmenter for performing segmentation in tiling window style. 45 By default, set to 'False'. 46 predictor: The pre-loaded predictor (optional). 47 state: The pre-loaded state (optional). 48 kwargs: Keyword arguments for the automatic mask generation class. 49 50 Returns: 51 The Segment Anything model. 52 The automatic instance segmentation class. 53 """ 54 # Get the predictor and state for Segment Anything Model. 55 if predictor is None: 56 device = util.get_device(device=device) 57 predictor, state = util.get_sam_model( 58 model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True 59 ) 60 else: 61 assert state is not None 62 63 if segmentation_mode in (None, "auto"): 64 segmentation_mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg" 65 66 if segmentation_mode.lower() == "amg": 67 decoder = None 68 else: 69 if "decoder_state" not in state: 70 raise RuntimeError( 71 f"You have passed 'segmentation_mode={segmentation_mode}', but your model does not contain a decoder." 72 ) 73 decoder_state = state["decoder_state"] 74 decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) 75 76 segmenter = get_instance_segmentation_generator( 77 predictor=predictor, is_tiled=is_tiled, decoder=decoder, segmentation_mode=segmentation_mode, **kwargs 78 ) 79 return predictor, segmenter 80 81 82def _add_suffix_to_output_path(output_path: Union[str, os.PathLike], suffix: str) -> str: 83 fpath = Path(output_path).resolve() 84 fext = fpath.suffix if fpath.suffix else ".tif" 85 return str(fpath.with_name(f"{fpath.stem}{suffix}{fext}")) 86 87 88def automatic_tracking( 89 predictor: util.SamPredictor, 90 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 91 input_path: Union[Union[os.PathLike, str], np.ndarray], 92 output_path: Optional[Union[os.PathLike, str]] = None, 93 embedding_path: Optional[Union[os.PathLike, str]] = None, 94 key: Optional[str] = None, 95 tile_shape: Optional[Tuple[int, int]] = None, 96 halo: Optional[Tuple[int, int]] = None, 97 verbose: bool = True, 98 return_embeddings: bool = False, 99 annotate: bool = False, 100 batch_size: int = 1, 101 **generate_kwargs 102) -> Tuple[np.ndarray, List[Dict]]: 103 """Run automatic tracking for the input timeseries. 104 105 Args: 106 predictor: The Segment Anything model. 107 segmenter: The automatic instance segmentation class. 108 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 109 or a container file (e.g. hdf5 or zarr). 110 output_path: The folder where the tracking outputs will be saved in CTC format. 111 embedding_path: The path where the embeddings are cached already / will be saved. 112 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 113 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 114 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 115 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 116 verbose: Verbosity flag. By default, set to 'True'. 117 return_embeddings: Whether to return the precomputed image embeddings. 118 By default, does not return the embeddings. 119 annotate: Whether to activate the annotator for continue annotation process. 120 By default, does not activate the annotator. 121 batch_size: The batch size to compute image embeddings over tiles / z-planes. 122 By default, does it sequentially, i.e. one after the other. 123 generate_kwargs: optional keyword arguments for the generate function of the AMG, APG, or AIS class. 124 125 Returns: 126 The tracking result as a timeseries, where each object is labeled by its track id. 127 The lineages representing cell divisions, stored as a dictionary. 128 """ 129 # Load the input image file. 130 # We assume that it has to be read from file if it is a str or pathlike. 131 # Otherwise we assume it is a numpy array like object. 132 image_data = util.load_image_data(input_path, key) if isinstance(input_path, (str, os.PathLike)) else input_path 133 134 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 135 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 136 137 gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent") 138 segmentation, lineage, image_embeddings = automatic_tracking_implementation( 139 image_data, 140 predictor, 141 segmenter, 142 embedding_path=embedding_path, 143 gap_closing=gap_closing, 144 min_time_extent=min_time_extent, 145 tile_shape=tile_shape, 146 halo=halo, 147 verbose=verbose, 148 batch_size=batch_size, 149 return_embeddings=True, 150 output_folder=output_path, 151 **generate_kwargs, 152 ) 153 154 if annotate: 155 # TODO We need to support initialization of the tracking annotator with the tracking result for this. 156 raise NotImplementedError("Annotation after running the automated tracking is currently not supported.") 157 158 if return_embeddings: 159 return segmentation, lineage, image_embeddings 160 else: 161 return segmentation, lineage 162 163 164def automatic_instance_segmentation( 165 predictor: util.SamPredictor, 166 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 167 input_path: Union[Union[os.PathLike, str], np.ndarray], 168 output_path: Optional[Union[os.PathLike, str]] = None, 169 embedding_path: Optional[Union[os.PathLike, str]] = None, 170 mask_path: Optional[Union[Union[os.PathLike, str], np.ndarray]] = None, 171 key: Optional[str] = None, 172 mask_key: Optional[str] = None, 173 ndim: Optional[int] = None, 174 tile_shape: Optional[Tuple[int, int]] = None, 175 halo: Optional[Tuple[int, int]] = None, 176 verbose: bool = True, 177 return_embeddings: bool = False, 178 annotate: bool = False, 179 batch_size: int = 1, 180 **generate_kwargs 181) -> np.ndarray: 182 """Run automatic segmentation for the input image. 183 184 Args: 185 predictor: The Segment Anything model. 186 segmenter: The automatic instance segmentation class. 187 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 188 or a container file (e.g. hdf5 or zarr). 189 output_path: The output path where the instance segmentations will be saved. 190 embedding_path: The path where the embeddings are cached already / will be saved. 191 mask_path: The path to an optional foreground mask. Areas outside of the foreground will not be processed. 192 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 193 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 194 mask_key: The key to the (optional) foreground mask. 195 ndim: The dimensionality of the data. By default the dimensionality of the data will be used. 196 If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. 197 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 198 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 199 verbose: Verbosity flag. By default, set to 'True'. 200 return_embeddings: Whether to return the precomputed image embeddings. 201 By default, does not return the embeddings. 202 annotate: Whether to activate the annotator for continue annotation process. 203 By default, does not activate the annotator. 204 batch_size: The batch size to compute image embeddings over tiles / z-planes. 205 By default, does it sequentially, i.e. one after the other. 206 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 207 208 Returns: 209 The segmentation result. 210 """ 211 # Avoid overwriting already stored segmentations. 212 if output_path is not None: 213 output_path = Path(output_path).with_suffix(".tif") 214 if os.path.exists(output_path): 215 print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.") 216 return 217 218 # We assume that it has to be read from file if it is a str or pathlike. 219 # Otherwise we assume it is a numpy array like object. 220 image_data = util.load_image_data(input_path, key) if isinstance(input_path, (str, os.PathLike)) else input_path 221 222 ndim = image_data.ndim if ndim is None else ndim 223 224 # Load the mask defining foreground if it was given. 225 if mask_path is None: 226 mask = None 227 else: 228 mask = util.load_image_data(mask_path, mask_key) if isinstance(mask_path, (str, os.PathLike)) else mask_path 229 230 if ndim == 2: 231 if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): 232 raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") 233 234 # Precompute the image embeddings. 235 image_embeddings = util.precompute_image_embeddings( 236 predictor=predictor, 237 input_=image_data, 238 save_path=embedding_path, 239 ndim=ndim, 240 tile_shape=tile_shape, 241 halo=halo, 242 verbose=verbose, 243 batch_size=batch_size, 244 mask=mask, 245 ) 246 initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose) 247 if mask is not None: 248 initialize_kwargs["mask"] = mask 249 250 # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing. 251 # In this case, we also add the batch size to the initialize kwargs, 252 # so that the segmentation decoder can be applied in a batched fashion. 253 if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None: 254 if not isinstance(segmenter, TiledAutomaticPromptGenerator): 255 generate_kwargs.update({"tile_shape": tile_shape, "halo": halo}) 256 initialize_kwargs["batch_size"] = batch_size 257 258 segmenter.initialize(**initialize_kwargs) 259 instances = segmenter.generate(**generate_kwargs) 260 261 else: 262 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 263 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 264 if mask is not None: 265 raise NotImplementedError 266 267 instances, image_embeddings = automatic_3d_segmentation( 268 volume=image_data, 269 predictor=predictor, 270 segmentor=segmenter, 271 embedding_path=embedding_path, 272 tile_shape=tile_shape, 273 halo=halo, 274 verbose=verbose, 275 return_embeddings=True, 276 batch_size=batch_size, 277 **generate_kwargs 278 ) 279 280 # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage. 281 if output_path is not None: 282 _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path 283 imageio.imwrite(_output_path, instances, compression="zlib") 284 if verbose: 285 print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.") 286 287 # Allow opening the automatic segmentation in the annotator for further annotation, if desired. 288 if annotate: 289 from micro_sam.sam_annotator import annotator_2d, annotator_3d 290 annotator_function = annotator_2d if ndim == 2 else annotator_3d 291 292 viewer = annotator_function( 293 image=image_data, 294 model_type=predictor.model_name, 295 embedding_path=image_embeddings, # Providing the precomputed image embeddings. 296 segmentation_result=instances, # Initializes the automatic segmentation to the annotator. 297 tile_shape=tile_shape, 298 halo=halo, 299 return_viewer=True, # Returns the viewer, which allows the user to store the updated segmentations. 300 ) 301 302 # Start the GUI here 303 import napari 304 napari.run() 305 306 # We extract the segmentation in "committed_objects" layer, where the user either: 307 # a) Performed interactive segmentation / corrections and committed them, OR 308 # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is. 309 instances = viewer.layers["committed_objects"].data 310 311 # Save the instance segmentation, if 'output_path' provided. 312 if output_path is not None: 313 imageio.imwrite(output_path, instances, compression="zlib") 314 if verbose: 315 print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.") 316 317 if return_embeddings: 318 return instances, image_embeddings 319 else: 320 return instances 321 322 323def _get_inputs_from_paths(paths, pattern): 324 "Function to get all filepaths in a directory." 325 326 if isinstance(paths, str): 327 paths = [paths] 328 329 fpaths = [] 330 for path in paths: 331 if os.path.isfile(path): # It is just one filepath. 332 fpaths.append(path) 333 else: # Otherwise, if the path is a directory, fetch all inputs provided with a pattern. 334 assert pattern is not None, \ 335 f"You must provide a pattern to search for files in the directory: '{os.path.abspath(path)}'." 336 fpaths.extend(glob(os.path.join(path, pattern))) 337 338 return fpaths 339 340 341def main(): 342 """@private""" 343 import argparse 344 345 available_models = list(util.get_model_names()) 346 available_models = ", ".join(available_models) 347 348 parser = argparse.ArgumentParser( 349 description="Run automatic segmentation or tracking for 2d, 3d or timeseries data.\n" 350 "Either a single input file or multiple input files are supported. You can specify multiple files " 351 "by either providing multiple filepaths to the '--i/--input_paths' argument, or by providing an argument " 352 "to '--pattern' to use a wildcard pattern ('*') for selecting multiple files.\n" 353 "NOTE: for automatic 3d segmentation or tracking the data has to be stored as volume / timeseries, " 354 "stacking individual tif images is not supported.\n" 355 "Segmentation is performed using one of the three modes supported by micro_sam: \n" 356 "automatic instance segmentation (AIS), automatic prompt generation (APG) or automatic mask generation (AMG).\n" 357 "In addition to the options listed below, " 358 "you can also passed additional arguments for these three segmentation modes:\n" 359 "For AIS: '--center_distance_threshold', '--boundary_distance_threshold' and other arguments of `InstanceSegmentationWithDecoder.generate`.\n" # noqa 360 "FOR APG: '--center_distance_threshold', '--boundary_distance_threshold' and other arguments of `AutomaticPromptGenerator.generate`.\n" # noqa 361 "For AMG: '--pred_iou_thresh', '--stability_score_thresh' and other arguments of `AutomaticMaskGenerator.generate`." # noqa 362 ) 363 parser.add_argument( 364 "-i", "--input_path", required=True, type=str, nargs="+", 365 help="The filepath(s) to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " # noqa 366 "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter." 367 ) 368 parser.add_argument( 369 "-o", "--output_path", required=True, type=str, 370 help="The filepath to store the results. If multiple inputs are provied, " 371 "this should be a folder. For a single image, you should provide the path to a tif file for the output segmentation." # noqa 372 "NOTE: Segmentation results are stored as tif files, tracking results in the CTC fil format ." 373 ) 374 parser.add_argument( 375 "-e", "--embedding_path", default=None, type=str, 376 help="An optional path where the embeddings will be saved. If multiple inputs are provided, " 377 "this should be a folder. Otherwise you can store embeddings in single zarr file." 378 ) 379 parser.add_argument( 380 "--pattern", type=str, help="Pattern / wildcard for selecting files in a folder. To select all files use '*'." 381 ) 382 parser.add_argument( 383 "-k", "--key", default=None, type=str, 384 help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, " 385 "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'." 386 ) 387 parser.add_argument( 388 "-m", "--model_type", default=util._DEFAULT_MODEL, type=str, 389 help=f"The segment anything model that will be used, one of {available_models}." 390 ) 391 parser.add_argument( 392 "-c", "--checkpoint", default=None, type=str, help="Checkpoint from which the SAM model will be loaded." 393 ) 394 parser.add_argument( 395 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None 396 ) 397 parser.add_argument( 398 "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None 399 ) 400 parser.add_argument( 401 "-n", "--ndim", default=None, type=int, 402 help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension." 403 ) 404 parser.add_argument( 405 "--mode", default="auto", type=str, 406 help="The choice of automatic segmentation mode. Either 'auto', 'amg', 'apg', or 'ais'." 407 ) 408 parser.add_argument( 409 "--annotate", action="store_true", 410 help="Whether to continue annotation after the automatic segmentation is generated." 411 ) 412 parser.add_argument( 413 "-d", "--device", default=None, type=str, 414 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 415 "By default the most performant available device will be selected." 416 ) 417 parser.add_argument( 418 "--batch_size", type=int, default=1, 419 help="The batch size for computing image embeddings over tiles or z-plane. " 420 "By default, computes the image embeddings for one tile / z-plane at a time." 421 ) 422 parser.add_argument( 423 "--tracking", action="store_true", help="Run automatic tracking instead of instance segmentation. " 424 "NOTE: It is only supported for timeseries inputs." 425 ) 426 parser.add_argument( 427 "-v", "--verbose", action="store_true", help="Whether to allow verbosity of outputs." 428 ) 429 430 args, parameter_args = parser.parse_known_args() 431 432 def _convert_argval(value): 433 # The values for the parsed arguments need to be in the expected input structure as provided. 434 # i.e. integers and floats should be in their original types. 435 try: 436 return int(value) 437 except ValueError: 438 return float(value) 439 440 # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to 441 # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS) 442 extra_kwargs = { 443 parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2) 444 } 445 446 # Separate extra arguments as per where they should be passed in the automatic segmentation class. 447 # This is done to ensure the extra arguments are allocated to the desired location. 448 # eg. for AMG, 'points_per_side' is expected by '__init__', 449 # and 'stability_score_thresh' is expected in 'generate' method. 450 mode = args.mode 451 if mode in ("auto", None): 452 # We have to load the state to see if we have a decoder in this case. 453 device = util.get_device(device=args.device) 454 predictor, state = util.get_sam_model( 455 model_type=args.model_type, device=device, checkpoint_path=args.checkpoint, return_state=True 456 ) 457 mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg" 458 else: 459 predictor, state = None, None 460 461 if mode.lower() == "amg": 462 segmenter_class = AutomaticMaskGenerator if args.tile_shape is None else TiledAutomaticMaskGenerator 463 elif mode.lower() == "ais": 464 segmenter_class = InstanceSegmentationWithDecoder if args.tile_shape is None else\ 465 TiledInstanceSegmentationWithDecoder 466 elif mode.lower() == "apg": 467 segmenter_class = AutomaticPromptGenerator if args.tile_shape is None else TiledAutomaticPromptGenerator 468 else: 469 raise ValueError(f"Invalid segmentation_mode: {mode}. Choose one of 'amg', 'ais', or 'apg'.") 470 init_kwargs, generate_kwargs = split_kwargs(segmenter_class, **extra_kwargs) 471 472 predictor, segmenter = get_predictor_and_segmenter( 473 model_type=args.model_type, 474 checkpoint=args.checkpoint, 475 device=args.device, 476 segmentation_mode=mode, 477 is_tiled=args.tile_shape is not None, 478 predictor=predictor, 479 state=state, 480 **init_kwargs, 481 ) 482 483 # Get the filepaths to input images (and other paths to store stuff, eg. segmentations and embeddings) 484 # Check whether the inputs are as expected, otherwise assort them. 485 input_paths = _get_inputs_from_paths(args.input_path, args.pattern) 486 assert len(input_paths) > 0, "'micro-sam' could not extract any image data internally." 487 488 output_path = args.output_path 489 embedding_path = args.embedding_path 490 has_one_input = len(input_paths) == 1 491 492 instance_seg_function = automatic_tracking if args.tracking else partial( 493 automatic_instance_segmentation, ndim=args.ndim 494 ) 495 496 # Run automatic segmentation per image. 497 for input_path in tqdm(input_paths, desc="Run automatic " + ("tracking" if args.tracking else "segmentation")): 498 if has_one_input: # When we have only one image / volume. 499 _embedding_fpath = embedding_path # Either folder or zarr file, would work for both. 500 501 output_fdir = os.path.splitext(output_path)[0] 502 os.makedirs(output_fdir, exist_ok=True) 503 504 # For tracking, we ensure that the output path is a folder, 505 # i.e. does not have an extension. We throw a warning if the user provided an extension. 506 if args.tracking: 507 if os.path.splitext(output_path)[-1]: 508 warnings.warn( 509 f"The output folder has an extension '{os.path.splitext(output_path)[-1]}'. " 510 "We remove it and treat it as a folder to store tracking outputs in CTC format." 511 ) 512 _output_fpath = output_fdir 513 else: # Otherwise, we can store outputs for user directly in the provided filepath, ensuring extension .tif 514 _output_fpath = f"{output_fdir}.tif" 515 516 else: # When we have multiple images. 517 # Get the input filename, without the extension. 518 input_name = str(Path(input_path).stem) 519 520 # Let's check the 'embedding_path'. 521 if embedding_path is None: # For computing embeddings on-the-fly, we don't care about the path logic. 522 _embedding_fpath = embedding_path 523 else: # Otherwise, store each embeddings inside a folder. 524 embedding_folder = os.path.splitext(embedding_path)[0] # Treat the provided embedding path as folder. 525 os.makedirs(embedding_folder, exist_ok=True) 526 _embedding_fpath = os.path.join(embedding_folder, f"{input_name}.zarr") # Create each embedding file. 527 528 # Get the output folder name. 529 output_folder = os.path.splitext(output_path)[0] 530 os.makedirs(output_folder, exist_ok=True) 531 532 # Next, let's check for output file to store segmentation (or tracks). 533 if args.tracking: # For tracking, we store CTC outputs in subfolders, with input_name as folder. 534 _output_fpath = os.path.join(output_folder, input_name) 535 else: # Otherwise, store each result inside a folder. 536 _output_fpath = os.path.join(output_folder, f"{input_name}.tif") 537 538 instance_seg_function( 539 predictor=predictor, 540 segmenter=segmenter, 541 input_path=input_path, 542 output_path=_output_fpath, 543 embedding_path=_embedding_fpath, 544 key=args.key, 545 tile_shape=args.tile_shape, 546 halo=args.halo, 547 annotate=args.annotate, 548 verbose=args.verbose, 549 batch_size=args.batch_size, 550 **generate_kwargs, 551 )
def
get_predictor_and_segmenter( model_type: str, checkpoint: Union[str, os.PathLike, NoneType] = None, device: str = None, segmentation_mode: Optional[Literal['amg', 'ais', 'apg']] = None, is_tiled: bool = False, predictor=None, state=None, **kwargs) -> Tuple[mobile_sam.predictor.SamPredictor, Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder]]:
26def get_predictor_and_segmenter( 27 model_type: str, 28 checkpoint: Optional[Union[os.PathLike, str]] = None, 29 device: str = None, 30 segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None, 31 is_tiled: bool = False, 32 predictor=None, 33 state=None, 34 **kwargs, 35) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: 36 f"""Get the Segment Anything model and class for automatic instance segmentation. 37 38 Args: 39 model_type: The Segment Anything model choice. 40 checkpoint: The filepath to the stored model checkpoints. 41 device: The torch device. By default, automatically chooses the best available device. 42 segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'. 43 By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used 44 if a decoder is passed, otherwise 'amg' is used. 45 is_tiled: Whether to return segmenter for performing segmentation in tiling window style. 46 By default, set to 'False'. 47 predictor: The pre-loaded predictor (optional). 48 state: The pre-loaded state (optional). 49 kwargs: Keyword arguments for the automatic mask generation class. 50 51 Returns: 52 The Segment Anything model. 53 The automatic instance segmentation class. 54 """ 55 # Get the predictor and state for Segment Anything Model. 56 if predictor is None: 57 device = util.get_device(device=device) 58 predictor, state = util.get_sam_model( 59 model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True 60 ) 61 else: 62 assert state is not None 63 64 if segmentation_mode in (None, "auto"): 65 segmentation_mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg" 66 67 if segmentation_mode.lower() == "amg": 68 decoder = None 69 else: 70 if "decoder_state" not in state: 71 raise RuntimeError( 72 f"You have passed 'segmentation_mode={segmentation_mode}', but your model does not contain a decoder." 73 ) 74 decoder_state = state["decoder_state"] 75 decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) 76 77 segmenter = get_instance_segmentation_generator( 78 predictor=predictor, is_tiled=is_tiled, decoder=decoder, segmentation_mode=segmentation_mode, **kwargs 79 ) 80 return predictor, segmenter
def
automatic_tracking( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[str, os.PathLike, NoneType] = None, embedding_path: Union[str, os.PathLike, NoneType] = None, key: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, batch_size: int = 1, **generate_kwargs) -> Tuple[numpy.ndarray, List[Dict]]:
89def automatic_tracking( 90 predictor: util.SamPredictor, 91 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 92 input_path: Union[Union[os.PathLike, str], np.ndarray], 93 output_path: Optional[Union[os.PathLike, str]] = None, 94 embedding_path: Optional[Union[os.PathLike, str]] = None, 95 key: Optional[str] = None, 96 tile_shape: Optional[Tuple[int, int]] = None, 97 halo: Optional[Tuple[int, int]] = None, 98 verbose: bool = True, 99 return_embeddings: bool = False, 100 annotate: bool = False, 101 batch_size: int = 1, 102 **generate_kwargs 103) -> Tuple[np.ndarray, List[Dict]]: 104 """Run automatic tracking for the input timeseries. 105 106 Args: 107 predictor: The Segment Anything model. 108 segmenter: The automatic instance segmentation class. 109 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 110 or a container file (e.g. hdf5 or zarr). 111 output_path: The folder where the tracking outputs will be saved in CTC format. 112 embedding_path: The path where the embeddings are cached already / will be saved. 113 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 114 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 115 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 116 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 117 verbose: Verbosity flag. By default, set to 'True'. 118 return_embeddings: Whether to return the precomputed image embeddings. 119 By default, does not return the embeddings. 120 annotate: Whether to activate the annotator for continue annotation process. 121 By default, does not activate the annotator. 122 batch_size: The batch size to compute image embeddings over tiles / z-planes. 123 By default, does it sequentially, i.e. one after the other. 124 generate_kwargs: optional keyword arguments for the generate function of the AMG, APG, or AIS class. 125 126 Returns: 127 The tracking result as a timeseries, where each object is labeled by its track id. 128 The lineages representing cell divisions, stored as a dictionary. 129 """ 130 # Load the input image file. 131 # We assume that it has to be read from file if it is a str or pathlike. 132 # Otherwise we assume it is a numpy array like object. 133 image_data = util.load_image_data(input_path, key) if isinstance(input_path, (str, os.PathLike)) else input_path 134 135 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 136 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 137 138 gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent") 139 segmentation, lineage, image_embeddings = automatic_tracking_implementation( 140 image_data, 141 predictor, 142 segmenter, 143 embedding_path=embedding_path, 144 gap_closing=gap_closing, 145 min_time_extent=min_time_extent, 146 tile_shape=tile_shape, 147 halo=halo, 148 verbose=verbose, 149 batch_size=batch_size, 150 return_embeddings=True, 151 output_folder=output_path, 152 **generate_kwargs, 153 ) 154 155 if annotate: 156 # TODO We need to support initialization of the tracking annotator with the tracking result for this. 157 raise NotImplementedError("Annotation after running the automated tracking is currently not supported.") 158 159 if return_embeddings: 160 return segmentation, lineage, image_embeddings 161 else: 162 return segmentation, lineage
Run automatic tracking for the input timeseries.
Arguments:
- predictor: The Segment Anything model.
- segmenter: The automatic instance segmentation class.
- input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
- output_path: The folder where the tracking outputs will be saved in CTC format.
- embedding_path: The path where the embeddings are cached already / will be saved.
- key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Verbosity flag. By default, set to 'True'.
- return_embeddings: Whether to return the precomputed image embeddings. By default, does not return the embeddings.
- annotate: Whether to activate the annotator for continue annotation process. By default, does not activate the annotator.
- batch_size: The batch size to compute image embeddings over tiles / z-planes. By default, does it sequentially, i.e. one after the other.
- generate_kwargs: optional keyword arguments for the generate function of the AMG, APG, or AIS class.
Returns:
The tracking result as a timeseries, where each object is labeled by its track id. The lineages representing cell divisions, stored as a dictionary.
def
automatic_instance_segmentation( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[str, os.PathLike, NoneType] = None, embedding_path: Union[str, os.PathLike, NoneType] = None, mask_path: Union[os.PathLike, str, numpy.ndarray, NoneType] = None, key: Optional[str] = None, mask_key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, batch_size: int = 1, **generate_kwargs) -> numpy.ndarray:
165def automatic_instance_segmentation( 166 predictor: util.SamPredictor, 167 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 168 input_path: Union[Union[os.PathLike, str], np.ndarray], 169 output_path: Optional[Union[os.PathLike, str]] = None, 170 embedding_path: Optional[Union[os.PathLike, str]] = None, 171 mask_path: Optional[Union[Union[os.PathLike, str], np.ndarray]] = None, 172 key: Optional[str] = None, 173 mask_key: Optional[str] = None, 174 ndim: Optional[int] = None, 175 tile_shape: Optional[Tuple[int, int]] = None, 176 halo: Optional[Tuple[int, int]] = None, 177 verbose: bool = True, 178 return_embeddings: bool = False, 179 annotate: bool = False, 180 batch_size: int = 1, 181 **generate_kwargs 182) -> np.ndarray: 183 """Run automatic segmentation for the input image. 184 185 Args: 186 predictor: The Segment Anything model. 187 segmenter: The automatic instance segmentation class. 188 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 189 or a container file (e.g. hdf5 or zarr). 190 output_path: The output path where the instance segmentations will be saved. 191 embedding_path: The path where the embeddings are cached already / will be saved. 192 mask_path: The path to an optional foreground mask. Areas outside of the foreground will not be processed. 193 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 194 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 195 mask_key: The key to the (optional) foreground mask. 196 ndim: The dimensionality of the data. By default the dimensionality of the data will be used. 197 If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. 198 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 199 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 200 verbose: Verbosity flag. By default, set to 'True'. 201 return_embeddings: Whether to return the precomputed image embeddings. 202 By default, does not return the embeddings. 203 annotate: Whether to activate the annotator for continue annotation process. 204 By default, does not activate the annotator. 205 batch_size: The batch size to compute image embeddings over tiles / z-planes. 206 By default, does it sequentially, i.e. one after the other. 207 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 208 209 Returns: 210 The segmentation result. 211 """ 212 # Avoid overwriting already stored segmentations. 213 if output_path is not None: 214 output_path = Path(output_path).with_suffix(".tif") 215 if os.path.exists(output_path): 216 print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.") 217 return 218 219 # We assume that it has to be read from file if it is a str or pathlike. 220 # Otherwise we assume it is a numpy array like object. 221 image_data = util.load_image_data(input_path, key) if isinstance(input_path, (str, os.PathLike)) else input_path 222 223 ndim = image_data.ndim if ndim is None else ndim 224 225 # Load the mask defining foreground if it was given. 226 if mask_path is None: 227 mask = None 228 else: 229 mask = util.load_image_data(mask_path, mask_key) if isinstance(mask_path, (str, os.PathLike)) else mask_path 230 231 if ndim == 2: 232 if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): 233 raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") 234 235 # Precompute the image embeddings. 236 image_embeddings = util.precompute_image_embeddings( 237 predictor=predictor, 238 input_=image_data, 239 save_path=embedding_path, 240 ndim=ndim, 241 tile_shape=tile_shape, 242 halo=halo, 243 verbose=verbose, 244 batch_size=batch_size, 245 mask=mask, 246 ) 247 initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose) 248 if mask is not None: 249 initialize_kwargs["mask"] = mask 250 251 # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing. 252 # In this case, we also add the batch size to the initialize kwargs, 253 # so that the segmentation decoder can be applied in a batched fashion. 254 if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None: 255 if not isinstance(segmenter, TiledAutomaticPromptGenerator): 256 generate_kwargs.update({"tile_shape": tile_shape, "halo": halo}) 257 initialize_kwargs["batch_size"] = batch_size 258 259 segmenter.initialize(**initialize_kwargs) 260 instances = segmenter.generate(**generate_kwargs) 261 262 else: 263 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 264 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 265 if mask is not None: 266 raise NotImplementedError 267 268 instances, image_embeddings = automatic_3d_segmentation( 269 volume=image_data, 270 predictor=predictor, 271 segmentor=segmenter, 272 embedding_path=embedding_path, 273 tile_shape=tile_shape, 274 halo=halo, 275 verbose=verbose, 276 return_embeddings=True, 277 batch_size=batch_size, 278 **generate_kwargs 279 ) 280 281 # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage. 282 if output_path is not None: 283 _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path 284 imageio.imwrite(_output_path, instances, compression="zlib") 285 if verbose: 286 print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.") 287 288 # Allow opening the automatic segmentation in the annotator for further annotation, if desired. 289 if annotate: 290 from micro_sam.sam_annotator import annotator_2d, annotator_3d 291 annotator_function = annotator_2d if ndim == 2 else annotator_3d 292 293 viewer = annotator_function( 294 image=image_data, 295 model_type=predictor.model_name, 296 embedding_path=image_embeddings, # Providing the precomputed image embeddings. 297 segmentation_result=instances, # Initializes the automatic segmentation to the annotator. 298 tile_shape=tile_shape, 299 halo=halo, 300 return_viewer=True, # Returns the viewer, which allows the user to store the updated segmentations. 301 ) 302 303 # Start the GUI here 304 import napari 305 napari.run() 306 307 # We extract the segmentation in "committed_objects" layer, where the user either: 308 # a) Performed interactive segmentation / corrections and committed them, OR 309 # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is. 310 instances = viewer.layers["committed_objects"].data 311 312 # Save the instance segmentation, if 'output_path' provided. 313 if output_path is not None: 314 imageio.imwrite(output_path, instances, compression="zlib") 315 if verbose: 316 print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.") 317 318 if return_embeddings: 319 return instances, image_embeddings 320 else: 321 return instances
Run automatic segmentation for the input image.
Arguments:
- predictor: The Segment Anything model.
- segmenter: The automatic instance segmentation class.
- input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
- output_path: The output path where the instance segmentations will be saved.
- embedding_path: The path where the embeddings are cached already / will be saved.
- mask_path: The path to an optional foreground mask. Areas outside of the foreground will not be processed.
- key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
- mask_key: The key to the (optional) foreground mask.
- ndim: The dimensionality of the data. By default the dimensionality of the data will be used. If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Verbosity flag. By default, set to 'True'.
- return_embeddings: Whether to return the precomputed image embeddings. By default, does not return the embeddings.
- annotate: Whether to activate the annotator for continue annotation process. By default, does not activate the annotator.
- batch_size: The batch size to compute image embeddings over tiles / z-planes. By default, does it sequentially, i.e. one after the other.
- generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:
The segmentation result.