micro_sam.automatic_segmentation
1import os 2from glob import glob 3from tqdm import tqdm 4from pathlib import Path 5from typing import Optional, Union, Tuple 6 7import numpy as np 8import imageio.v3 as imageio 9 10from torch_em.data.datasets.util import split_kwargs 11 12from . import util 13from .instance_segmentation import ( 14 get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder, 15 AMGBase, AutomaticMaskGenerator, TiledAutomaticMaskGenerator 16) 17from .multi_dimensional_segmentation import automatic_3d_segmentation 18 19 20def get_predictor_and_segmenter( 21 model_type: str, 22 checkpoint: Optional[Union[os.PathLike, str]] = None, 23 device: str = None, 24 amg: Optional[bool] = None, 25 is_tiled: bool = False, 26 **kwargs, 27) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: 28 """Get the Segment Anything model and class for automatic instance segmentation. 29 30 Args: 31 model_type: The Segment Anything model choice. 32 checkpoint: The filepath to the stored model checkpoints. 33 device: The torch device. 34 amg: Whether to perform automatic segmentation in AMG mode. 35 Otherwise AIS will be used, which requires a special segmentation decoder. 36 If not specified AIS will be used if it is available and otherwise AMG will be used. 37 is_tiled: Whether to return segmenter for performing segmentation in tiling window style. 38 kwargs: Keyword arguments for the automatic mask generation class. 39 40 Returns: 41 The Segment Anything model. 42 The automatic instance segmentation class. 43 """ 44 # Get the device 45 device = util.get_device(device=device) 46 47 # Get the predictor and state for Segment Anything models. 48 predictor, state = util.get_sam_model( 49 model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True, 50 ) 51 52 if amg is None: 53 amg = "decoder_state" not in state 54 55 if amg: 56 decoder = None 57 else: 58 if "decoder_state" not in state: 59 raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.") 60 decoder_state = state["decoder_state"] 61 decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) 62 63 segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs) 64 65 return predictor, segmenter 66 67 68def _add_suffix_to_output_path(output_path: Union[str, os.PathLike], suffix: str) -> str: 69 fpath = Path(output_path).resolve() 70 fext = fpath.suffix if fpath.suffix else ".tif" 71 return str(fpath.with_name(f"{fpath.stem}{suffix}{fext}")) 72 73 74def automatic_instance_segmentation( 75 predictor: util.SamPredictor, 76 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 77 input_path: Union[Union[os.PathLike, str], np.ndarray], 78 output_path: Optional[Union[os.PathLike, str]] = None, 79 embedding_path: Optional[Union[os.PathLike, str]] = None, 80 key: Optional[str] = None, 81 ndim: Optional[int] = None, 82 tile_shape: Optional[Tuple[int, int]] = None, 83 halo: Optional[Tuple[int, int]] = None, 84 verbose: bool = True, 85 return_embeddings: bool = False, 86 annotate: bool = False, 87 batch_size: int = 1, 88 **generate_kwargs 89) -> np.ndarray: 90 """Run automatic segmentation for the input image. 91 92 Args: 93 predictor: The Segment Anything model. 94 segmenter: The automatic instance segmentation class. 95 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 96 or a container file (e.g. hdf5 or zarr). 97 output_path: The output path where the instance segmentations will be saved. 98 embedding_path: The path where the embeddings are cached already / will be saved. 99 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 100 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 101 ndim: The dimensionality of the data. By default the dimensionality of the data will be used. 102 If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. 103 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 104 halo: Overlap of the tiles for tiled prediction. 105 verbose: Verbosity flag. 106 return_embeddings: Whether to return the precomputed image embeddings. 107 annotate: Whether to activate the annotator for continue annotation process. 108 batch_size: The batch size to compute image embeddings over tiles / z-planes. 109 By default, does it sequentially, i.e. one after the other. 110 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 111 112 Returns: 113 The segmentation result. 114 """ 115 # Avoid overwriting already stored segmentations. 116 if output_path is not None: 117 output_path = Path(output_path).with_suffix(".tif") 118 if os.path.exists(output_path): 119 print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.") 120 return 121 122 # Load the input image file. 123 if isinstance(input_path, np.ndarray): 124 image_data = input_path 125 else: 126 image_data = util.load_image_data(input_path, key) 127 128 ndim = image_data.ndim if ndim is None else ndim 129 130 # We perform additional post-processing for AMG-only. 131 # Otherwise, we ignore additional post-processing for AIS. 132 if isinstance(segmenter, InstanceSegmentationWithDecoder): 133 generate_kwargs["output_mode"] = None 134 135 if ndim == 2: 136 if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): 137 raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") 138 139 # Precompute the image embeddings. 140 image_embeddings = util.precompute_image_embeddings( 141 predictor=predictor, 142 input_=image_data, 143 save_path=embedding_path, 144 ndim=ndim, 145 tile_shape=tile_shape, 146 halo=halo, 147 verbose=verbose, 148 batch_size=batch_size, 149 ) 150 initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose) 151 152 # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing. 153 # In this case, we also add the batch size to the initialize kwargs, 154 # so that the segmentation decoder can be applied in a batched fashion. 155 if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None: 156 generate_kwargs.update({"tile_shape": tile_shape, "halo": halo}) 157 initialize_kwargs["batch_size"] = batch_size 158 159 segmenter.initialize(**initialize_kwargs) 160 masks = segmenter.generate(**generate_kwargs) 161 162 if isinstance(masks, list): 163 # whether the predictions from 'generate' are list of dict, 164 # which contains additional info req. for post-processing, eg. area per object. 165 if len(masks) == 0: 166 instances = np.zeros(image_data.shape[:2], dtype="uint32") 167 else: 168 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) 169 else: 170 # if (raw) predictions provided, store them as it is w/o further post-processing. 171 instances = masks 172 173 else: 174 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 175 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 176 177 instances, image_embeddings = automatic_3d_segmentation( 178 volume=image_data, 179 predictor=predictor, 180 segmentor=segmenter, 181 embedding_path=embedding_path, 182 tile_shape=tile_shape, 183 halo=halo, 184 verbose=verbose, 185 return_embeddings=True, 186 batch_size=batch_size, 187 **generate_kwargs 188 ) 189 190 # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage. 191 if output_path is not None: 192 _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path 193 imageio.imwrite(_output_path, instances, compression="zlib") 194 print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.") 195 196 # Allow opening the automatic segmentation in the annotator for further annotation, if desired. 197 if annotate: 198 from micro_sam.sam_annotator import annotator_2d, annotator_3d 199 annotator_function = annotator_2d if ndim == 2 else annotator_3d 200 201 viewer = annotator_function( 202 image=image_data, 203 model_type=predictor.model_name, 204 embedding_path=image_embeddings, # Providing the precomputed image embeddings. 205 segmentation_result=instances, # Initializes the automatic segmentation to the annotator. 206 tile_shape=tile_shape, 207 halo=halo, 208 return_viewer=True, # Returns the viewer, which allows the user to store the updated segmentations. 209 ) 210 211 # Start the GUI here 212 import napari 213 napari.run() 214 215 # We extract the segmentation in "committed_objects" layer, where the user either: 216 # a) Performed interactive segmentation / corrections and committed them, OR 217 # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is. 218 instances = viewer.layers["committed_objects"].data 219 220 # Save the instance segmentation, if 'output_path' provided. 221 if output_path is not None: 222 imageio.imwrite(output_path, instances, compression="zlib") 223 print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.") 224 225 if return_embeddings: 226 return instances, image_embeddings 227 else: 228 return instances 229 230 231def _get_inputs_from_paths(paths, pattern): 232 "Function to get all filepaths in a directory." 233 234 if isinstance(paths, str): 235 paths = [paths] 236 237 fpaths = [] 238 for path in paths: 239 if _has_extension(path): # It is just one filepath. 240 fpaths.append(path) 241 else: # Otherwise, if the path is a directory, fetch all inputs provided with a pattern. 242 assert pattern is not None, \ 243 f"You must provide a pattern to search for files in the directory: '{os.path.abspath(path)}'." 244 fpaths.extend(glob(os.path.join(path, pattern))) 245 246 return fpaths 247 248 249def _has_extension(fpath: Union[os.PathLike, str]) -> bool: 250 "Returns whether the provided path has an extension or not." 251 return bool(os.path.splitext(fpath)[1]) 252 253 254def main(): 255 """@private""" 256 import argparse 257 258 available_models = list(util.get_model_names()) 259 available_models = ", ".join(available_models) 260 261 parser = argparse.ArgumentParser(description="Run automatic segmentation for an image.") 262 parser.add_argument( 263 "-i", "--input_path", required=True, type=str, nargs="+", 264 help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " 265 "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter." 266 ) 267 parser.add_argument( 268 "-o", "--output_path", required=True, type=str, 269 help="The filepath to store the instance segmentation. The current support stores segmentation in a 'tif' file." 270 ) 271 parser.add_argument( 272 "-e", "--embedding_path", default=None, type=str, help="The path where the embeddings will be saved." 273 ) 274 parser.add_argument( 275 "--pattern", type=str, help="Pattern / wildcard for selecting files in a folder. To select all files use '*'." 276 ) 277 parser.add_argument( 278 "-k", "--key", default=None, type=str, 279 help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, " 280 "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'." 281 ) 282 parser.add_argument( 283 "-m", "--model_type", default=util._DEFAULT_MODEL, type=str, 284 help=f"The segment anything model that will be used, one of {available_models}." 285 ) 286 parser.add_argument( 287 "-c", "--checkpoint", default=None, type=str, help="Checkpoint from which the SAM model will be loaded." 288 ) 289 parser.add_argument( 290 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None 291 ) 292 parser.add_argument( 293 "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None 294 ) 295 parser.add_argument( 296 "-n", "--ndim", default=None, type=int, 297 help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension." 298 ) 299 parser.add_argument( 300 "--mode", default="auto", type=str, 301 help="The choice of automatic segmentation with the Segment Anything models. Either 'auto', 'amg' or 'ais'." 302 ) 303 parser.add_argument( 304 "--annotate", action="store_true", 305 help="Whether to continue annotation after the automatic segmentation is generated." 306 ) 307 parser.add_argument( 308 "-d", "--device", default=None, type=str, 309 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 310 "By default the most performant available device will be selected." 311 ) 312 parser.add_argument( 313 "--batch_size", type=int, default=1, 314 help="The batch size for computing image embeddings over tiles or z-plane. " 315 "By default, computes the image embeddings for one tile / z-plane at a time." 316 ) 317 parser.add_argument( 318 "-v", "--verbose", action="store_true", help="Whether to allow verbosity of outputs." 319 ) 320 321 args, parameter_args = parser.parse_known_args() 322 323 def _convert_argval(value): 324 # The values for the parsed arguments need to be in the expected input structure as provided. 325 # i.e. integers and floats should be in their original types. 326 try: 327 return int(value) 328 except ValueError: 329 return float(value) 330 331 # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to 332 # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS) 333 extra_kwargs = { 334 parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2) 335 } 336 337 # Separate extra arguments as per where they should be passed in the automatic segmentation class. 338 # This is done to ensure the extra arguments are allocated to the desired location. 339 # eg. for AMG, 'points_per_side' is expected by '__init__', 340 # and 'stability_score_thresh' is expected in 'generate' method. 341 amg_class = AutomaticMaskGenerator if args.tile_shape is None else TiledAutomaticMaskGenerator 342 amg_kwargs, generate_kwargs = split_kwargs(amg_class, **extra_kwargs) 343 344 # Validate for the expected automatic segmentation mode. 345 # By default, it is set to 'auto', i.e. searches for the decoder state to prioritize AIS for finetuned models. 346 # Otherwise, runs AMG for all models in any case. 347 amg = None 348 if args.mode != "auto": 349 assert args.mode in ["ais", "amg"], \ 350 f"'{args.mode}' is not a valid automatic segmentation mode. Please choose either 'amg' or 'ais'." 351 amg = (args.mode == "amg") 352 353 predictor, segmenter = get_predictor_and_segmenter( 354 model_type=args.model_type, 355 checkpoint=args.checkpoint, 356 device=args.device, 357 amg=amg, 358 is_tiled=args.tile_shape is not None, 359 **amg_kwargs, 360 ) 361 362 # Get the filepaths to input images (and other paths to store stuff, eg. segmentations and embeddings) 363 # Check whether the inputs are as expected, otherwise assort them. 364 input_paths = _get_inputs_from_paths(args.input_path, args.pattern) 365 assert len(input_paths) > 0, "'micro-sam' could not extract any image data internally." 366 367 output_path = args.output_path 368 embedding_path = args.embedding_path 369 has_one_input = len(input_paths) == 1 370 371 # Run automatic segmentation per image. 372 for path in tqdm(input_paths, desc="Run automatic segmentation"): 373 if has_one_input: # if we have one image only. 374 _output_fpath = str(Path(output_path).with_suffix(".tif")) 375 _embedding_fpath = embedding_path 376 377 else: # if we have multiple image, we need to make the other target filepaths compatible. 378 # Let's check for 'embedding_path'. 379 _embedding_fpath = embedding_path 380 if embedding_path: 381 if _has_extension(embedding_path): # in this case, use filename as addl. suffix to provided path. 382 _embedding_fpath = str(Path(embedding_path).with_suffix(".zarr")) 383 _embedding_fpath = _embedding_fpath.replace(".zarr", f"_{Path(path).stem}.zarr") 384 else: # otherwise, for directory, use image filename for multiple images. 385 os.makedirs(embedding_path, exist_ok=True) 386 _embedding_fpath = os.path.join(embedding_path, Path(os.path.basename(path)).with_suffix(".zarr")) 387 388 # Next, let's check for output file to store segmentation. 389 if _has_extension(output_path): # in this case, use filename as addl. suffix to provided path. 390 _output_fpath = str(Path(output_path).with_suffix(".tif")) 391 _output_fpath = _output_fpath.replace(".tif", f"_{Path(path).stem}.tif") 392 else: # otherwise, for directory, use image filename for multiple images. 393 os.makedirs(output_path, exist_ok=True) 394 _output_fpath = os.path.join(output_path, Path(os.path.basename(path)).with_suffix(".tif")) 395 396 automatic_instance_segmentation( 397 predictor=predictor, 398 segmenter=segmenter, 399 input_path=path, 400 output_path=_output_fpath, 401 embedding_path=_embedding_fpath, 402 key=args.key, 403 ndim=args.ndim, 404 tile_shape=args.tile_shape, 405 halo=args.halo, 406 annotate=args.annotate, 407 verbose=args.verbose, 408 batch_size=args.batch_size, 409 **generate_kwargs, 410 )
def
get_predictor_and_segmenter( model_type: str, checkpoint: Union[str, os.PathLike, NoneType] = None, device: str = None, amg: Optional[bool] = None, is_tiled: bool = False, **kwargs) -> Tuple[mobile_sam.predictor.SamPredictor, Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder]]:
21def get_predictor_and_segmenter( 22 model_type: str, 23 checkpoint: Optional[Union[os.PathLike, str]] = None, 24 device: str = None, 25 amg: Optional[bool] = None, 26 is_tiled: bool = False, 27 **kwargs, 28) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: 29 """Get the Segment Anything model and class for automatic instance segmentation. 30 31 Args: 32 model_type: The Segment Anything model choice. 33 checkpoint: The filepath to the stored model checkpoints. 34 device: The torch device. 35 amg: Whether to perform automatic segmentation in AMG mode. 36 Otherwise AIS will be used, which requires a special segmentation decoder. 37 If not specified AIS will be used if it is available and otherwise AMG will be used. 38 is_tiled: Whether to return segmenter for performing segmentation in tiling window style. 39 kwargs: Keyword arguments for the automatic mask generation class. 40 41 Returns: 42 The Segment Anything model. 43 The automatic instance segmentation class. 44 """ 45 # Get the device 46 device = util.get_device(device=device) 47 48 # Get the predictor and state for Segment Anything models. 49 predictor, state = util.get_sam_model( 50 model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True, 51 ) 52 53 if amg is None: 54 amg = "decoder_state" not in state 55 56 if amg: 57 decoder = None 58 else: 59 if "decoder_state" not in state: 60 raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.") 61 decoder_state = state["decoder_state"] 62 decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) 63 64 segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs) 65 66 return predictor, segmenter
Get the Segment Anything model and class for automatic instance segmentation.
Arguments:
- model_type: The Segment Anything model choice.
- checkpoint: The filepath to the stored model checkpoints.
- device: The torch device.
- amg: Whether to perform automatic segmentation in AMG mode. Otherwise AIS will be used, which requires a special segmentation decoder. If not specified AIS will be used if it is available and otherwise AMG will be used.
- is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
- kwargs: Keyword arguments for the automatic mask generation class.
Returns:
The Segment Anything model. The automatic instance segmentation class.
def
automatic_instance_segmentation( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[str, os.PathLike, NoneType] = None, embedding_path: Union[str, os.PathLike, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, batch_size: int = 1, **generate_kwargs) -> numpy.ndarray:
75def automatic_instance_segmentation( 76 predictor: util.SamPredictor, 77 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 78 input_path: Union[Union[os.PathLike, str], np.ndarray], 79 output_path: Optional[Union[os.PathLike, str]] = None, 80 embedding_path: Optional[Union[os.PathLike, str]] = None, 81 key: Optional[str] = None, 82 ndim: Optional[int] = None, 83 tile_shape: Optional[Tuple[int, int]] = None, 84 halo: Optional[Tuple[int, int]] = None, 85 verbose: bool = True, 86 return_embeddings: bool = False, 87 annotate: bool = False, 88 batch_size: int = 1, 89 **generate_kwargs 90) -> np.ndarray: 91 """Run automatic segmentation for the input image. 92 93 Args: 94 predictor: The Segment Anything model. 95 segmenter: The automatic instance segmentation class. 96 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 97 or a container file (e.g. hdf5 or zarr). 98 output_path: The output path where the instance segmentations will be saved. 99 embedding_path: The path where the embeddings are cached already / will be saved. 100 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 101 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 102 ndim: The dimensionality of the data. By default the dimensionality of the data will be used. 103 If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. 104 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 105 halo: Overlap of the tiles for tiled prediction. 106 verbose: Verbosity flag. 107 return_embeddings: Whether to return the precomputed image embeddings. 108 annotate: Whether to activate the annotator for continue annotation process. 109 batch_size: The batch size to compute image embeddings over tiles / z-planes. 110 By default, does it sequentially, i.e. one after the other. 111 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 112 113 Returns: 114 The segmentation result. 115 """ 116 # Avoid overwriting already stored segmentations. 117 if output_path is not None: 118 output_path = Path(output_path).with_suffix(".tif") 119 if os.path.exists(output_path): 120 print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.") 121 return 122 123 # Load the input image file. 124 if isinstance(input_path, np.ndarray): 125 image_data = input_path 126 else: 127 image_data = util.load_image_data(input_path, key) 128 129 ndim = image_data.ndim if ndim is None else ndim 130 131 # We perform additional post-processing for AMG-only. 132 # Otherwise, we ignore additional post-processing for AIS. 133 if isinstance(segmenter, InstanceSegmentationWithDecoder): 134 generate_kwargs["output_mode"] = None 135 136 if ndim == 2: 137 if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): 138 raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") 139 140 # Precompute the image embeddings. 141 image_embeddings = util.precompute_image_embeddings( 142 predictor=predictor, 143 input_=image_data, 144 save_path=embedding_path, 145 ndim=ndim, 146 tile_shape=tile_shape, 147 halo=halo, 148 verbose=verbose, 149 batch_size=batch_size, 150 ) 151 initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose) 152 153 # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing. 154 # In this case, we also add the batch size to the initialize kwargs, 155 # so that the segmentation decoder can be applied in a batched fashion. 156 if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None: 157 generate_kwargs.update({"tile_shape": tile_shape, "halo": halo}) 158 initialize_kwargs["batch_size"] = batch_size 159 160 segmenter.initialize(**initialize_kwargs) 161 masks = segmenter.generate(**generate_kwargs) 162 163 if isinstance(masks, list): 164 # whether the predictions from 'generate' are list of dict, 165 # which contains additional info req. for post-processing, eg. area per object. 166 if len(masks) == 0: 167 instances = np.zeros(image_data.shape[:2], dtype="uint32") 168 else: 169 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) 170 else: 171 # if (raw) predictions provided, store them as it is w/o further post-processing. 172 instances = masks 173 174 else: 175 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 176 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 177 178 instances, image_embeddings = automatic_3d_segmentation( 179 volume=image_data, 180 predictor=predictor, 181 segmentor=segmenter, 182 embedding_path=embedding_path, 183 tile_shape=tile_shape, 184 halo=halo, 185 verbose=verbose, 186 return_embeddings=True, 187 batch_size=batch_size, 188 **generate_kwargs 189 ) 190 191 # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage. 192 if output_path is not None: 193 _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path 194 imageio.imwrite(_output_path, instances, compression="zlib") 195 print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.") 196 197 # Allow opening the automatic segmentation in the annotator for further annotation, if desired. 198 if annotate: 199 from micro_sam.sam_annotator import annotator_2d, annotator_3d 200 annotator_function = annotator_2d if ndim == 2 else annotator_3d 201 202 viewer = annotator_function( 203 image=image_data, 204 model_type=predictor.model_name, 205 embedding_path=image_embeddings, # Providing the precomputed image embeddings. 206 segmentation_result=instances, # Initializes the automatic segmentation to the annotator. 207 tile_shape=tile_shape, 208 halo=halo, 209 return_viewer=True, # Returns the viewer, which allows the user to store the updated segmentations. 210 ) 211 212 # Start the GUI here 213 import napari 214 napari.run() 215 216 # We extract the segmentation in "committed_objects" layer, where the user either: 217 # a) Performed interactive segmentation / corrections and committed them, OR 218 # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is. 219 instances = viewer.layers["committed_objects"].data 220 221 # Save the instance segmentation, if 'output_path' provided. 222 if output_path is not None: 223 imageio.imwrite(output_path, instances, compression="zlib") 224 print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.") 225 226 if return_embeddings: 227 return instances, image_embeddings 228 else: 229 return instances
Run automatic segmentation for the input image.
Arguments:
- predictor: The Segment Anything model.
- segmenter: The automatic instance segmentation class.
- input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
- output_path: The output path where the instance segmentations will be saved.
- embedding_path: The path where the embeddings are cached already / will be saved.
- key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
- ndim: The dimensionality of the data. By default the dimensionality of the data will be used. If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction.
- verbose: Verbosity flag.
- return_embeddings: Whether to return the precomputed image embeddings.
- annotate: Whether to activate the annotator for continue annotation process.
- batch_size: The batch size to compute image embeddings over tiles / z-planes. By default, does it sequentially, i.e. one after the other.
- generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:
The segmentation result.