micro_sam.automatic_segmentation
1import os 2import warnings 3from glob import glob 4from tqdm import tqdm 5from pathlib import Path 6from functools import partial 7from typing import Dict, List, Optional, Union, Tuple 8 9import numpy as np 10import imageio.v3 as imageio 11 12from torch_em.data.datasets.util import split_kwargs 13 14from . import util 15from .instance_segmentation import ( 16 get_instance_segmentation_generator, get_decoder, AMGBase, 17 AutomaticMaskGenerator, TiledAutomaticMaskGenerator, 18 AutomaticPromptGenerator, TiledAutomaticPromptGenerator, 19 InstanceSegmentationWithDecoder, TiledInstanceSegmentationWithDecoder, 20 DEFAULT_SEGMENTATION_MODE_WITH_DECODER, 21) 22from .multi_dimensional_segmentation import automatic_3d_segmentation, automatic_tracking_implementation 23 24 25def get_predictor_and_segmenter( 26 model_type: str, 27 checkpoint: Optional[Union[os.PathLike, str]] = None, 28 device: str = None, 29 segmentation_mode: Optional[str] = None, 30 is_tiled: bool = False, 31 predictor=None, 32 state=None, 33 **kwargs, 34) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: 35 f"""Get the Segment Anything model and class for automatic instance segmentation. 36 37 Args: 38 model_type: The Segment Anything model choice. 39 checkpoint: The filepath to the stored model checkpoints. 40 device: The torch device. By default, automatically chooses the best available device. 41 segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'. 42 By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used 43 if a decoder is passed, otherwise 'amg' is used. 44 is_tiled: Whether to return segmenter for performing segmentation in tiling window style. 45 By default, set to 'False'. 46 predictor: The pre-loaded predictor (optional). 47 state: The pre-loaded state (optional). 48 kwargs: Keyword arguments for the automatic mask generation class. 49 50 Returns: 51 The Segment Anything model. 52 The automatic instance segmentation class. 53 """ 54 # Get the predictor and state for Segment Anything Model. 55 if predictor is None: 56 device = util.get_device(device=device) 57 predictor, state = util.get_sam_model( 58 model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True 59 ) 60 else: 61 assert state is not None 62 63 if segmentation_mode in (None, "auto"): 64 segmentation_mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg" 65 66 if segmentation_mode.lower() == "amg": 67 decoder = None 68 else: 69 if "decoder_state" not in state: 70 raise RuntimeError( 71 f"You have passed 'segmentation_mode={segmentation_mode}', but your model does not contain a decoder." 72 ) 73 decoder_state = state["decoder_state"] 74 decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) 75 76 segmenter = get_instance_segmentation_generator( 77 predictor=predictor, is_tiled=is_tiled, decoder=decoder, segmentation_mode=segmentation_mode, **kwargs 78 ) 79 return predictor, segmenter 80 81 82def _add_suffix_to_output_path(output_path: Union[str, os.PathLike], suffix: str) -> str: 83 fpath = Path(output_path).resolve() 84 fext = fpath.suffix if fpath.suffix else ".tif" 85 return str(fpath.with_name(f"{fpath.stem}{suffix}{fext}")) 86 87 88def automatic_tracking( 89 predictor: util.SamPredictor, 90 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 91 input_path: Union[Union[os.PathLike, str], np.ndarray], 92 output_path: Optional[Union[os.PathLike, str]] = None, 93 embedding_path: Optional[Union[os.PathLike, str]] = None, 94 key: Optional[str] = None, 95 tile_shape: Optional[Tuple[int, int]] = None, 96 halo: Optional[Tuple[int, int]] = None, 97 verbose: bool = True, 98 return_embeddings: bool = False, 99 annotate: bool = False, 100 batch_size: int = 1, 101 **generate_kwargs 102) -> Tuple[np.ndarray, List[Dict]]: 103 """Run automatic tracking for the input timeseries. 104 105 Args: 106 predictor: The Segment Anything model. 107 segmenter: The automatic instance segmentation class. 108 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 109 or a container file (e.g. hdf5 or zarr). 110 output_path: The folder where the tracking outputs will be saved in CTC format. 111 embedding_path: The path where the embeddings are cached already / will be saved. 112 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 113 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 114 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 115 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 116 verbose: Verbosity flag. By default, set to 'True'. 117 return_embeddings: Whether to return the precomputed image embeddings. 118 By default, does not return the embeddings. 119 annotate: Whether to activate the annotator for continue annotation process. 120 By default, does not activate the annotator. 121 batch_size: The batch size to compute image embeddings over tiles / z-planes. 122 By default, does it sequentially, i.e. one after the other. 123 generate_kwargs: optional keyword arguments for the generate function of the AMG, APG, or AIS class. 124 125 Returns: 126 The tracking result as a timeseries, where each object is labeled by its track id. 127 The lineages representing cell divisions, stored as a dictionary. 128 """ 129 # Load the input image file. 130 if isinstance(input_path, np.ndarray): 131 image_data = input_path 132 else: 133 image_data = util.load_image_data(input_path, key) 134 135 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 136 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 137 138 gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent") 139 segmentation, lineage, image_embeddings = automatic_tracking_implementation( 140 image_data, 141 predictor, 142 segmenter, 143 embedding_path=embedding_path, 144 gap_closing=gap_closing, 145 min_time_extent=min_time_extent, 146 tile_shape=tile_shape, 147 halo=halo, 148 verbose=verbose, 149 batch_size=batch_size, 150 return_embeddings=True, 151 output_folder=output_path, 152 **generate_kwargs, 153 ) 154 155 if annotate: 156 # TODO We need to support initialization of the tracking annotator with the tracking result for this. 157 raise NotImplementedError("Annotation after running the automated tracking is currently not supported.") 158 159 if return_embeddings: 160 return segmentation, lineage, image_embeddings 161 else: 162 return segmentation, lineage 163 164 165def automatic_instance_segmentation( 166 predictor: util.SamPredictor, 167 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 168 input_path: Union[Union[os.PathLike, str], np.ndarray], 169 output_path: Optional[Union[os.PathLike, str]] = None, 170 embedding_path: Optional[Union[os.PathLike, str]] = None, 171 key: Optional[str] = None, 172 ndim: Optional[int] = None, 173 tile_shape: Optional[Tuple[int, int]] = None, 174 halo: Optional[Tuple[int, int]] = None, 175 verbose: bool = True, 176 return_embeddings: bool = False, 177 annotate: bool = False, 178 batch_size: int = 1, 179 **generate_kwargs 180) -> np.ndarray: 181 """Run automatic segmentation for the input image. 182 183 Args: 184 predictor: The Segment Anything model. 185 segmenter: The automatic instance segmentation class. 186 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 187 or a container file (e.g. hdf5 or zarr). 188 output_path: The output path where the instance segmentations will be saved. 189 embedding_path: The path where the embeddings are cached already / will be saved. 190 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 191 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 192 ndim: The dimensionality of the data. By default the dimensionality of the data will be used. 193 If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. 194 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 195 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 196 verbose: Verbosity flag. By default, set to 'True'. 197 return_embeddings: Whether to return the precomputed image embeddings. 198 By default, does not return the embeddings. 199 annotate: Whether to activate the annotator for continue annotation process. 200 By default, does not activate the annotator. 201 batch_size: The batch size to compute image embeddings over tiles / z-planes. 202 By default, does it sequentially, i.e. one after the other. 203 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 204 205 Returns: 206 The segmentation result. 207 """ 208 # Avoid overwriting already stored segmentations. 209 if output_path is not None: 210 output_path = Path(output_path).with_suffix(".tif") 211 if os.path.exists(output_path): 212 print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.") 213 return 214 215 # Load the input image file. 216 if isinstance(input_path, np.ndarray): 217 image_data = input_path 218 else: 219 image_data = util.load_image_data(input_path, key) 220 221 ndim = image_data.ndim if ndim is None else ndim 222 223 if ndim == 2: 224 if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): 225 raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") 226 227 # Precompute the image embeddings. 228 image_embeddings = util.precompute_image_embeddings( 229 predictor=predictor, 230 input_=image_data, 231 save_path=embedding_path, 232 ndim=ndim, 233 tile_shape=tile_shape, 234 halo=halo, 235 verbose=verbose, 236 batch_size=batch_size, 237 ) 238 initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose) 239 240 # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing. 241 # In this case, we also add the batch size to the initialize kwargs, 242 # so that the segmentation decoder can be applied in a batched fashion. 243 if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None: 244 generate_kwargs.update({"tile_shape": tile_shape, "halo": halo}) 245 initialize_kwargs["batch_size"] = batch_size 246 247 segmenter.initialize(**initialize_kwargs) 248 instances = segmenter.generate(**generate_kwargs) 249 250 else: 251 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 252 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 253 254 instances, image_embeddings = automatic_3d_segmentation( 255 volume=image_data, 256 predictor=predictor, 257 segmentor=segmenter, 258 embedding_path=embedding_path, 259 tile_shape=tile_shape, 260 halo=halo, 261 verbose=verbose, 262 return_embeddings=True, 263 batch_size=batch_size, 264 **generate_kwargs 265 ) 266 267 # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage. 268 if output_path is not None: 269 _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path 270 imageio.imwrite(_output_path, instances, compression="zlib") 271 if verbose: 272 print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.") 273 274 # Allow opening the automatic segmentation in the annotator for further annotation, if desired. 275 if annotate: 276 from micro_sam.sam_annotator import annotator_2d, annotator_3d 277 annotator_function = annotator_2d if ndim == 2 else annotator_3d 278 279 viewer = annotator_function( 280 image=image_data, 281 model_type=predictor.model_name, 282 embedding_path=image_embeddings, # Providing the precomputed image embeddings. 283 segmentation_result=instances, # Initializes the automatic segmentation to the annotator. 284 tile_shape=tile_shape, 285 halo=halo, 286 return_viewer=True, # Returns the viewer, which allows the user to store the updated segmentations. 287 ) 288 289 # Start the GUI here 290 import napari 291 napari.run() 292 293 # We extract the segmentation in "committed_objects" layer, where the user either: 294 # a) Performed interactive segmentation / corrections and committed them, OR 295 # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is. 296 instances = viewer.layers["committed_objects"].data 297 298 # Save the instance segmentation, if 'output_path' provided. 299 if output_path is not None: 300 imageio.imwrite(output_path, instances, compression="zlib") 301 if verbose: 302 print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.") 303 304 if return_embeddings: 305 return instances, image_embeddings 306 else: 307 return instances 308 309 310def _get_inputs_from_paths(paths, pattern): 311 "Function to get all filepaths in a directory." 312 313 if isinstance(paths, str): 314 paths = [paths] 315 316 fpaths = [] 317 for path in paths: 318 if os.path.isfile(path): # It is just one filepath. 319 fpaths.append(path) 320 else: # Otherwise, if the path is a directory, fetch all inputs provided with a pattern. 321 assert pattern is not None, \ 322 f"You must provide a pattern to search for files in the directory: '{os.path.abspath(path)}'." 323 fpaths.extend(glob(os.path.join(path, pattern))) 324 325 return fpaths 326 327 328def main(): 329 """@private""" 330 import argparse 331 332 available_models = list(util.get_model_names()) 333 available_models = ", ".join(available_models) 334 335 parser = argparse.ArgumentParser( 336 description="Run automatic segmentation or tracking for 2d, 3d or timeseries data.\n" 337 "Either a single input file or multiple input files are supported. You can specify multiple files " 338 "by either providing multiple filepaths to the '--i/--input_paths' argument, or by providing an argument " 339 "to '--pattern' to use a wildcard pattern ('*') for selecting multiple files.\n" 340 "NOTE: for automatic 3d segmentation or tracking the data has to be stored as volume / timeseries, " 341 "stacking individual tif images is not supported.\n" 342 "Segmentation is performed using one of the three modes supported by micro_sam: \n" 343 "automatic instance segmentation (AIS), automatic prompt generation (APG) or automatic mask generation (AMG).\n" 344 "In addition to the options listed below, " 345 "you can also passed additional arguments for these three segmentation modes:\n" 346 "For AIS: '--center_distance_threshold', '--boundary_distance_threshold' and other arguments of `InstanceSegmentationWithDecoder.generate`.\n" # noqa 347 "FOR APG: '--center_distance_threshold', '--boundary_distance_threshold' and other arguments of `AutomaticPromptGenerator.generate`.\n" # noqa 348 "For AMG: '--pred_iou_thresh', '--stability_score_thresh' and other arguments of `AutomaticMaskGenerator.generate`." # noqa 349 ) 350 parser.add_argument( 351 "-i", "--input_path", required=True, type=str, nargs="+", 352 help="The filepath(s) to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " # noqa 353 "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter." 354 ) 355 parser.add_argument( 356 "-o", "--output_path", required=True, type=str, 357 help="The filepath to store the results. If multiple inputs are provied, " 358 "this should be a folder. For a single image, you should provide the path to a tif file for the output segmentation." # noqa 359 "NOTE: Segmentation results are stored as tif files, tracking results in the CTC fil format ." 360 ) 361 parser.add_argument( 362 "-e", "--embedding_path", default=None, type=str, 363 help="An optional path where the embeddings will be saved. If multiple inputs are provided, " 364 "this should be a folder. Otherwise you can store embeddings in single zarr file." 365 ) 366 parser.add_argument( 367 "--pattern", type=str, help="Pattern / wildcard for selecting files in a folder. To select all files use '*'." 368 ) 369 parser.add_argument( 370 "-k", "--key", default=None, type=str, 371 help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, " 372 "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'." 373 ) 374 parser.add_argument( 375 "-m", "--model_type", default=util._DEFAULT_MODEL, type=str, 376 help=f"The segment anything model that will be used, one of {available_models}." 377 ) 378 parser.add_argument( 379 "-c", "--checkpoint", default=None, type=str, help="Checkpoint from which the SAM model will be loaded." 380 ) 381 parser.add_argument( 382 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None 383 ) 384 parser.add_argument( 385 "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None 386 ) 387 parser.add_argument( 388 "-n", "--ndim", default=None, type=int, 389 help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension." 390 ) 391 parser.add_argument( 392 "--mode", default="auto", type=str, 393 help="The choice of automatic segmentation mode. Either 'auto', 'amg', 'apg', or 'ais'." 394 ) 395 parser.add_argument( 396 "--annotate", action="store_true", 397 help="Whether to continue annotation after the automatic segmentation is generated." 398 ) 399 parser.add_argument( 400 "-d", "--device", default=None, type=str, 401 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 402 "By default the most performant available device will be selected." 403 ) 404 parser.add_argument( 405 "--batch_size", type=int, default=1, 406 help="The batch size for computing image embeddings over tiles or z-plane. " 407 "By default, computes the image embeddings for one tile / z-plane at a time." 408 ) 409 parser.add_argument( 410 "--tracking", action="store_true", help="Run automatic tracking instead of instance segmentation. " 411 "NOTE: It is only supported for timeseries inputs." 412 ) 413 parser.add_argument( 414 "-v", "--verbose", action="store_true", help="Whether to allow verbosity of outputs." 415 ) 416 417 args, parameter_args = parser.parse_known_args() 418 419 def _convert_argval(value): 420 # The values for the parsed arguments need to be in the expected input structure as provided. 421 # i.e. integers and floats should be in their original types. 422 try: 423 return int(value) 424 except ValueError: 425 return float(value) 426 427 # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to 428 # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS) 429 extra_kwargs = { 430 parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2) 431 } 432 433 # Separate extra arguments as per where they should be passed in the automatic segmentation class. 434 # This is done to ensure the extra arguments are allocated to the desired location. 435 # eg. for AMG, 'points_per_side' is expected by '__init__', 436 # and 'stability_score_thresh' is expected in 'generate' method. 437 mode = args.mode 438 if mode in ("auto", None): 439 # We have to load the state to see if we have a decoder in this case. 440 device = util.get_device(device=args.device) 441 predictor, state = util.get_sam_model( 442 model_type=args.model_type, device=device, checkpoint_path=args.checkpoint, return_state=True 443 ) 444 mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg" 445 else: 446 predictor, state = None, None 447 448 if mode.lower() == "amg": 449 segmenter_class = AutomaticMaskGenerator if args.tile_shape is None else TiledAutomaticMaskGenerator 450 elif mode.lower() == "ais": 451 segmenter_class = InstanceSegmentationWithDecoder if args.tile_shape is None else\ 452 TiledInstanceSegmentationWithDecoder 453 elif mode.lower() == "apg": 454 segmenter_class = AutomaticPromptGenerator if args.tile_shape is None else TiledAutomaticPromptGenerator 455 else: 456 raise ValueError(f"Invalid segmentation_mode: {mode}. Choose one of 'amg', 'ais', or 'apg'.") 457 init_kwargs, generate_kwargs = split_kwargs(segmenter_class, **extra_kwargs) 458 459 predictor, segmenter = get_predictor_and_segmenter( 460 model_type=args.model_type, 461 checkpoint=args.checkpoint, 462 device=args.device, 463 segmentation_mode=mode, 464 is_tiled=args.tile_shape is not None, 465 predictor=predictor, 466 state=state, 467 **init_kwargs, 468 ) 469 470 # Get the filepaths to input images (and other paths to store stuff, eg. segmentations and embeddings) 471 # Check whether the inputs are as expected, otherwise assort them. 472 input_paths = _get_inputs_from_paths(args.input_path, args.pattern) 473 assert len(input_paths) > 0, "'micro-sam' could not extract any image data internally." 474 475 output_path = args.output_path 476 embedding_path = args.embedding_path 477 has_one_input = len(input_paths) == 1 478 479 instance_seg_function = automatic_tracking if args.tracking else partial( 480 automatic_instance_segmentation, ndim=args.ndim 481 ) 482 483 # Run automatic segmentation per image. 484 for input_path in tqdm(input_paths, desc="Run automatic " + ("tracking" if args.tracking else "segmentation")): 485 if has_one_input: # When we have only one image / volume. 486 _embedding_fpath = embedding_path # Either folder or zarr file, would work for both. 487 488 output_fdir = os.path.splitext(output_path)[0] 489 os.makedirs(output_fdir, exist_ok=True) 490 491 # For tracking, we ensure that the output path is a folder, 492 # i.e. does not have an extension. We throw a warning if the user provided an extension. 493 if args.tracking: 494 if os.path.splitext(output_path)[-1]: 495 warnings.warn( 496 f"The output folder has an extension '{os.path.splitext(output_path)[-1]}'. " 497 "We remove it and treat it as a folder to store tracking outputs in CTC format." 498 ) 499 _output_fpath = output_fdir 500 else: # Otherwise, we can store outputs for user directly in the provided filepath, ensuring extension .tif 501 _output_fpath = f"{output_fdir}.tif" 502 503 else: # When we have multiple images. 504 # Get the input filename, without the extension. 505 input_name = str(Path(input_path).stem) 506 507 # Let's check the 'embedding_path'. 508 if embedding_path is None: # For computing embeddings on-the-fly, we don't care about the path logic. 509 _embedding_fpath = embedding_path 510 else: # Otherwise, store each embeddings inside a folder. 511 embedding_folder = os.path.splitext(embedding_path)[0] # Treat the provided embedding path as folder. 512 os.makedirs(embedding_folder, exist_ok=True) 513 _embedding_fpath = os.path.join(embedding_folder, f"{input_name}.zarr") # Create each embedding file. 514 515 # Get the output folder name. 516 output_folder = os.path.splitext(output_path)[0] 517 os.makedirs(output_folder, exist_ok=True) 518 519 # Next, let's check for output file to store segmentation (or tracks). 520 if args.tracking: # For tracking, we store CTC outputs in subfolders, with input_name as folder. 521 _output_fpath = os.path.join(output_folder, input_name) 522 else: # Otherwise, store each result inside a folder. 523 _output_fpath = os.path.join(output_folder, f"{input_name}.tif") 524 525 instance_seg_function( 526 predictor=predictor, 527 segmenter=segmenter, 528 input_path=input_path, 529 output_path=_output_fpath, 530 embedding_path=_embedding_fpath, 531 key=args.key, 532 tile_shape=args.tile_shape, 533 halo=args.halo, 534 annotate=args.annotate, 535 verbose=args.verbose, 536 batch_size=args.batch_size, 537 **generate_kwargs, 538 )
def
get_predictor_and_segmenter( model_type: str, checkpoint: Union[os.PathLike, str, NoneType] = None, device: str = None, segmentation_mode: Optional[str] = None, is_tiled: bool = False, predictor=None, state=None, **kwargs) -> Tuple[mobile_sam.predictor.SamPredictor, Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder]]:
26def get_predictor_and_segmenter( 27 model_type: str, 28 checkpoint: Optional[Union[os.PathLike, str]] = None, 29 device: str = None, 30 segmentation_mode: Optional[str] = None, 31 is_tiled: bool = False, 32 predictor=None, 33 state=None, 34 **kwargs, 35) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: 36 f"""Get the Segment Anything model and class for automatic instance segmentation. 37 38 Args: 39 model_type: The Segment Anything model choice. 40 checkpoint: The filepath to the stored model checkpoints. 41 device: The torch device. By default, automatically chooses the best available device. 42 segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'. 43 By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used 44 if a decoder is passed, otherwise 'amg' is used. 45 is_tiled: Whether to return segmenter for performing segmentation in tiling window style. 46 By default, set to 'False'. 47 predictor: The pre-loaded predictor (optional). 48 state: The pre-loaded state (optional). 49 kwargs: Keyword arguments for the automatic mask generation class. 50 51 Returns: 52 The Segment Anything model. 53 The automatic instance segmentation class. 54 """ 55 # Get the predictor and state for Segment Anything Model. 56 if predictor is None: 57 device = util.get_device(device=device) 58 predictor, state = util.get_sam_model( 59 model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True 60 ) 61 else: 62 assert state is not None 63 64 if segmentation_mode in (None, "auto"): 65 segmentation_mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg" 66 67 if segmentation_mode.lower() == "amg": 68 decoder = None 69 else: 70 if "decoder_state" not in state: 71 raise RuntimeError( 72 f"You have passed 'segmentation_mode={segmentation_mode}', but your model does not contain a decoder." 73 ) 74 decoder_state = state["decoder_state"] 75 decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) 76 77 segmenter = get_instance_segmentation_generator( 78 predictor=predictor, is_tiled=is_tiled, decoder=decoder, segmentation_mode=segmentation_mode, **kwargs 79 ) 80 return predictor, segmenter
def
automatic_tracking( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, batch_size: int = 1, **generate_kwargs) -> Tuple[numpy.ndarray, List[Dict]]:
89def automatic_tracking( 90 predictor: util.SamPredictor, 91 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 92 input_path: Union[Union[os.PathLike, str], np.ndarray], 93 output_path: Optional[Union[os.PathLike, str]] = None, 94 embedding_path: Optional[Union[os.PathLike, str]] = None, 95 key: Optional[str] = None, 96 tile_shape: Optional[Tuple[int, int]] = None, 97 halo: Optional[Tuple[int, int]] = None, 98 verbose: bool = True, 99 return_embeddings: bool = False, 100 annotate: bool = False, 101 batch_size: int = 1, 102 **generate_kwargs 103) -> Tuple[np.ndarray, List[Dict]]: 104 """Run automatic tracking for the input timeseries. 105 106 Args: 107 predictor: The Segment Anything model. 108 segmenter: The automatic instance segmentation class. 109 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 110 or a container file (e.g. hdf5 or zarr). 111 output_path: The folder where the tracking outputs will be saved in CTC format. 112 embedding_path: The path where the embeddings are cached already / will be saved. 113 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 114 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 115 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 116 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 117 verbose: Verbosity flag. By default, set to 'True'. 118 return_embeddings: Whether to return the precomputed image embeddings. 119 By default, does not return the embeddings. 120 annotate: Whether to activate the annotator for continue annotation process. 121 By default, does not activate the annotator. 122 batch_size: The batch size to compute image embeddings over tiles / z-planes. 123 By default, does it sequentially, i.e. one after the other. 124 generate_kwargs: optional keyword arguments for the generate function of the AMG, APG, or AIS class. 125 126 Returns: 127 The tracking result as a timeseries, where each object is labeled by its track id. 128 The lineages representing cell divisions, stored as a dictionary. 129 """ 130 # Load the input image file. 131 if isinstance(input_path, np.ndarray): 132 image_data = input_path 133 else: 134 image_data = util.load_image_data(input_path, key) 135 136 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 137 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 138 139 gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent") 140 segmentation, lineage, image_embeddings = automatic_tracking_implementation( 141 image_data, 142 predictor, 143 segmenter, 144 embedding_path=embedding_path, 145 gap_closing=gap_closing, 146 min_time_extent=min_time_extent, 147 tile_shape=tile_shape, 148 halo=halo, 149 verbose=verbose, 150 batch_size=batch_size, 151 return_embeddings=True, 152 output_folder=output_path, 153 **generate_kwargs, 154 ) 155 156 if annotate: 157 # TODO We need to support initialization of the tracking annotator with the tracking result for this. 158 raise NotImplementedError("Annotation after running the automated tracking is currently not supported.") 159 160 if return_embeddings: 161 return segmentation, lineage, image_embeddings 162 else: 163 return segmentation, lineage
Run automatic tracking for the input timeseries.
Arguments:
- predictor: The Segment Anything model.
- segmenter: The automatic instance segmentation class.
- input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
- output_path: The folder where the tracking outputs will be saved in CTC format.
- embedding_path: The path where the embeddings are cached already / will be saved.
- key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Verbosity flag. By default, set to 'True'.
- return_embeddings: Whether to return the precomputed image embeddings. By default, does not return the embeddings.
- annotate: Whether to activate the annotator for continue annotation process. By default, does not activate the annotator.
- batch_size: The batch size to compute image embeddings over tiles / z-planes. By default, does it sequentially, i.e. one after the other.
- generate_kwargs: optional keyword arguments for the generate function of the AMG, APG, or AIS class.
Returns:
The tracking result as a timeseries, where each object is labeled by its track id. The lineages representing cell divisions, stored as a dictionary.
def
automatic_instance_segmentation( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, batch_size: int = 1, **generate_kwargs) -> numpy.ndarray:
166def automatic_instance_segmentation( 167 predictor: util.SamPredictor, 168 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 169 input_path: Union[Union[os.PathLike, str], np.ndarray], 170 output_path: Optional[Union[os.PathLike, str]] = None, 171 embedding_path: Optional[Union[os.PathLike, str]] = None, 172 key: Optional[str] = None, 173 ndim: Optional[int] = None, 174 tile_shape: Optional[Tuple[int, int]] = None, 175 halo: Optional[Tuple[int, int]] = None, 176 verbose: bool = True, 177 return_embeddings: bool = False, 178 annotate: bool = False, 179 batch_size: int = 1, 180 **generate_kwargs 181) -> np.ndarray: 182 """Run automatic segmentation for the input image. 183 184 Args: 185 predictor: The Segment Anything model. 186 segmenter: The automatic instance segmentation class. 187 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 188 or a container file (e.g. hdf5 or zarr). 189 output_path: The output path where the instance segmentations will be saved. 190 embedding_path: The path where the embeddings are cached already / will be saved. 191 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 192 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 193 ndim: The dimensionality of the data. By default the dimensionality of the data will be used. 194 If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. 195 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 196 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 197 verbose: Verbosity flag. By default, set to 'True'. 198 return_embeddings: Whether to return the precomputed image embeddings. 199 By default, does not return the embeddings. 200 annotate: Whether to activate the annotator for continue annotation process. 201 By default, does not activate the annotator. 202 batch_size: The batch size to compute image embeddings over tiles / z-planes. 203 By default, does it sequentially, i.e. one after the other. 204 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 205 206 Returns: 207 The segmentation result. 208 """ 209 # Avoid overwriting already stored segmentations. 210 if output_path is not None: 211 output_path = Path(output_path).with_suffix(".tif") 212 if os.path.exists(output_path): 213 print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.") 214 return 215 216 # Load the input image file. 217 if isinstance(input_path, np.ndarray): 218 image_data = input_path 219 else: 220 image_data = util.load_image_data(input_path, key) 221 222 ndim = image_data.ndim if ndim is None else ndim 223 224 if ndim == 2: 225 if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): 226 raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") 227 228 # Precompute the image embeddings. 229 image_embeddings = util.precompute_image_embeddings( 230 predictor=predictor, 231 input_=image_data, 232 save_path=embedding_path, 233 ndim=ndim, 234 tile_shape=tile_shape, 235 halo=halo, 236 verbose=verbose, 237 batch_size=batch_size, 238 ) 239 initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose) 240 241 # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing. 242 # In this case, we also add the batch size to the initialize kwargs, 243 # so that the segmentation decoder can be applied in a batched fashion. 244 if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None: 245 generate_kwargs.update({"tile_shape": tile_shape, "halo": halo}) 246 initialize_kwargs["batch_size"] = batch_size 247 248 segmenter.initialize(**initialize_kwargs) 249 instances = segmenter.generate(**generate_kwargs) 250 251 else: 252 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 253 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 254 255 instances, image_embeddings = automatic_3d_segmentation( 256 volume=image_data, 257 predictor=predictor, 258 segmentor=segmenter, 259 embedding_path=embedding_path, 260 tile_shape=tile_shape, 261 halo=halo, 262 verbose=verbose, 263 return_embeddings=True, 264 batch_size=batch_size, 265 **generate_kwargs 266 ) 267 268 # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage. 269 if output_path is not None: 270 _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path 271 imageio.imwrite(_output_path, instances, compression="zlib") 272 if verbose: 273 print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.") 274 275 # Allow opening the automatic segmentation in the annotator for further annotation, if desired. 276 if annotate: 277 from micro_sam.sam_annotator import annotator_2d, annotator_3d 278 annotator_function = annotator_2d if ndim == 2 else annotator_3d 279 280 viewer = annotator_function( 281 image=image_data, 282 model_type=predictor.model_name, 283 embedding_path=image_embeddings, # Providing the precomputed image embeddings. 284 segmentation_result=instances, # Initializes the automatic segmentation to the annotator. 285 tile_shape=tile_shape, 286 halo=halo, 287 return_viewer=True, # Returns the viewer, which allows the user to store the updated segmentations. 288 ) 289 290 # Start the GUI here 291 import napari 292 napari.run() 293 294 # We extract the segmentation in "committed_objects" layer, where the user either: 295 # a) Performed interactive segmentation / corrections and committed them, OR 296 # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is. 297 instances = viewer.layers["committed_objects"].data 298 299 # Save the instance segmentation, if 'output_path' provided. 300 if output_path is not None: 301 imageio.imwrite(output_path, instances, compression="zlib") 302 if verbose: 303 print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.") 304 305 if return_embeddings: 306 return instances, image_embeddings 307 else: 308 return instances
Run automatic segmentation for the input image.
Arguments:
- predictor: The Segment Anything model.
- segmenter: The automatic instance segmentation class.
- input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
- output_path: The output path where the instance segmentations will be saved.
- embedding_path: The path where the embeddings are cached already / will be saved.
- key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
- ndim: The dimensionality of the data. By default the dimensionality of the data will be used. If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Verbosity flag. By default, set to 'True'.
- return_embeddings: Whether to return the precomputed image embeddings. By default, does not return the embeddings.
- annotate: Whether to activate the annotator for continue annotation process. By default, does not activate the annotator.
- batch_size: The batch size to compute image embeddings over tiles / z-planes. By default, does it sequentially, i.e. one after the other.
- generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:
The segmentation result.