micro_sam.automatic_segmentation
1import os 2from pathlib import Path 3from typing import Optional, Union, Tuple 4 5import numpy as np 6import imageio.v3 as imageio 7 8from torch_em.data.datasets.util import split_kwargs 9 10from . import util 11from .instance_segmentation import ( 12 get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder, 13 AMGBase, AutomaticMaskGenerator, TiledAutomaticMaskGenerator 14) 15from .multi_dimensional_segmentation import automatic_3d_segmentation 16 17 18def get_predictor_and_segmenter( 19 model_type: str, 20 checkpoint: Optional[Union[os.PathLike, str]] = None, 21 device: str = None, 22 amg: Optional[bool] = None, 23 is_tiled: bool = False, 24 **kwargs, 25) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: 26 """Get the Segment Anything model and class for automatic instance segmentation. 27 28 Args: 29 model_type: The Segment Anything model choice. 30 checkpoint: The filepath to the stored model checkpoints. 31 device: The torch device. 32 amg: Whether to perform automatic segmentation in AMG mode. 33 Otherwise AIS will be used, which requires a special segmentation decoder. 34 If not specified AIS will be used if it is available and otherwise AMG will be used. 35 is_tiled: Whether to return segmenter for performing segmentation in tiling window style. 36 kwargs: Keyword arguments for the automatic mask generation class. 37 38 Returns: 39 The Segment Anything model. 40 The automatic instance segmentation class. 41 """ 42 # Get the device 43 device = util.get_device(device=device) 44 45 # Get the predictor and state for Segment Anything models. 46 predictor, state = util.get_sam_model( 47 model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True, 48 ) 49 50 if amg is None: 51 amg = "decoder_state" not in state 52 53 if amg: 54 decoder = None 55 else: 56 if "decoder_state" not in state: 57 raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.") 58 decoder_state = state["decoder_state"] 59 decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) 60 61 segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs) 62 63 return predictor, segmenter 64 65 66def automatic_instance_segmentation( 67 predictor: util.SamPredictor, 68 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 69 input_path: Union[Union[os.PathLike, str], np.ndarray], 70 output_path: Optional[Union[os.PathLike, str]] = None, 71 embedding_path: Optional[Union[os.PathLike, str]] = None, 72 key: Optional[str] = None, 73 ndim: Optional[int] = None, 74 tile_shape: Optional[Tuple[int, int]] = None, 75 halo: Optional[Tuple[int, int]] = None, 76 verbose: bool = True, 77 return_embeddings: bool = False, 78 annotate: bool = False, 79 **generate_kwargs 80) -> np.ndarray: 81 """Run automatic segmentation for the input image. 82 83 Args: 84 predictor: The Segment Anything model. 85 segmenter: The automatic instance segmentation class. 86 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 87 or a container file (e.g. hdf5 or zarr). 88 output_path: The output path where the instance segmentations will be saved. 89 embedding_path: The path where the embeddings are cached already / will be saved. 90 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 91 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 92 ndim: The dimensionality of the data. By default the dimensionality of the data will be used. 93 If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. 94 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 95 halo: Overlap of the tiles for tiled prediction. 96 verbose: Verbosity flag. 97 return_embeddings: Whether to return the precomputed image embeddings. 98 annotate: Whether to activate the annotator for continue annotation process. 99 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 100 101 Returns: 102 The segmentation result. 103 """ 104 # Avoid overwriting already stored segmentations. 105 if output_path is not None: 106 output_path = Path(output_path).with_suffix(".tif") 107 if os.path.exists(output_path): 108 print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.") 109 return 110 111 # Load the input image file. 112 if isinstance(input_path, np.ndarray): 113 image_data = input_path 114 else: 115 image_data = util.load_image_data(input_path, key) 116 117 ndim = image_data.ndim if ndim is None else ndim 118 119 if ndim == 2: 120 if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): 121 raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") 122 123 # Precompute the image embeddings. 124 image_embeddings = util.precompute_image_embeddings( 125 predictor=predictor, 126 input_=image_data, 127 save_path=embedding_path, 128 ndim=ndim, 129 tile_shape=tile_shape, 130 halo=halo, 131 verbose=verbose, 132 ) 133 134 # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing. 135 if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None: 136 generate_kwargs.update({"tile_shape": tile_shape, "halo": halo}) 137 138 segmenter.initialize(image=image_data, image_embeddings=image_embeddings, verbose=verbose) 139 masks = segmenter.generate(**generate_kwargs) 140 141 if isinstance(masks, list): 142 # whether the predictions from 'generate' are list of dict, 143 # which contains additional info req. for post-processing, eg. area per object. 144 if len(masks) == 0: 145 instances = np.zeros(image_data.shape[:2], dtype="uint32") 146 else: 147 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) 148 else: 149 # if (raw) predictions provided, store them as it is w/o further post-processing. 150 instances = masks 151 152 else: 153 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 154 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 155 156 outputs = automatic_3d_segmentation( 157 volume=image_data, 158 predictor=predictor, 159 segmentor=segmenter, 160 embedding_path=embedding_path, 161 tile_shape=tile_shape, 162 halo=halo, 163 verbose=verbose, 164 return_embeddings=return_embeddings, 165 **generate_kwargs 166 ) 167 168 if return_embeddings: 169 instances, image_embeddings = outputs 170 else: 171 instances = outputs 172 173 # Allow opening the automatic segmentation in the annotator for further annotation, if desired. 174 if annotate: 175 from micro_sam.sam_annotator import annotator_2d, annotator_3d 176 annotator_function = annotator_2d if ndim == 2 else annotator_3d 177 178 viewer = annotator_function( 179 image=image_data, 180 model_type=predictor.model_name, 181 embedding_path=embedding_path, 182 segmentation_result=instances, # Initializes the automatic segmentation to the annotator. 183 tile_shape=tile_shape, 184 halo=halo, 185 return_viewer=True, # Returns the viewer, which allows the user to store the updated segmentations. 186 ) 187 188 # Start the GUI here 189 import napari 190 napari.run() 191 192 # We extract the segmentation in "committed_objects" layer, where the user either: 193 # a) Performed interactive segmentation / corrections and committed them, OR 194 # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is. 195 instances = viewer.layers["committed_objects"].data 196 197 # Save the instance segmentation, if 'output_path' provided. 198 if output_path is not None: 199 imageio.imwrite(output_path, instances, compression="zlib") 200 print(f"The segmentation results are stored at '{os.path.abspath(output_path)}'.") 201 202 if return_embeddings: 203 return instances, image_embeddings 204 else: 205 return instances 206 207 208def main(): 209 """@private""" 210 import argparse 211 212 available_models = list(util.get_model_names()) 213 available_models = ", ".join(available_models) 214 215 parser = argparse.ArgumentParser(description="Run automatic segmentation for an image.") 216 parser.add_argument( 217 "-i", "--input_path", required=True, 218 help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " 219 "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter." 220 ) 221 parser.add_argument( 222 "-o", "--output_path", required=True, 223 help="The filepath to store the instance segmentation. The current support stores segmentation in a 'tif' file." 224 ) 225 parser.add_argument( 226 "-e", "--embedding_path", default=None, type=str, help="The path where the embeddings will be saved." 227 ) 228 parser.add_argument( 229 "--pattern", help="Pattern / wildcard for selecting files in a folder. To select all files use '*'." 230 ) 231 parser.add_argument( 232 "-k", "--key", 233 help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, " 234 "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'." 235 ) 236 parser.add_argument( 237 "-m", "--model_type", default=util._DEFAULT_MODEL, 238 help=f"The segment anything model that will be used, one of {available_models}." 239 ) 240 parser.add_argument( 241 "-c", "--checkpoint", default=None, help="Checkpoint from which the SAM model will be loaded." 242 ) 243 parser.add_argument( 244 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None 245 ) 246 parser.add_argument( 247 "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None 248 ) 249 parser.add_argument( 250 "-n", "--ndim", type=int, default=None, 251 help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension." 252 ) 253 parser.add_argument( 254 "--mode", type=str, default="auto", 255 help="The choice of automatic segmentation with the Segment Anything models. Either 'auto', 'amg' or 'ais'." 256 ) 257 parser.add_argument( 258 "--annotate", action="store_true", 259 help="Whether to continue annotation after the automatic segmentation is generated." 260 ) 261 parser.add_argument( 262 "-d", "--device", default=None, 263 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 264 "By default the most performant available device will be selected." 265 ) 266 parser.add_argument( 267 "-v", "--verbose", action="store_true", help="Whether to allow verbosity of outputs." 268 ) 269 270 args, parameter_args = parser.parse_known_args() 271 272 def _convert_argval(value): 273 # The values for the parsed arguments need to be in the expected input structure as provided. 274 # i.e. integers and floats should be in their original types. 275 try: 276 return int(value) 277 except ValueError: 278 return float(value) 279 280 # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to 281 # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS) 282 extra_kwargs = { 283 parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2) 284 } 285 286 # Separate extra arguments as per where they should be passed in the automatic segmentation class. 287 # This is done to ensure the extra arguments are allocated to the desired location. 288 # eg. for AMG, 'points_per_side' is expected by '__init__', 289 # and 'stability_score_thresh' is expected in 'generate' method. 290 amg_class = AutomaticMaskGenerator if args.tile_shape is None else TiledAutomaticMaskGenerator 291 amg_kwargs, generate_kwargs = split_kwargs(amg_class, **extra_kwargs) 292 293 # Validate for the expected automatic segmentation mode. 294 # By default, it is set to 'auto', i.e. searches for the decoder state to prioritize AIS for finetuned models. 295 # Otherwise, runs AMG for all models in any case. 296 amg = None 297 if args.mode != "auto": 298 assert args.mode in ["ais", "amg"], \ 299 f"'{args.mode}' is not a valid automatic segmentation mode. Please choose either 'amg' or 'ais'." 300 amg = (args.mode == "amg") 301 302 predictor, segmenter = get_predictor_and_segmenter( 303 model_type=args.model_type, 304 checkpoint=args.checkpoint, 305 device=args.device, 306 amg=amg, 307 is_tiled=args.tile_shape is not None, 308 **amg_kwargs, 309 ) 310 311 # We perform additional post-processing for AMG-only. 312 # Otherwise, we ignore additional post-processing for AIS. 313 if isinstance(segmenter, InstanceSegmentationWithDecoder): 314 generate_kwargs["output_mode"] = None 315 316 automatic_instance_segmentation( 317 predictor=predictor, 318 segmenter=segmenter, 319 input_path=args.input_path, 320 output_path=args.output_path, 321 embedding_path=args.embedding_path, 322 key=args.key, 323 ndim=args.ndim, 324 tile_shape=args.tile_shape, 325 halo=args.halo, 326 annotate=args.annotate, 327 verbose=args.verbose, 328 **generate_kwargs, 329 ) 330 331 332if __name__ == "__main__": 333 main()
def
get_predictor_and_segmenter( model_type: str, checkpoint: Union[os.PathLike, str, NoneType] = None, device: str = None, amg: Optional[bool] = None, is_tiled: bool = False, **kwargs) -> Tuple[mobile_sam.predictor.SamPredictor, Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder]]:
19def get_predictor_and_segmenter( 20 model_type: str, 21 checkpoint: Optional[Union[os.PathLike, str]] = None, 22 device: str = None, 23 amg: Optional[bool] = None, 24 is_tiled: bool = False, 25 **kwargs, 26) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: 27 """Get the Segment Anything model and class for automatic instance segmentation. 28 29 Args: 30 model_type: The Segment Anything model choice. 31 checkpoint: The filepath to the stored model checkpoints. 32 device: The torch device. 33 amg: Whether to perform automatic segmentation in AMG mode. 34 Otherwise AIS will be used, which requires a special segmentation decoder. 35 If not specified AIS will be used if it is available and otherwise AMG will be used. 36 is_tiled: Whether to return segmenter for performing segmentation in tiling window style. 37 kwargs: Keyword arguments for the automatic mask generation class. 38 39 Returns: 40 The Segment Anything model. 41 The automatic instance segmentation class. 42 """ 43 # Get the device 44 device = util.get_device(device=device) 45 46 # Get the predictor and state for Segment Anything models. 47 predictor, state = util.get_sam_model( 48 model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True, 49 ) 50 51 if amg is None: 52 amg = "decoder_state" not in state 53 54 if amg: 55 decoder = None 56 else: 57 if "decoder_state" not in state: 58 raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.") 59 decoder_state = state["decoder_state"] 60 decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) 61 62 segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs) 63 64 return predictor, segmenter
Get the Segment Anything model and class for automatic instance segmentation.
Arguments:
- model_type: The Segment Anything model choice.
- checkpoint: The filepath to the stored model checkpoints.
- device: The torch device.
- amg: Whether to perform automatic segmentation in AMG mode. Otherwise AIS will be used, which requires a special segmentation decoder. If not specified AIS will be used if it is available and otherwise AMG will be used.
- is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
- kwargs: Keyword arguments for the automatic mask generation class.
Returns:
The Segment Anything model. The automatic instance segmentation class.
def
automatic_instance_segmentation( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, **generate_kwargs) -> numpy.ndarray:
67def automatic_instance_segmentation( 68 predictor: util.SamPredictor, 69 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 70 input_path: Union[Union[os.PathLike, str], np.ndarray], 71 output_path: Optional[Union[os.PathLike, str]] = None, 72 embedding_path: Optional[Union[os.PathLike, str]] = None, 73 key: Optional[str] = None, 74 ndim: Optional[int] = None, 75 tile_shape: Optional[Tuple[int, int]] = None, 76 halo: Optional[Tuple[int, int]] = None, 77 verbose: bool = True, 78 return_embeddings: bool = False, 79 annotate: bool = False, 80 **generate_kwargs 81) -> np.ndarray: 82 """Run automatic segmentation for the input image. 83 84 Args: 85 predictor: The Segment Anything model. 86 segmenter: The automatic instance segmentation class. 87 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 88 or a container file (e.g. hdf5 or zarr). 89 output_path: The output path where the instance segmentations will be saved. 90 embedding_path: The path where the embeddings are cached already / will be saved. 91 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 92 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 93 ndim: The dimensionality of the data. By default the dimensionality of the data will be used. 94 If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. 95 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 96 halo: Overlap of the tiles for tiled prediction. 97 verbose: Verbosity flag. 98 return_embeddings: Whether to return the precomputed image embeddings. 99 annotate: Whether to activate the annotator for continue annotation process. 100 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 101 102 Returns: 103 The segmentation result. 104 """ 105 # Avoid overwriting already stored segmentations. 106 if output_path is not None: 107 output_path = Path(output_path).with_suffix(".tif") 108 if os.path.exists(output_path): 109 print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.") 110 return 111 112 # Load the input image file. 113 if isinstance(input_path, np.ndarray): 114 image_data = input_path 115 else: 116 image_data = util.load_image_data(input_path, key) 117 118 ndim = image_data.ndim if ndim is None else ndim 119 120 if ndim == 2: 121 if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): 122 raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") 123 124 # Precompute the image embeddings. 125 image_embeddings = util.precompute_image_embeddings( 126 predictor=predictor, 127 input_=image_data, 128 save_path=embedding_path, 129 ndim=ndim, 130 tile_shape=tile_shape, 131 halo=halo, 132 verbose=verbose, 133 ) 134 135 # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing. 136 if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None: 137 generate_kwargs.update({"tile_shape": tile_shape, "halo": halo}) 138 139 segmenter.initialize(image=image_data, image_embeddings=image_embeddings, verbose=verbose) 140 masks = segmenter.generate(**generate_kwargs) 141 142 if isinstance(masks, list): 143 # whether the predictions from 'generate' are list of dict, 144 # which contains additional info req. for post-processing, eg. area per object. 145 if len(masks) == 0: 146 instances = np.zeros(image_data.shape[:2], dtype="uint32") 147 else: 148 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) 149 else: 150 # if (raw) predictions provided, store them as it is w/o further post-processing. 151 instances = masks 152 153 else: 154 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 155 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 156 157 outputs = automatic_3d_segmentation( 158 volume=image_data, 159 predictor=predictor, 160 segmentor=segmenter, 161 embedding_path=embedding_path, 162 tile_shape=tile_shape, 163 halo=halo, 164 verbose=verbose, 165 return_embeddings=return_embeddings, 166 **generate_kwargs 167 ) 168 169 if return_embeddings: 170 instances, image_embeddings = outputs 171 else: 172 instances = outputs 173 174 # Allow opening the automatic segmentation in the annotator for further annotation, if desired. 175 if annotate: 176 from micro_sam.sam_annotator import annotator_2d, annotator_3d 177 annotator_function = annotator_2d if ndim == 2 else annotator_3d 178 179 viewer = annotator_function( 180 image=image_data, 181 model_type=predictor.model_name, 182 embedding_path=embedding_path, 183 segmentation_result=instances, # Initializes the automatic segmentation to the annotator. 184 tile_shape=tile_shape, 185 halo=halo, 186 return_viewer=True, # Returns the viewer, which allows the user to store the updated segmentations. 187 ) 188 189 # Start the GUI here 190 import napari 191 napari.run() 192 193 # We extract the segmentation in "committed_objects" layer, where the user either: 194 # a) Performed interactive segmentation / corrections and committed them, OR 195 # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is. 196 instances = viewer.layers["committed_objects"].data 197 198 # Save the instance segmentation, if 'output_path' provided. 199 if output_path is not None: 200 imageio.imwrite(output_path, instances, compression="zlib") 201 print(f"The segmentation results are stored at '{os.path.abspath(output_path)}'.") 202 203 if return_embeddings: 204 return instances, image_embeddings 205 else: 206 return instances
Run automatic segmentation for the input image.
Arguments:
- predictor: The Segment Anything model.
- segmenter: The automatic instance segmentation class.
- input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
- output_path: The output path where the instance segmentations will be saved.
- embedding_path: The path where the embeddings are cached already / will be saved.
- key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
- ndim: The dimensionality of the data. By default the dimensionality of the data will be used. If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction.
- verbose: Verbosity flag.
- return_embeddings: Whether to return the precomputed image embeddings.
- annotate: Whether to activate the annotator for continue annotation process.
- generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:
The segmentation result.