micro_sam.automatic_segmentation
1import os 2import warnings 3from glob import glob 4from tqdm import tqdm 5from pathlib import Path 6from functools import partial 7from typing import Dict, List, Optional, Union, Tuple 8 9import numpy as np 10import imageio.v3 as imageio 11 12from torch_em.data.datasets.util import split_kwargs 13 14from . import util 15from .instance_segmentation import ( 16 get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder, 17 AMGBase, AutomaticMaskGenerator, TiledAutomaticMaskGenerator 18) 19from .multi_dimensional_segmentation import automatic_3d_segmentation, automatic_tracking_implementation 20 21 22def get_predictor_and_segmenter( 23 model_type: str, 24 checkpoint: Optional[Union[os.PathLike, str]] = None, 25 device: str = None, 26 amg: Optional[bool] = None, 27 is_tiled: bool = False, 28 **kwargs, 29) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: 30 """Get the Segment Anything model and class for automatic instance segmentation. 31 32 Args: 33 model_type: The Segment Anything model choice. 34 checkpoint: The filepath to the stored model checkpoints. 35 device: The torch device. By default, automatically chooses the best available device. 36 amg: Whether to perform automatic segmentation in AMG mode. 37 Otherwise AIS will be used, which requires a special segmentation decoder. 38 If not specified AIS will be used if it is available and otherwise AMG will be used. 39 is_tiled: Whether to return segmenter for performing segmentation in tiling window style. 40 By default, set to 'False'. 41 kwargs: Keyword arguments for the automatic mask generation class. 42 43 Returns: 44 The Segment Anything model. 45 The automatic instance segmentation class. 46 """ 47 # Get the device 48 device = util.get_device(device=device) 49 50 # Get the predictor and state for Segment Anything models. 51 predictor, state = util.get_sam_model( 52 model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True, 53 ) 54 55 if amg is None: 56 amg = "decoder_state" not in state 57 58 if amg: 59 decoder = None 60 else: 61 if "decoder_state" not in state: 62 raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.") 63 decoder_state = state["decoder_state"] 64 decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) 65 66 segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs) 67 68 return predictor, segmenter 69 70 71def _add_suffix_to_output_path(output_path: Union[str, os.PathLike], suffix: str) -> str: 72 fpath = Path(output_path).resolve() 73 fext = fpath.suffix if fpath.suffix else ".tif" 74 return str(fpath.with_name(f"{fpath.stem}{suffix}{fext}")) 75 76 77def automatic_tracking( 78 predictor: util.SamPredictor, 79 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 80 input_path: Union[Union[os.PathLike, str], np.ndarray], 81 output_path: Optional[Union[os.PathLike, str]] = None, 82 embedding_path: Optional[Union[os.PathLike, str]] = None, 83 key: Optional[str] = None, 84 tile_shape: Optional[Tuple[int, int]] = None, 85 halo: Optional[Tuple[int, int]] = None, 86 verbose: bool = True, 87 return_embeddings: bool = False, 88 annotate: bool = False, 89 batch_size: int = 1, 90 **generate_kwargs 91) -> Tuple[np.ndarray, List[Dict]]: 92 """Run automatic tracking for the input timeseries. 93 94 Args: 95 predictor: The Segment Anything model. 96 segmenter: The automatic instance segmentation class. 97 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 98 or a container file (e.g. hdf5 or zarr). 99 output_path: The folder where the tracking outputs will be saved in CTC format. 100 embedding_path: The path where the embeddings are cached already / will be saved. 101 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 102 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 103 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 104 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 105 verbose: Verbosity flag. By default, set to 'True'. 106 return_embeddings: Whether to return the precomputed image embeddings. 107 By default, does not return the embeddings. 108 annotate: Whether to activate the annotator for continue annotation process. 109 By default, does not activate the annotator. 110 batch_size: The batch size to compute image embeddings over tiles / z-planes. 111 By default, does it sequentially, i.e. one after the other. 112 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 113 114 Returns: 115 The tracking result as a timeseries, where each object is labeled by its track id. 116 The lineages representing cell divisions, stored as a dictionary. 117 """ 118 # Load the input image file. 119 if isinstance(input_path, np.ndarray): 120 image_data = input_path 121 else: 122 image_data = util.load_image_data(input_path, key) 123 124 # We perform additional post-processing for AMG-only. 125 # Otherwise, we ignore additional post-processing for AIS. 126 if isinstance(segmenter, InstanceSegmentationWithDecoder): 127 generate_kwargs["output_mode"] = None 128 129 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 130 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 131 132 gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent") 133 segmentation, lineage, image_embeddings = automatic_tracking_implementation( 134 image_data, 135 predictor, 136 segmenter, 137 embedding_path=embedding_path, 138 gap_closing=gap_closing, 139 min_time_extent=min_time_extent, 140 tile_shape=tile_shape, 141 halo=halo, 142 verbose=verbose, 143 batch_size=batch_size, 144 return_embeddings=True, 145 output_folder=output_path, 146 **generate_kwargs, 147 ) 148 149 if annotate: 150 # TODO We need to support initialization of the tracking annotator with the tracking result for this. 151 raise NotImplementedError("Annotation after running the automated tracking is currently not supported.") 152 153 if return_embeddings: 154 return segmentation, lineage, image_embeddings 155 else: 156 return segmentation, lineage 157 158 159def automatic_instance_segmentation( 160 predictor: util.SamPredictor, 161 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 162 input_path: Union[Union[os.PathLike, str], np.ndarray], 163 output_path: Optional[Union[os.PathLike, str]] = None, 164 embedding_path: Optional[Union[os.PathLike, str]] = None, 165 key: Optional[str] = None, 166 ndim: Optional[int] = None, 167 tile_shape: Optional[Tuple[int, int]] = None, 168 halo: Optional[Tuple[int, int]] = None, 169 verbose: bool = True, 170 return_embeddings: bool = False, 171 annotate: bool = False, 172 batch_size: int = 1, 173 **generate_kwargs 174) -> np.ndarray: 175 """Run automatic segmentation for the input image. 176 177 Args: 178 predictor: The Segment Anything model. 179 segmenter: The automatic instance segmentation class. 180 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 181 or a container file (e.g. hdf5 or zarr). 182 output_path: The output path where the instance segmentations will be saved. 183 embedding_path: The path where the embeddings are cached already / will be saved. 184 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 185 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 186 ndim: The dimensionality of the data. By default the dimensionality of the data will be used. 187 If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. 188 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 189 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 190 verbose: Verbosity flag. By default, set to 'True'. 191 return_embeddings: Whether to return the precomputed image embeddings. 192 By default, does not return the embeddings. 193 annotate: Whether to activate the annotator for continue annotation process. 194 By default, does not activate the annotator. 195 batch_size: The batch size to compute image embeddings over tiles / z-planes. 196 By default, does it sequentially, i.e. one after the other. 197 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 198 199 Returns: 200 The segmentation result. 201 """ 202 # Avoid overwriting already stored segmentations. 203 if output_path is not None: 204 output_path = Path(output_path).with_suffix(".tif") 205 if os.path.exists(output_path): 206 print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.") 207 return 208 209 # Load the input image file. 210 if isinstance(input_path, np.ndarray): 211 image_data = input_path 212 else: 213 image_data = util.load_image_data(input_path, key) 214 215 ndim = image_data.ndim if ndim is None else ndim 216 217 # We perform additional post-processing for AMG-only. 218 # Otherwise, we ignore additional post-processing for AIS. 219 if isinstance(segmenter, InstanceSegmentationWithDecoder): 220 generate_kwargs["output_mode"] = None 221 222 if ndim == 2: 223 if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): 224 raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") 225 226 # Precompute the image embeddings. 227 image_embeddings = util.precompute_image_embeddings( 228 predictor=predictor, 229 input_=image_data, 230 save_path=embedding_path, 231 ndim=ndim, 232 tile_shape=tile_shape, 233 halo=halo, 234 verbose=verbose, 235 batch_size=batch_size, 236 ) 237 initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose) 238 239 # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing. 240 # In this case, we also add the batch size to the initialize kwargs, 241 # so that the segmentation decoder can be applied in a batched fashion. 242 if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None: 243 generate_kwargs.update({"tile_shape": tile_shape, "halo": halo}) 244 initialize_kwargs["batch_size"] = batch_size 245 246 segmenter.initialize(**initialize_kwargs) 247 masks = segmenter.generate(**generate_kwargs) 248 249 if isinstance(masks, list): 250 # whether the predictions from 'generate' are list of dict, 251 # which contains additional info req. for post-processing, eg. area per object. 252 if len(masks) == 0: 253 instances = np.zeros(image_data.shape[:2], dtype="uint32") 254 else: 255 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) 256 else: 257 # if (raw) predictions provided, store them as it is w/o further post-processing. 258 instances = masks 259 260 else: 261 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 262 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 263 264 instances, image_embeddings = automatic_3d_segmentation( 265 volume=image_data, 266 predictor=predictor, 267 segmentor=segmenter, 268 embedding_path=embedding_path, 269 tile_shape=tile_shape, 270 halo=halo, 271 verbose=verbose, 272 return_embeddings=True, 273 batch_size=batch_size, 274 **generate_kwargs 275 ) 276 277 # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage. 278 if output_path is not None: 279 _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path 280 imageio.imwrite(_output_path, instances, compression="zlib") 281 if verbose: 282 print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.") 283 284 # Allow opening the automatic segmentation in the annotator for further annotation, if desired. 285 if annotate: 286 from micro_sam.sam_annotator import annotator_2d, annotator_3d 287 annotator_function = annotator_2d if ndim == 2 else annotator_3d 288 289 viewer = annotator_function( 290 image=image_data, 291 model_type=predictor.model_name, 292 embedding_path=image_embeddings, # Providing the precomputed image embeddings. 293 segmentation_result=instances, # Initializes the automatic segmentation to the annotator. 294 tile_shape=tile_shape, 295 halo=halo, 296 return_viewer=True, # Returns the viewer, which allows the user to store the updated segmentations. 297 ) 298 299 # Start the GUI here 300 import napari 301 napari.run() 302 303 # We extract the segmentation in "committed_objects" layer, where the user either: 304 # a) Performed interactive segmentation / corrections and committed them, OR 305 # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is. 306 instances = viewer.layers["committed_objects"].data 307 308 # Save the instance segmentation, if 'output_path' provided. 309 if output_path is not None: 310 imageio.imwrite(output_path, instances, compression="zlib") 311 if verbose: 312 print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.") 313 314 if return_embeddings: 315 return instances, image_embeddings 316 else: 317 return instances 318 319 320def _get_inputs_from_paths(paths, pattern): 321 "Function to get all filepaths in a directory." 322 323 if isinstance(paths, str): 324 paths = [paths] 325 326 fpaths = [] 327 for path in paths: 328 if os.path.isfile(path): # It is just one filepath. 329 fpaths.append(path) 330 else: # Otherwise, if the path is a directory, fetch all inputs provided with a pattern. 331 assert pattern is not None, \ 332 f"You must provide a pattern to search for files in the directory: '{os.path.abspath(path)}'." 333 fpaths.extend(glob(os.path.join(path, pattern))) 334 335 return fpaths 336 337 338def main(): 339 """@private""" 340 import argparse 341 342 available_models = list(util.get_model_names()) 343 available_models = ", ".join(available_models) 344 345 parser = argparse.ArgumentParser( 346 description="Run automatic segmentation or tracking for 2d, 3d or timeseries data.\n" 347 "Either a single input file or multiple input files are supported. You can specify multiple files " 348 "by either providing multiple filepaths to the '--i/--input_paths' argument, or by providing an argument " 349 "to '--pattern' to use a wildcard pattern ('*') for selecting multiple files.\n" 350 "NOTE: for automatic 3d segmentation or tracking the data has to be stored as volume / timeseries, " 351 "stacking individual tif images is not supported.\n" 352 "Segmentation is performed using one of the two modes supported by micro_sam: \n" 353 "automatic instance segmentation (AIS) or automatic mask generation (AMG).\n" 354 "In addition to the options listed below, " 355 "you can also passed additional arguments for these two segmentation modes:\n" 356 "For AIS: '--center_distance_threshold', '--boundary_distance_threshold' and other arguments of `InstanceSegmentationWithDecoder.generate`." # noqa 357 "For AMG: '--pred_iou_thresh', '--stability_score_thresh' and other arguments of `AutomaticMaskGenerator.generate`." # noqa 358 ) 359 parser.add_argument( 360 "-i", "--input_path", required=True, type=str, nargs="+", 361 help="The filepath(s) to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " # noqa 362 "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter." 363 ) 364 parser.add_argument( 365 "-o", "--output_path", required=True, type=str, 366 help="The filepath to store the results. If multiple inputs are provied, " 367 "this should be a folder. For a single image, you should provide the path to a tif file for the output segmentation." # noqa 368 "NOTE: Segmentation results are stored as tif files, tracking results in the CTC fil format ." 369 ) 370 parser.add_argument( 371 "-e", "--embedding_path", default=None, type=str, 372 help="An optional path where the embeddings will be saved. If multiple inputs are provided, " 373 "this should be a folder. Otherwise you can store embeddings in single zarr file." 374 ) 375 parser.add_argument( 376 "--pattern", type=str, help="Pattern / wildcard for selecting files in a folder. To select all files use '*'." 377 ) 378 parser.add_argument( 379 "-k", "--key", default=None, type=str, 380 help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, " 381 "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'." 382 ) 383 parser.add_argument( 384 "-m", "--model_type", default=util._DEFAULT_MODEL, type=str, 385 help=f"The segment anything model that will be used, one of {available_models}." 386 ) 387 parser.add_argument( 388 "-c", "--checkpoint", default=None, type=str, help="Checkpoint from which the SAM model will be loaded." 389 ) 390 parser.add_argument( 391 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None 392 ) 393 parser.add_argument( 394 "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None 395 ) 396 parser.add_argument( 397 "-n", "--ndim", default=None, type=int, 398 help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension." 399 ) 400 parser.add_argument( 401 "--mode", default="auto", type=str, 402 help="The choice of automatic segmentation with the Segment Anything models. Either 'auto', 'amg' or 'ais'." 403 ) 404 parser.add_argument( 405 "--annotate", action="store_true", 406 help="Whether to continue annotation after the automatic segmentation is generated." 407 ) 408 parser.add_argument( 409 "-d", "--device", default=None, type=str, 410 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 411 "By default the most performant available device will be selected." 412 ) 413 parser.add_argument( 414 "--batch_size", type=int, default=1, 415 help="The batch size for computing image embeddings over tiles or z-plane. " 416 "By default, computes the image embeddings for one tile / z-plane at a time." 417 ) 418 parser.add_argument( 419 "--tracking", action="store_true", help="Run automatic tracking instead of instance segmentation. " 420 "NOTE: It is only supported for timeseries inputs." 421 ) 422 parser.add_argument( 423 "-v", "--verbose", action="store_true", help="Whether to allow verbosity of outputs." 424 ) 425 426 args, parameter_args = parser.parse_known_args() 427 428 def _convert_argval(value): 429 # The values for the parsed arguments need to be in the expected input structure as provided. 430 # i.e. integers and floats should be in their original types. 431 try: 432 return int(value) 433 except ValueError: 434 return float(value) 435 436 # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to 437 # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS) 438 extra_kwargs = { 439 parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2) 440 } 441 442 # Separate extra arguments as per where they should be passed in the automatic segmentation class. 443 # This is done to ensure the extra arguments are allocated to the desired location. 444 # eg. for AMG, 'points_per_side' is expected by '__init__', 445 # and 'stability_score_thresh' is expected in 'generate' method. 446 amg_class = AutomaticMaskGenerator if args.tile_shape is None else TiledAutomaticMaskGenerator 447 amg_kwargs, generate_kwargs = split_kwargs(amg_class, **extra_kwargs) 448 449 # Validate for the expected automatic segmentation mode. 450 # By default, it is set to 'auto', i.e. searches for the decoder state to prioritize AIS for finetuned models. 451 # Otherwise, runs AMG for all models in any case. 452 amg = None 453 if args.mode != "auto": 454 assert args.mode in ["ais", "amg"], \ 455 f"'{args.mode}' is not a valid automatic segmentation mode. Please choose either 'amg' or 'ais'." 456 amg = (args.mode == "amg") 457 458 predictor, segmenter = get_predictor_and_segmenter( 459 model_type=args.model_type, 460 checkpoint=args.checkpoint, 461 device=args.device, 462 amg=amg, 463 is_tiled=args.tile_shape is not None, 464 **amg_kwargs, 465 ) 466 467 # Get the filepaths to input images (and other paths to store stuff, eg. segmentations and embeddings) 468 # Check whether the inputs are as expected, otherwise assort them. 469 input_paths = _get_inputs_from_paths(args.input_path, args.pattern) 470 assert len(input_paths) > 0, "'micro-sam' could not extract any image data internally." 471 472 output_path = args.output_path 473 embedding_path = args.embedding_path 474 has_one_input = len(input_paths) == 1 475 476 instance_seg_function = automatic_tracking if args.tracking else partial( 477 automatic_instance_segmentation, ndim=args.ndim 478 ) 479 480 # Run automatic segmentation per image. 481 for input_path in tqdm(input_paths, desc="Run automatic " + ("tracking" if args.tracking else "segmentation")): 482 if has_one_input: # When we have only one image / volume. 483 _embedding_fpath = embedding_path # Either folder or zarr file, would work for both. 484 485 output_fdir = os.path.splitext(output_path)[0] 486 os.makedirs(output_fdir, exist_ok=True) 487 488 # For tracking, we ensure that the output path is a folder, 489 # i.e. does not have an extension. We throw a warning if the user provided an extension. 490 if args.tracking: 491 if os.path.splitext(output_path)[-1]: 492 warnings.warn( 493 f"The output folder has an extension '{os.path.splitext(output_path)[-1]}'. " 494 "We remove it and treat it as a folder to store tracking outputs in CTC format." 495 ) 496 _output_fpath = output_fdir 497 else: # Otherwise, we can store outputs for user directly in the provided filepath, ensuring extension .tif 498 _output_fpath = f"{output_fdir}.tif" 499 500 else: # When we have multiple images. 501 # Get the input filename, without the extension. 502 input_name = str(Path(input_path).stem) 503 504 # Let's check the 'embedding_path'. 505 if embedding_path is None: # For computing embeddings on-the-fly, we don't care about the path logic. 506 _embedding_fpath = embedding_path 507 else: # Otherwise, store each embeddings inside a folder. 508 embedding_folder = os.path.splitext(embedding_path)[0] # Treat the provided embedding path as folder. 509 os.makedirs(embedding_folder, exist_ok=True) 510 _embedding_fpath = os.path.join(embedding_folder, f"{input_name}.zarr") # Create each embedding file. 511 512 # Get the output folder name. 513 output_folder = os.path.splitext(output_path)[0] 514 os.makedirs(output_folder, exist_ok=True) 515 516 # Next, let's check for output file to store segmentation (or tracks). 517 if args.tracking: # For tracking, we store CTC outputs in subfolders, with input_name as folder. 518 _output_fpath = os.path.join(output_folder, input_name) 519 else: # Otherwise, store each result inside a folder. 520 _output_fpath = os.path.join(output_folder, f"{input_name}.tif") 521 522 instance_seg_function( 523 predictor=predictor, 524 segmenter=segmenter, 525 input_path=input_path, 526 output_path=_output_fpath, 527 embedding_path=_embedding_fpath, 528 key=args.key, 529 tile_shape=args.tile_shape, 530 halo=args.halo, 531 annotate=args.annotate, 532 verbose=args.verbose, 533 batch_size=args.batch_size, 534 **generate_kwargs, 535 )
def
get_predictor_and_segmenter( model_type: str, checkpoint: Union[os.PathLike, str, NoneType] = None, device: str = None, amg: Optional[bool] = None, is_tiled: bool = False, **kwargs) -> Tuple[segment_anything.predictor.SamPredictor, Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder]]:
23def get_predictor_and_segmenter( 24 model_type: str, 25 checkpoint: Optional[Union[os.PathLike, str]] = None, 26 device: str = None, 27 amg: Optional[bool] = None, 28 is_tiled: bool = False, 29 **kwargs, 30) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: 31 """Get the Segment Anything model and class for automatic instance segmentation. 32 33 Args: 34 model_type: The Segment Anything model choice. 35 checkpoint: The filepath to the stored model checkpoints. 36 device: The torch device. By default, automatically chooses the best available device. 37 amg: Whether to perform automatic segmentation in AMG mode. 38 Otherwise AIS will be used, which requires a special segmentation decoder. 39 If not specified AIS will be used if it is available and otherwise AMG will be used. 40 is_tiled: Whether to return segmenter for performing segmentation in tiling window style. 41 By default, set to 'False'. 42 kwargs: Keyword arguments for the automatic mask generation class. 43 44 Returns: 45 The Segment Anything model. 46 The automatic instance segmentation class. 47 """ 48 # Get the device 49 device = util.get_device(device=device) 50 51 # Get the predictor and state for Segment Anything models. 52 predictor, state = util.get_sam_model( 53 model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True, 54 ) 55 56 if amg is None: 57 amg = "decoder_state" not in state 58 59 if amg: 60 decoder = None 61 else: 62 if "decoder_state" not in state: 63 raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.") 64 decoder_state = state["decoder_state"] 65 decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) 66 67 segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs) 68 69 return predictor, segmenter
Get the Segment Anything model and class for automatic instance segmentation.
Arguments:
- model_type: The Segment Anything model choice.
- checkpoint: The filepath to the stored model checkpoints.
- device: The torch device. By default, automatically chooses the best available device.
- amg: Whether to perform automatic segmentation in AMG mode. Otherwise AIS will be used, which requires a special segmentation decoder. If not specified AIS will be used if it is available and otherwise AMG will be used.
- is_tiled: Whether to return segmenter for performing segmentation in tiling window style. By default, set to 'False'.
- kwargs: Keyword arguments for the automatic mask generation class.
Returns:
The Segment Anything model. The automatic instance segmentation class.
def
automatic_tracking( predictor: segment_anything.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, batch_size: int = 1, **generate_kwargs) -> Tuple[numpy.ndarray, List[Dict]]:
78def automatic_tracking( 79 predictor: util.SamPredictor, 80 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 81 input_path: Union[Union[os.PathLike, str], np.ndarray], 82 output_path: Optional[Union[os.PathLike, str]] = None, 83 embedding_path: Optional[Union[os.PathLike, str]] = None, 84 key: Optional[str] = None, 85 tile_shape: Optional[Tuple[int, int]] = None, 86 halo: Optional[Tuple[int, int]] = None, 87 verbose: bool = True, 88 return_embeddings: bool = False, 89 annotate: bool = False, 90 batch_size: int = 1, 91 **generate_kwargs 92) -> Tuple[np.ndarray, List[Dict]]: 93 """Run automatic tracking for the input timeseries. 94 95 Args: 96 predictor: The Segment Anything model. 97 segmenter: The automatic instance segmentation class. 98 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 99 or a container file (e.g. hdf5 or zarr). 100 output_path: The folder where the tracking outputs will be saved in CTC format. 101 embedding_path: The path where the embeddings are cached already / will be saved. 102 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 103 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 104 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 105 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 106 verbose: Verbosity flag. By default, set to 'True'. 107 return_embeddings: Whether to return the precomputed image embeddings. 108 By default, does not return the embeddings. 109 annotate: Whether to activate the annotator for continue annotation process. 110 By default, does not activate the annotator. 111 batch_size: The batch size to compute image embeddings over tiles / z-planes. 112 By default, does it sequentially, i.e. one after the other. 113 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 114 115 Returns: 116 The tracking result as a timeseries, where each object is labeled by its track id. 117 The lineages representing cell divisions, stored as a dictionary. 118 """ 119 # Load the input image file. 120 if isinstance(input_path, np.ndarray): 121 image_data = input_path 122 else: 123 image_data = util.load_image_data(input_path, key) 124 125 # We perform additional post-processing for AMG-only. 126 # Otherwise, we ignore additional post-processing for AIS. 127 if isinstance(segmenter, InstanceSegmentationWithDecoder): 128 generate_kwargs["output_mode"] = None 129 130 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 131 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 132 133 gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent") 134 segmentation, lineage, image_embeddings = automatic_tracking_implementation( 135 image_data, 136 predictor, 137 segmenter, 138 embedding_path=embedding_path, 139 gap_closing=gap_closing, 140 min_time_extent=min_time_extent, 141 tile_shape=tile_shape, 142 halo=halo, 143 verbose=verbose, 144 batch_size=batch_size, 145 return_embeddings=True, 146 output_folder=output_path, 147 **generate_kwargs, 148 ) 149 150 if annotate: 151 # TODO We need to support initialization of the tracking annotator with the tracking result for this. 152 raise NotImplementedError("Annotation after running the automated tracking is currently not supported.") 153 154 if return_embeddings: 155 return segmentation, lineage, image_embeddings 156 else: 157 return segmentation, lineage
Run automatic tracking for the input timeseries.
Arguments:
- predictor: The Segment Anything model.
- segmenter: The automatic instance segmentation class.
- input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
- output_path: The folder where the tracking outputs will be saved in CTC format.
- embedding_path: The path where the embeddings are cached already / will be saved.
- key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Verbosity flag. By default, set to 'True'.
- return_embeddings: Whether to return the precomputed image embeddings. By default, does not return the embeddings.
- annotate: Whether to activate the annotator for continue annotation process. By default, does not activate the annotator.
- batch_size: The batch size to compute image embeddings over tiles / z-planes. By default, does it sequentially, i.e. one after the other.
- generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:
The tracking result as a timeseries, where each object is labeled by its track id. The lineages representing cell divisions, stored as a dictionary.
def
automatic_instance_segmentation( predictor: segment_anything.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, batch_size: int = 1, **generate_kwargs) -> numpy.ndarray:
160def automatic_instance_segmentation( 161 predictor: util.SamPredictor, 162 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 163 input_path: Union[Union[os.PathLike, str], np.ndarray], 164 output_path: Optional[Union[os.PathLike, str]] = None, 165 embedding_path: Optional[Union[os.PathLike, str]] = None, 166 key: Optional[str] = None, 167 ndim: Optional[int] = None, 168 tile_shape: Optional[Tuple[int, int]] = None, 169 halo: Optional[Tuple[int, int]] = None, 170 verbose: bool = True, 171 return_embeddings: bool = False, 172 annotate: bool = False, 173 batch_size: int = 1, 174 **generate_kwargs 175) -> np.ndarray: 176 """Run automatic segmentation for the input image. 177 178 Args: 179 predictor: The Segment Anything model. 180 segmenter: The automatic instance segmentation class. 181 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 182 or a container file (e.g. hdf5 or zarr). 183 output_path: The output path where the instance segmentations will be saved. 184 embedding_path: The path where the embeddings are cached already / will be saved. 185 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 186 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 187 ndim: The dimensionality of the data. By default the dimensionality of the data will be used. 188 If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. 189 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 190 halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. 191 verbose: Verbosity flag. By default, set to 'True'. 192 return_embeddings: Whether to return the precomputed image embeddings. 193 By default, does not return the embeddings. 194 annotate: Whether to activate the annotator for continue annotation process. 195 By default, does not activate the annotator. 196 batch_size: The batch size to compute image embeddings over tiles / z-planes. 197 By default, does it sequentially, i.e. one after the other. 198 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 199 200 Returns: 201 The segmentation result. 202 """ 203 # Avoid overwriting already stored segmentations. 204 if output_path is not None: 205 output_path = Path(output_path).with_suffix(".tif") 206 if os.path.exists(output_path): 207 print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.") 208 return 209 210 # Load the input image file. 211 if isinstance(input_path, np.ndarray): 212 image_data = input_path 213 else: 214 image_data = util.load_image_data(input_path, key) 215 216 ndim = image_data.ndim if ndim is None else ndim 217 218 # We perform additional post-processing for AMG-only. 219 # Otherwise, we ignore additional post-processing for AIS. 220 if isinstance(segmenter, InstanceSegmentationWithDecoder): 221 generate_kwargs["output_mode"] = None 222 223 if ndim == 2: 224 if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): 225 raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") 226 227 # Precompute the image embeddings. 228 image_embeddings = util.precompute_image_embeddings( 229 predictor=predictor, 230 input_=image_data, 231 save_path=embedding_path, 232 ndim=ndim, 233 tile_shape=tile_shape, 234 halo=halo, 235 verbose=verbose, 236 batch_size=batch_size, 237 ) 238 initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose) 239 240 # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing. 241 # In this case, we also add the batch size to the initialize kwargs, 242 # so that the segmentation decoder can be applied in a batched fashion. 243 if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None: 244 generate_kwargs.update({"tile_shape": tile_shape, "halo": halo}) 245 initialize_kwargs["batch_size"] = batch_size 246 247 segmenter.initialize(**initialize_kwargs) 248 masks = segmenter.generate(**generate_kwargs) 249 250 if isinstance(masks, list): 251 # whether the predictions from 'generate' are list of dict, 252 # which contains additional info req. for post-processing, eg. area per object. 253 if len(masks) == 0: 254 instances = np.zeros(image_data.shape[:2], dtype="uint32") 255 else: 256 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) 257 else: 258 # if (raw) predictions provided, store them as it is w/o further post-processing. 259 instances = masks 260 261 else: 262 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 263 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 264 265 instances, image_embeddings = automatic_3d_segmentation( 266 volume=image_data, 267 predictor=predictor, 268 segmentor=segmenter, 269 embedding_path=embedding_path, 270 tile_shape=tile_shape, 271 halo=halo, 272 verbose=verbose, 273 return_embeddings=True, 274 batch_size=batch_size, 275 **generate_kwargs 276 ) 277 278 # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage. 279 if output_path is not None: 280 _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path 281 imageio.imwrite(_output_path, instances, compression="zlib") 282 if verbose: 283 print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.") 284 285 # Allow opening the automatic segmentation in the annotator for further annotation, if desired. 286 if annotate: 287 from micro_sam.sam_annotator import annotator_2d, annotator_3d 288 annotator_function = annotator_2d if ndim == 2 else annotator_3d 289 290 viewer = annotator_function( 291 image=image_data, 292 model_type=predictor.model_name, 293 embedding_path=image_embeddings, # Providing the precomputed image embeddings. 294 segmentation_result=instances, # Initializes the automatic segmentation to the annotator. 295 tile_shape=tile_shape, 296 halo=halo, 297 return_viewer=True, # Returns the viewer, which allows the user to store the updated segmentations. 298 ) 299 300 # Start the GUI here 301 import napari 302 napari.run() 303 304 # We extract the segmentation in "committed_objects" layer, where the user either: 305 # a) Performed interactive segmentation / corrections and committed them, OR 306 # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is. 307 instances = viewer.layers["committed_objects"].data 308 309 # Save the instance segmentation, if 'output_path' provided. 310 if output_path is not None: 311 imageio.imwrite(output_path, instances, compression="zlib") 312 if verbose: 313 print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.") 314 315 if return_embeddings: 316 return instances, image_embeddings 317 else: 318 return instances
Run automatic segmentation for the input image.
Arguments:
- predictor: The Segment Anything model.
- segmenter: The automatic instance segmentation class.
- input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
- output_path: The output path where the instance segmentations will be saved.
- embedding_path: The path where the embeddings are cached already / will be saved.
- key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
- ndim: The dimensionality of the data. By default the dimensionality of the data will be used. If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
- verbose: Verbosity flag. By default, set to 'True'.
- return_embeddings: Whether to return the precomputed image embeddings. By default, does not return the embeddings.
- annotate: Whether to activate the annotator for continue annotation process. By default, does not activate the annotator.
- batch_size: The batch size to compute image embeddings over tiles / z-planes. By default, does it sequentially, i.e. one after the other.
- generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:
The segmentation result.