micro_sam.automatic_segmentation
1import os 2from pathlib import Path 3from typing import Optional, Union, Tuple, Dict 4 5import numpy as np 6import imageio.v3 as imageio 7 8from . import util 9from .instance_segmentation import ( 10 get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder, AMGBase 11) 12from .multi_dimensional_segmentation import automatic_3d_segmentation 13 14 15def get_predictor_and_segmenter( 16 model_type: str, 17 checkpoint: Optional[Union[os.PathLike, str]] = None, 18 device: str = None, 19 amg: Optional[bool] = None, 20 is_tiled: bool = False, 21 **kwargs, 22) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: 23 """Get the Segment Anything model and class for automatic instance segmentation. 24 25 Args: 26 model_type: The Segment Anything model choice. 27 checkpoint: The filepath to the stored model checkpoints. 28 device: The torch device. 29 amg: Whether to perform automatic segmentation in AMG mode. 30 Otherwise AIS will be used, which requires a special segmentation decoder. 31 If not specified AIS will be used if it is available and otherwise AMG will be used. 32 is_tiled: Whether to return segmenter for performing segmentation in tiling window style. 33 kwargs: Keyword arguments for the automatic instance segmentation class. 34 35 Returns: 36 The Segment Anything model. 37 The automatic instance segmentation class. 38 """ 39 # Get the device 40 device = util.get_device(device=device) 41 42 # Get the predictor and state for Segment Anything models. 43 predictor, state = util.get_sam_model( 44 model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True, 45 ) 46 47 if amg is None: 48 amg = "decoder_state" not in state 49 if amg: 50 decoder = None 51 else: 52 if "decoder_state" not in state: 53 raise RuntimeError("You have passed amg=False, but your model does not contain a segmentation decoder.") 54 decoder_state = state["decoder_state"] 55 decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) 56 57 segmenter = get_amg( 58 predictor=predictor, 59 is_tiled=is_tiled, 60 decoder=decoder, 61 **kwargs 62 ) 63 return predictor, segmenter 64 65 66def automatic_instance_segmentation( 67 predictor: util.SamPredictor, 68 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 69 input_path: Union[Union[os.PathLike, str], np.ndarray], 70 output_path: Optional[Union[os.PathLike, str]] = None, 71 embedding_path: Optional[Union[os.PathLike, str]] = None, 72 key: Optional[str] = None, 73 ndim: Optional[int] = None, 74 tile_shape: Optional[Tuple[int, int]] = None, 75 halo: Optional[Tuple[int, int]] = None, 76 verbose: bool = True, 77 **generate_kwargs 78) -> np.ndarray: 79 """Run automatic segmentation for the input image. 80 81 Args: 82 predictor: The Segment Anything model. 83 segmenter: The automatic instance segmentation class. 84 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 85 or a container file (e.g. hdf5 or zarr). 86 output_path: The output path where the instance segmentations will be saved. 87 embedding_path: The path where the embeddings are cached already / will be saved. 88 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 89 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 90 ndim: The dimensionality of the data. By default the dimensionality of the data will be used. 91 If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. 92 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 93 halo: Overlap of the tiles for tiled prediction. 94 verbose: Verbosity flag. 95 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 96 97 Returns: 98 The segmentation result. 99 """ 100 # Load the input image file. 101 if isinstance(input_path, np.ndarray): 102 image_data = input_path 103 else: 104 image_data = util.load_image_data(input_path, key) 105 106 ndim = image_data.ndim if ndim is None else ndim 107 108 if ndim == 2: 109 if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): 110 raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") 111 112 # Precompute the image embeddings. 113 image_embeddings = util.precompute_image_embeddings( 114 predictor=predictor, 115 input_=image_data, 116 save_path=embedding_path, 117 ndim=ndim, 118 tile_shape=tile_shape, 119 halo=halo, 120 verbose=verbose, 121 ) 122 123 segmenter.initialize(image=image_data, image_embeddings=image_embeddings) 124 masks = segmenter.generate(**generate_kwargs) 125 126 if len(masks) == 0: # instance segmentation can have no masks, hence we just save empty labels 127 if isinstance(segmenter, InstanceSegmentationWithDecoder): 128 this_shape = segmenter._foreground.shape 129 elif isinstance(segmenter, AMGBase): 130 this_shape = segmenter._original_size 131 else: 132 this_shape = image_data.shape[-2:] 133 134 instances = np.zeros(this_shape, dtype="uint32") 135 else: 136 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) 137 else: 138 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 139 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 140 141 instances = automatic_3d_segmentation( 142 volume=image_data, 143 predictor=predictor, 144 segmentor=segmenter, 145 embedding_path=embedding_path, 146 tile_shape=tile_shape, 147 halo=halo, 148 verbose=verbose, 149 **generate_kwargs 150 ) 151 152 if output_path is not None: 153 # Save the instance segmentation 154 output_path = Path(output_path).with_suffix(".tif") 155 imageio.imwrite(output_path, instances, compression="zlib") 156 157 return instances 158 159 160def main(): 161 """@private""" 162 import argparse 163 164 available_models = list(util.get_model_names()) 165 available_models = ", ".join(available_models) 166 167 parser = argparse.ArgumentParser(description="Run automatic segmentation for an image.") 168 parser.add_argument( 169 "-i", "--input_path", required=True, 170 help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " 171 "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter." 172 ) 173 parser.add_argument( 174 "-o", "--output_path", required=True, 175 help="The filepath to store the instance segmentation. The current support stores segmentation in a 'tif' file." 176 ) 177 parser.add_argument( 178 "-e", "--embedding_path", default=None, type=str, help="The path where the embeddings will be saved." 179 ) 180 parser.add_argument( 181 "--pattern", help="Pattern / wildcard for selecting files in a folder. To select all files use '*'." 182 ) 183 parser.add_argument( 184 "-k", "--key", 185 help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, " 186 "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'." 187 ) 188 parser.add_argument( 189 "-m", "--model_type", default=util._DEFAULT_MODEL, 190 help=f"The segment anything model that will be used, one of {available_models}." 191 ) 192 parser.add_argument( 193 "-c", "--checkpoint", default=None, 194 help="Checkpoint from which the SAM model will be loaded loaded." 195 ) 196 parser.add_argument( 197 "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None 198 ) 199 parser.add_argument( 200 "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None 201 ) 202 parser.add_argument( 203 "-n", "--ndim", type=int, default=None, 204 help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension." 205 ) 206 parser.add_argument( 207 "--amg", action="store_true", help="Whether to use automatic mask generation with the model." 208 ) 209 parser.add_argument( 210 "-d", "--device", default=None, 211 help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." 212 "By default the most performant available device will be selected." 213 ) 214 215 args, parameter_args = parser.parse_known_args() 216 217 def _convert_argval(value): 218 # The values for the parsed arguments need to be in the expected input structure as provided. 219 # i.e. integers and floats should be in their original types. 220 try: 221 return int(value) 222 except ValueError: 223 return float(value) 224 225 # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to 226 # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS) 227 generate_kwargs = { 228 parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2) 229 } 230 231 predictor, segmenter = get_predictor_and_segmenter( 232 model_type=args.model_type, checkpoint=args.checkpoint, device=args.device, 233 ) 234 235 automatic_instance_segmentation( 236 predictor=predictor, 237 segmenter=segmenter, 238 input_path=args.input_path, 239 output_path=args.output_path, 240 embedding_path=args.embedding_path, 241 key=args.key, 242 ndim=args.ndim, 243 tile_shape=args.tile_shape, 244 halo=args.halo, 245 **generate_kwargs, 246 ) 247 248 249if __name__ == "__main__": 250 main()
def
get_predictor_and_segmenter( model_type: str, checkpoint: Union[os.PathLike, str, NoneType] = None, device: str = None, amg: Optional[bool] = None, is_tiled: bool = False, **kwargs) -> Tuple[mobile_sam.predictor.SamPredictor, Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder]]:
16def get_predictor_and_segmenter( 17 model_type: str, 18 checkpoint: Optional[Union[os.PathLike, str]] = None, 19 device: str = None, 20 amg: Optional[bool] = None, 21 is_tiled: bool = False, 22 **kwargs, 23) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: 24 """Get the Segment Anything model and class for automatic instance segmentation. 25 26 Args: 27 model_type: The Segment Anything model choice. 28 checkpoint: The filepath to the stored model checkpoints. 29 device: The torch device. 30 amg: Whether to perform automatic segmentation in AMG mode. 31 Otherwise AIS will be used, which requires a special segmentation decoder. 32 If not specified AIS will be used if it is available and otherwise AMG will be used. 33 is_tiled: Whether to return segmenter for performing segmentation in tiling window style. 34 kwargs: Keyword arguments for the automatic instance segmentation class. 35 36 Returns: 37 The Segment Anything model. 38 The automatic instance segmentation class. 39 """ 40 # Get the device 41 device = util.get_device(device=device) 42 43 # Get the predictor and state for Segment Anything models. 44 predictor, state = util.get_sam_model( 45 model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True, 46 ) 47 48 if amg is None: 49 amg = "decoder_state" not in state 50 if amg: 51 decoder = None 52 else: 53 if "decoder_state" not in state: 54 raise RuntimeError("You have passed amg=False, but your model does not contain a segmentation decoder.") 55 decoder_state = state["decoder_state"] 56 decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) 57 58 segmenter = get_amg( 59 predictor=predictor, 60 is_tiled=is_tiled, 61 decoder=decoder, 62 **kwargs 63 ) 64 return predictor, segmenter
Get the Segment Anything model and class for automatic instance segmentation.
Arguments:
- model_type: The Segment Anything model choice.
- checkpoint: The filepath to the stored model checkpoints.
- device: The torch device.
- amg: Whether to perform automatic segmentation in AMG mode. Otherwise AIS will be used, which requires a special segmentation decoder. If not specified AIS will be used if it is available and otherwise AMG will be used.
- is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
- kwargs: Keyword arguments for the automatic instance segmentation class.
Returns:
The Segment Anything model. The automatic instance segmentation class.
def
automatic_instance_segmentation( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, **generate_kwargs) -> numpy.ndarray:
67def automatic_instance_segmentation( 68 predictor: util.SamPredictor, 69 segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], 70 input_path: Union[Union[os.PathLike, str], np.ndarray], 71 output_path: Optional[Union[os.PathLike, str]] = None, 72 embedding_path: Optional[Union[os.PathLike, str]] = None, 73 key: Optional[str] = None, 74 ndim: Optional[int] = None, 75 tile_shape: Optional[Tuple[int, int]] = None, 76 halo: Optional[Tuple[int, int]] = None, 77 verbose: bool = True, 78 **generate_kwargs 79) -> np.ndarray: 80 """Run automatic segmentation for the input image. 81 82 Args: 83 predictor: The Segment Anything model. 84 segmenter: The automatic instance segmentation class. 85 input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), 86 or a container file (e.g. hdf5 or zarr). 87 output_path: The output path where the instance segmentations will be saved. 88 embedding_path: The path where the embeddings are cached already / will be saved. 89 key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) 90 or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. 91 ndim: The dimensionality of the data. By default the dimensionality of the data will be used. 92 If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. 93 tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. 94 halo: Overlap of the tiles for tiled prediction. 95 verbose: Verbosity flag. 96 generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. 97 98 Returns: 99 The segmentation result. 100 """ 101 # Load the input image file. 102 if isinstance(input_path, np.ndarray): 103 image_data = input_path 104 else: 105 image_data = util.load_image_data(input_path, key) 106 107 ndim = image_data.ndim if ndim is None else ndim 108 109 if ndim == 2: 110 if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): 111 raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") 112 113 # Precompute the image embeddings. 114 image_embeddings = util.precompute_image_embeddings( 115 predictor=predictor, 116 input_=image_data, 117 save_path=embedding_path, 118 ndim=ndim, 119 tile_shape=tile_shape, 120 halo=halo, 121 verbose=verbose, 122 ) 123 124 segmenter.initialize(image=image_data, image_embeddings=image_embeddings) 125 masks = segmenter.generate(**generate_kwargs) 126 127 if len(masks) == 0: # instance segmentation can have no masks, hence we just save empty labels 128 if isinstance(segmenter, InstanceSegmentationWithDecoder): 129 this_shape = segmenter._foreground.shape 130 elif isinstance(segmenter, AMGBase): 131 this_shape = segmenter._original_size 132 else: 133 this_shape = image_data.shape[-2:] 134 135 instances = np.zeros(this_shape, dtype="uint32") 136 else: 137 instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) 138 else: 139 if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): 140 raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") 141 142 instances = automatic_3d_segmentation( 143 volume=image_data, 144 predictor=predictor, 145 segmentor=segmenter, 146 embedding_path=embedding_path, 147 tile_shape=tile_shape, 148 halo=halo, 149 verbose=verbose, 150 **generate_kwargs 151 ) 152 153 if output_path is not None: 154 # Save the instance segmentation 155 output_path = Path(output_path).with_suffix(".tif") 156 imageio.imwrite(output_path, instances, compression="zlib") 157 158 return instances
Run automatic segmentation for the input image.
Arguments:
- predictor: The Segment Anything model.
- segmenter: The automatic instance segmentation class.
- input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
- output_path: The output path where the instance segmentations will be saved.
- embedding_path: The path where the embeddings are cached already / will be saved.
- key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
- ndim: The dimensionality of the data. By default the dimensionality of the data will be used. If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
- tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction.
- verbose: Verbosity flag.
- generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:
The segmentation result.