micro_sam.automatic_segmentation
1import os 2from pathlib import Path 3from typing import Optional, Union, Tuple 4 5import numpy as np 6import imageio.v3 as imageio 7 8from torch_em.data.datasets.util import split_kwargs 9 10from . import util 11from .instance_segmentation import ( 12 get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder, 13 AMGBase, AutomaticMaskGenerator, TiledAutomaticMaskGenerator 14) 15from .multi_dimensional_segmentation import automatic_3d_segmentation 16 17 18def get_predictor_and_segmenter( 19 model_type: str, 20 checkpoint: Optional[Union[os.PathLike, str]] = None, 21 device: str = None, 22 amg: Optional[bool] = None, 23 is_tiled: bool = False, 24 **kwargs, 25) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: 26 """Get the Segment Anything model and class for automatic instance segmentation. 27 28 Args: 29 model_type: The Segment Anything model choice. 30 checkpoint: The filepath to the stored model checkpoints. 31 device: The torch device. 32 amg: Whether to perform automatic segmentation in AMG mode. 33 Otherwise AIS will be used, which requires a special segmentation decoder. 34 If not specified AIS will be used if it is available and otherwise AMG will be used. 35 is_tiled: Whether to return segmenter for performing segmentation in tiling window style. 36 kwargs: Keyword arguments for the automatic mask generation class. 37 38 Returns: 39 The Segment Anything model. 40 The automatic instance segmentation class. 41 """ 42 # Get the device 43 device = util.get_device(device=device) 44 45 # Get the predictor and state for Segment Anything models. 46 predictor, state = util.get_sam_model( 47 model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True, 48 ) 49 50 if amg is None: 51 amg = "decoder_state" not in state 52 53 if amg: 54 decoder = None 55 else: 56 if "decoder_state" not in state: 57 raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.") 58 decoder_state = state["decoder_state"] 59 decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) 60 61 segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs) 62 63 return predictor, segmenter 64 65 66def automatic_instance_segmentation( 67 predictor: util.SamPredictor, 68 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 69 input_path: Union[Union[os.PathLike, str], np.ndarray], 70 output_path: Optional[Union[os.PathLike, str]] = None, 71 embedding_path: Optional[Union[os.PathLike, str]] = None, 72 key: Optional[str] = None, 73 ndim: Optional[int] = None, 74 tile_shape: Optional[Tuple[int, int]] = None, 75 halo: Optional[Tuple[int, int]] = None, 76 verbose: bool = True, 77 **generate_kwargs 78) -> np.ndarray: 79 """Run automatic segmentation for the input image. 80 81 Args: 82 predictor: The Segment Anything model. 83 segmenter: The automatic instance segmentation class. 84 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 85 or a container file (e.g. hdf5 or zarr). 86 output_path: The output path where the instance segmentations will be saved. 87 embedding_path: The path where the embeddings are cached already / will be saved. 88 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 89 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 90 ndim: The dimensionality of the data. By default the dimensionality of the data will be used. 91 If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. 92 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 93 halo: Overlap of the tiles for tiled prediction. 94 verbose: Verbosity flag. 95 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 96 97 Returns: 98 The segmentation result. 99 """ 100 # Load the input image file. 101 if isinstance(input_path, np.ndarray): 102 image_data = input_path 103 else: 104 image_data = util.load_image_data(input_path, key) 105 106 ndim = image_data.ndim if ndim is None else ndim 107 108 if ndim == 2: 109 if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): 110 raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") 111 112 # Precompute the image embeddings. 113 image_embeddings = util.precompute_image_embeddings( 114 predictor=predictor, 115 input_=image_data, 116 save_path=embedding_path, 117 ndim=ndim, 118 tile_shape=tile_shape, 119 halo=halo, 120 verbose=verbose, 121 ) 122 123 segmenter.initialize(image=image_data, image_embeddings=image_embeddings) 124 masks = segmenter.generate(**generate_kwargs) 125 126 if len(masks) == 0: # instance segmentation can have no masks, hence we just save empty labels 127 if isinstance(segmenter, InstanceSegmentationWithDecoder): 128 this_shape = segmenter._foreground.shape 129 elif isinstance(segmenter, AMGBase): 130 this_shape = segmenter._original_size 131 else: 132 this_shape = image_data.shape[-2:] 133 134 instances = np.zeros(this_shape, dtype="uint32") 135 else: 136 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) 137 138 else: 139 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 140 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 141 142 instances = automatic_3d_segmentation( 143 volume=image_data, 144 predictor=predictor, 145 segmentor=segmenter, 146 embedding_path=embedding_path, 147 tile_shape=tile_shape, 148 halo=halo, 149 verbose=verbose, 150 **generate_kwargs 151 ) 152 153 if output_path is not None: 154 # Save the instance segmentation 155 output_path = Path(output_path).with_suffix(".tif") 156 imageio.imwrite(output_path, instances, compression="zlib") 157 158 return instances 159 160 161def main(): 162 """@private""" 163 import argparse 164 165 available_models = list(util.get_model_names()) 166 available_models = ", ".join(available_models) 167 168 parser = argparse.ArgumentParser(description="Run automatic segmentation for an image.") 169 parser.add_argument( 170 "-i", "--input_path", required=True, 171 help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " 172 "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter." 173 ) 174 parser.add_argument( 175 "-o", "--output_path", required=True, 176 help="The filepath to store the instance segmentation. The current support stores segmentation in a 'tif' file." 177 ) 178 parser.add_argument( 179 "-e", "--embedding_path", default=None, type=str, help="The path where the embeddings will be saved." 180 ) 181 parser.add_argument( 182 "--pattern", help="Pattern / wildcard for selecting files in a folder. To select all files use '*'." 183 ) 184 parser.add_argument( 185 "-k", "--key", 186 help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, " 187 "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'." 188 ) 189 parser.add_argument( 190 "-m", "--model_type", default=util._DEFAULT_MODEL, 191 help=f"The segment anything model that will be used, one of {available_models}." 192 ) 193 parser.add_argument( 194 "-c", "--checkpoint", default=None, 195 help="Checkpoint from which the SAM model will be loaded." 196 ) 197 parser.add_argument( 198 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None 199 ) 200 parser.add_argument( 201 "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None 202 ) 203 parser.add_argument( 204 "-n", "--ndim", type=int, default=None, 205 help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension." 206 ) 207 parser.add_argument( 208 "--mode", type=str, default=None, 209 help="The choice of automatic segmentation with the Segment Anything models. Either 'amg' or 'ais'." 210 ) 211 parser.add_argument( 212 "-d", "--device", default=None, 213 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 214 "By default the most performant available device will be selected." 215 ) 216 217 args, parameter_args = parser.parse_known_args() 218 219 def _convert_argval(value): 220 # The values for the parsed arguments need to be in the expected input structure as provided. 221 # i.e. integers and floats should be in their original types. 222 try: 223 return int(value) 224 except ValueError: 225 return float(value) 226 227 # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to 228 # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS) 229 extra_kwargs = { 230 parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2) 231 } 232 233 # Separate extra arguments as per where they should be passed in the automatic segmentation class. 234 # This is done to ensure the extra arguments are allocated to the desired location. 235 # eg. for AMG, 'points_per_side' is expected by '__init__', 236 # and 'stability_score_thresh' is expected in 'generate' method. 237 amg_class = AutomaticMaskGenerator if args.tile_shape is None else TiledAutomaticMaskGenerator 238 amg_kwargs, generate_kwargs = split_kwargs(amg_class, **extra_kwargs) 239 240 # Validate for the expected automatic segmentation mode. 241 # By default, it is set to 'None', i.e. searches for the decoder state to prioritize AIS for finetuned models. 242 # Otherwise, runs AMG for all models in any case. 243 amg = None 244 if args.mode is not None: 245 assert args.mode in ["ais", "amg"], \ 246 f"'{args.mode}' is not a valid automatic segmentation mode. Please choose either 'amg' or 'ais'." 247 amg = (args.mode == "amg") 248 249 predictor, segmenter = get_predictor_and_segmenter( 250 model_type=args.model_type, 251 checkpoint=args.checkpoint, 252 device=args.device, 253 amg=amg, 254 is_tiled=args.tile_shape is not None, 255 **amg_kwargs, 256 ) 257 258 automatic_instance_segmentation( 259 predictor=predictor, 260 segmenter=segmenter, 261 input_path=args.input_path, 262 output_path=args.output_path, 263 embedding_path=args.embedding_path, 264 key=args.key, 265 ndim=args.ndim, 266 tile_shape=args.tile_shape, 267 halo=args.halo, 268 **generate_kwargs, 269 ) 270 271 272if __name__ == "__main__": 273 main()
def
get_predictor_and_segmenter( model_type: str, checkpoint: Union[os.PathLike, str, NoneType] = None, device: str = None, amg: Optional[bool] = None, is_tiled: bool = False, **kwargs) -> Tuple[mobile_sam.predictor.SamPredictor, Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder]]:
19def get_predictor_and_segmenter( 20 model_type: str, 21 checkpoint: Optional[Union[os.PathLike, str]] = None, 22 device: str = None, 23 amg: Optional[bool] = None, 24 is_tiled: bool = False, 25 **kwargs, 26) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: 27 """Get the Segment Anything model and class for automatic instance segmentation. 28 29 Args: 30 model_type: The Segment Anything model choice. 31 checkpoint: The filepath to the stored model checkpoints. 32 device: The torch device. 33 amg: Whether to perform automatic segmentation in AMG mode. 34 Otherwise AIS will be used, which requires a special segmentation decoder. 35 If not specified AIS will be used if it is available and otherwise AMG will be used. 36 is_tiled: Whether to return segmenter for performing segmentation in tiling window style. 37 kwargs: Keyword arguments for the automatic mask generation class. 38 39 Returns: 40 The Segment Anything model. 41 The automatic instance segmentation class. 42 """ 43 # Get the device 44 device = util.get_device(device=device) 45 46 # Get the predictor and state for Segment Anything models. 47 predictor, state = util.get_sam_model( 48 model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True, 49 ) 50 51 if amg is None: 52 amg = "decoder_state" not in state 53 54 if amg: 55 decoder = None 56 else: 57 if "decoder_state" not in state: 58 raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.") 59 decoder_state = state["decoder_state"] 60 decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) 61 62 segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs) 63 64 return predictor, segmenter
Get the Segment Anything model and class for automatic instance segmentation.
Arguments:
- model_type: The Segment Anything model choice.
- checkpoint: The filepath to the stored model checkpoints.
- device: The torch device.
- amg: Whether to perform automatic segmentation in AMG mode. Otherwise AIS will be used, which requires a special segmentation decoder. If not specified AIS will be used if it is available and otherwise AMG will be used.
- is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
- kwargs: Keyword arguments for the automatic mask generation class.
Returns:
The Segment Anything model. The automatic instance segmentation class.
def
automatic_instance_segmentation( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, **generate_kwargs) -> numpy.ndarray:
67def automatic_instance_segmentation( 68 predictor: util.SamPredictor, 69 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 70 input_path: Union[Union[os.PathLike, str], np.ndarray], 71 output_path: Optional[Union[os.PathLike, str]] = None, 72 embedding_path: Optional[Union[os.PathLike, str]] = None, 73 key: Optional[str] = None, 74 ndim: Optional[int] = None, 75 tile_shape: Optional[Tuple[int, int]] = None, 76 halo: Optional[Tuple[int, int]] = None, 77 verbose: bool = True, 78 **generate_kwargs 79) -> np.ndarray: 80 """Run automatic segmentation for the input image. 81 82 Args: 83 predictor: The Segment Anything model. 84 segmenter: The automatic instance segmentation class. 85 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 86 or a container file (e.g. hdf5 or zarr). 87 output_path: The output path where the instance segmentations will be saved. 88 embedding_path: The path where the embeddings are cached already / will be saved. 89 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 90 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 91 ndim: The dimensionality of the data. By default the dimensionality of the data will be used. 92 If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. 93 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 94 halo: Overlap of the tiles for tiled prediction. 95 verbose: Verbosity flag. 96 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 97 98 Returns: 99 The segmentation result. 100 """ 101 # Load the input image file. 102 if isinstance(input_path, np.ndarray): 103 image_data = input_path 104 else: 105 image_data = util.load_image_data(input_path, key) 106 107 ndim = image_data.ndim if ndim is None else ndim 108 109 if ndim == 2: 110 if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): 111 raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") 112 113 # Precompute the image embeddings. 114 image_embeddings = util.precompute_image_embeddings( 115 predictor=predictor, 116 input_=image_data, 117 save_path=embedding_path, 118 ndim=ndim, 119 tile_shape=tile_shape, 120 halo=halo, 121 verbose=verbose, 122 ) 123 124 segmenter.initialize(image=image_data, image_embeddings=image_embeddings) 125 masks = segmenter.generate(**generate_kwargs) 126 127 if len(masks) == 0: # instance segmentation can have no masks, hence we just save empty labels 128 if isinstance(segmenter, InstanceSegmentationWithDecoder): 129 this_shape = segmenter._foreground.shape 130 elif isinstance(segmenter, AMGBase): 131 this_shape = segmenter._original_size 132 else: 133 this_shape = image_data.shape[-2:] 134 135 instances = np.zeros(this_shape, dtype="uint32") 136 else: 137 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) 138 139 else: 140 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 141 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 142 143 instances = automatic_3d_segmentation( 144 volume=image_data, 145 predictor=predictor, 146 segmentor=segmenter, 147 embedding_path=embedding_path, 148 tile_shape=tile_shape, 149 halo=halo, 150 verbose=verbose, 151 **generate_kwargs 152 ) 153 154 if output_path is not None: 155 # Save the instance segmentation 156 output_path = Path(output_path).with_suffix(".tif") 157 imageio.imwrite(output_path, instances, compression="zlib") 158 159 return instances
Run automatic segmentation for the input image.
Arguments:
- predictor: The Segment Anything model.
- segmenter: The automatic instance segmentation class.
- input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
- output_path: The output path where the instance segmentations will be saved.
- embedding_path: The path where the embeddings are cached already / will be saved.
- key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
- ndim: The dimensionality of the data. By default the dimensionality of the data will be used. If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction.
- verbose: Verbosity flag.
- generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:
The segmentation result.