micro_sam.automatic_segmentation

  1import os
  2from pathlib import Path
  3from typing import Optional, Union, Tuple
  4
  5import numpy as np
  6import imageio.v3 as imageio
  7
  8from torch_em.data.datasets.util import split_kwargs
  9
 10from . import util
 11from .instance_segmentation import (
 12    get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder,
 13    AMGBase, AutomaticMaskGenerator, TiledAutomaticMaskGenerator
 14)
 15from .multi_dimensional_segmentation import automatic_3d_segmentation
 16
 17
 18def get_predictor_and_segmenter(
 19    model_type: str,
 20    checkpoint: Optional[Union[os.PathLike, str]] = None,
 21    device: str = None,
 22    amg: Optional[bool] = None,
 23    is_tiled: bool = False,
 24    **kwargs,
 25) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
 26    """Get the Segment Anything model and class for automatic instance segmentation.
 27
 28    Args:
 29        model_type: The Segment Anything model choice.
 30        checkpoint: The filepath to the stored model checkpoints.
 31        device: The torch device.
 32        amg: Whether to perform automatic segmentation in AMG mode.
 33            Otherwise AIS will be used, which requires a special segmentation decoder.
 34            If not specified AIS will be used if it is available and otherwise AMG will be used.
 35        is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
 36        kwargs: Keyword arguments for the automatic mask generation class.
 37
 38    Returns:
 39        The Segment Anything model.
 40        The automatic instance segmentation class.
 41    """
 42    # Get the device
 43    device = util.get_device(device=device)
 44
 45    # Get the predictor and state for Segment Anything models.
 46    predictor, state = util.get_sam_model(
 47        model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True,
 48    )
 49
 50    if amg is None:
 51        amg = "decoder_state" not in state
 52
 53    if amg:
 54        decoder = None
 55    else:
 56        if "decoder_state" not in state:
 57            raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.")
 58        decoder_state = state["decoder_state"]
 59        decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)
 60
 61    segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs)
 62
 63    return predictor, segmenter
 64
 65
 66def automatic_instance_segmentation(
 67    predictor: util.SamPredictor,
 68    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
 69    input_path: Union[Union[os.PathLike, str], np.ndarray],
 70    output_path: Optional[Union[os.PathLike, str]] = None,
 71    embedding_path: Optional[Union[os.PathLike, str]] = None,
 72    key: Optional[str] = None,
 73    ndim: Optional[int] = None,
 74    tile_shape: Optional[Tuple[int, int]] = None,
 75    halo: Optional[Tuple[int, int]] = None,
 76    verbose: bool = True,
 77    **generate_kwargs
 78) -> np.ndarray:
 79    """Run automatic segmentation for the input image.
 80
 81    Args:
 82        predictor: The Segment Anything model.
 83        segmenter: The automatic instance segmentation class.
 84        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
 85            or a container file (e.g. hdf5 or zarr).
 86        output_path: The output path where the instance segmentations will be saved.
 87        embedding_path: The path where the embeddings are cached already / will be saved.
 88        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
 89            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
 90        ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
 91            If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
 92        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
 93        halo: Overlap of the tiles for tiled prediction.
 94        verbose: Verbosity flag.
 95        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
 96
 97    Returns:
 98        The segmentation result.
 99    """
100    # Load the input image file.
101    if isinstance(input_path, np.ndarray):
102        image_data = input_path
103    else:
104        image_data = util.load_image_data(input_path, key)
105
106    ndim = image_data.ndim if ndim is None else ndim
107
108    if ndim == 2:
109        if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3):
110            raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
111
112        # Precompute the image embeddings.
113        image_embeddings = util.precompute_image_embeddings(
114            predictor=predictor,
115            input_=image_data,
116            save_path=embedding_path,
117            ndim=ndim,
118            tile_shape=tile_shape,
119            halo=halo,
120            verbose=verbose,
121        )
122
123        segmenter.initialize(image=image_data, image_embeddings=image_embeddings)
124        masks = segmenter.generate(**generate_kwargs)
125
126        if len(masks) == 0:  # instance segmentation can have no masks, hence we just save empty labels
127            if isinstance(segmenter, InstanceSegmentationWithDecoder):
128                this_shape = segmenter._foreground.shape
129            elif isinstance(segmenter, AMGBase):
130                this_shape = segmenter._original_size
131            else:
132                this_shape = image_data.shape[-2:]
133
134            instances = np.zeros(this_shape, dtype="uint32")
135        else:
136            instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)
137
138    else:
139        if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
140            raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
141
142        instances = automatic_3d_segmentation(
143            volume=image_data,
144            predictor=predictor,
145            segmentor=segmenter,
146            embedding_path=embedding_path,
147            tile_shape=tile_shape,
148            halo=halo,
149            verbose=verbose,
150            **generate_kwargs
151        )
152
153    if output_path is not None:
154        # Save the instance segmentation
155        output_path = Path(output_path).with_suffix(".tif")
156        imageio.imwrite(output_path, instances, compression="zlib")
157
158    return instances
159
160
161def main():
162    """@private"""
163    import argparse
164
165    available_models = list(util.get_model_names())
166    available_models = ", ".join(available_models)
167
168    parser = argparse.ArgumentParser(description="Run automatic segmentation for an image.")
169    parser.add_argument(
170        "-i", "--input_path", required=True,
171        help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "
172        "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
173    )
174    parser.add_argument(
175        "-o", "--output_path", required=True,
176        help="The filepath to store the instance segmentation. The current support stores segmentation in a 'tif' file."
177    )
178    parser.add_argument(
179        "-e", "--embedding_path", default=None, type=str, help="The path where the embeddings will be saved."
180    )
181    parser.add_argument(
182        "--pattern", help="Pattern / wildcard for selecting files in a folder. To select all files use '*'."
183    )
184    parser.add_argument(
185        "-k", "--key",
186        help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
187        "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
188    )
189    parser.add_argument(
190        "-m", "--model_type", default=util._DEFAULT_MODEL,
191        help=f"The segment anything model that will be used, one of {available_models}."
192    )
193    parser.add_argument(
194        "-c", "--checkpoint", default=None,
195        help="Checkpoint from which the SAM model will be loaded."
196    )
197    parser.add_argument(
198        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None
199    )
200    parser.add_argument(
201        "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None
202    )
203    parser.add_argument(
204        "-n", "--ndim", type=int, default=None,
205        help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension."
206    )
207    parser.add_argument(
208        "--mode", type=str, default=None,
209        help="The choice of automatic segmentation with the Segment Anything models. Either 'amg' or 'ais'."
210    )
211    parser.add_argument(
212        "-d", "--device", default=None,
213        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
214        "By default the most performant available device will be selected."
215    )
216
217    args, parameter_args = parser.parse_known_args()
218
219    def _convert_argval(value):
220        # The values for the parsed arguments need to be in the expected input structure as provided.
221        # i.e. integers and floats should be in their original types.
222        try:
223            return int(value)
224        except ValueError:
225            return float(value)
226
227    # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to
228    # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS)
229    extra_kwargs = {
230        parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2)
231    }
232
233    # Separate extra arguments as per where they should be passed in the automatic segmentation class.
234    # This is done to ensure the extra arguments are allocated to the desired location.
235    # eg. for AMG, 'points_per_side' is expected by '__init__',
236    # and 'stability_score_thresh' is expected in 'generate' method.
237    amg_class = AutomaticMaskGenerator if args.tile_shape is None else TiledAutomaticMaskGenerator
238    amg_kwargs, generate_kwargs = split_kwargs(amg_class, **extra_kwargs)
239
240    # Validate for the expected automatic segmentation mode.
241    # By default, it is set to 'None', i.e. searches for the decoder state to prioritize AIS for finetuned models.
242    # Otherwise, runs AMG for all models in any case.
243    amg = None
244    if args.mode is not None:
245        assert args.mode in ["ais", "amg"], \
246            f"'{args.mode}' is not a valid automatic segmentation mode. Please choose either 'amg' or 'ais'."
247        amg = (args.mode == "amg")
248
249    predictor, segmenter = get_predictor_and_segmenter(
250        model_type=args.model_type,
251        checkpoint=args.checkpoint,
252        device=args.device,
253        amg=amg,
254        is_tiled=args.tile_shape is not None,
255        **amg_kwargs,
256    )
257
258    automatic_instance_segmentation(
259        predictor=predictor,
260        segmenter=segmenter,
261        input_path=args.input_path,
262        output_path=args.output_path,
263        embedding_path=args.embedding_path,
264        key=args.key,
265        ndim=args.ndim,
266        tile_shape=args.tile_shape,
267        halo=args.halo,
268        **generate_kwargs,
269    )
270
271
272if __name__ == "__main__":
273    main()
def get_predictor_and_segmenter( model_type: str, checkpoint: Union[os.PathLike, str, NoneType] = None, device: str = None, amg: Optional[bool] = None, is_tiled: bool = False, **kwargs) -> Tuple[mobile_sam.predictor.SamPredictor, Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder]]:
19def get_predictor_and_segmenter(
20    model_type: str,
21    checkpoint: Optional[Union[os.PathLike, str]] = None,
22    device: str = None,
23    amg: Optional[bool] = None,
24    is_tiled: bool = False,
25    **kwargs,
26) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
27    """Get the Segment Anything model and class for automatic instance segmentation.
28
29    Args:
30        model_type: The Segment Anything model choice.
31        checkpoint: The filepath to the stored model checkpoints.
32        device: The torch device.
33        amg: Whether to perform automatic segmentation in AMG mode.
34            Otherwise AIS will be used, which requires a special segmentation decoder.
35            If not specified AIS will be used if it is available and otherwise AMG will be used.
36        is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
37        kwargs: Keyword arguments for the automatic mask generation class.
38
39    Returns:
40        The Segment Anything model.
41        The automatic instance segmentation class.
42    """
43    # Get the device
44    device = util.get_device(device=device)
45
46    # Get the predictor and state for Segment Anything models.
47    predictor, state = util.get_sam_model(
48        model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True,
49    )
50
51    if amg is None:
52        amg = "decoder_state" not in state
53
54    if amg:
55        decoder = None
56    else:
57        if "decoder_state" not in state:
58            raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.")
59        decoder_state = state["decoder_state"]
60        decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)
61
62    segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs)
63
64    return predictor, segmenter

Get the Segment Anything model and class for automatic instance segmentation.

Arguments:
  • model_type: The Segment Anything model choice.
  • checkpoint: The filepath to the stored model checkpoints.
  • device: The torch device.
  • amg: Whether to perform automatic segmentation in AMG mode. Otherwise AIS will be used, which requires a special segmentation decoder. If not specified AIS will be used if it is available and otherwise AMG will be used.
  • is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
  • kwargs: Keyword arguments for the automatic mask generation class.
Returns:

The Segment Anything model. The automatic instance segmentation class.

def automatic_instance_segmentation( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, **generate_kwargs) -> numpy.ndarray:
 67def automatic_instance_segmentation(
 68    predictor: util.SamPredictor,
 69    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
 70    input_path: Union[Union[os.PathLike, str], np.ndarray],
 71    output_path: Optional[Union[os.PathLike, str]] = None,
 72    embedding_path: Optional[Union[os.PathLike, str]] = None,
 73    key: Optional[str] = None,
 74    ndim: Optional[int] = None,
 75    tile_shape: Optional[Tuple[int, int]] = None,
 76    halo: Optional[Tuple[int, int]] = None,
 77    verbose: bool = True,
 78    **generate_kwargs
 79) -> np.ndarray:
 80    """Run automatic segmentation for the input image.
 81
 82    Args:
 83        predictor: The Segment Anything model.
 84        segmenter: The automatic instance segmentation class.
 85        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
 86            or a container file (e.g. hdf5 or zarr).
 87        output_path: The output path where the instance segmentations will be saved.
 88        embedding_path: The path where the embeddings are cached already / will be saved.
 89        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
 90            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
 91        ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
 92            If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
 93        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
 94        halo: Overlap of the tiles for tiled prediction.
 95        verbose: Verbosity flag.
 96        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
 97
 98    Returns:
 99        The segmentation result.
100    """
101    # Load the input image file.
102    if isinstance(input_path, np.ndarray):
103        image_data = input_path
104    else:
105        image_data = util.load_image_data(input_path, key)
106
107    ndim = image_data.ndim if ndim is None else ndim
108
109    if ndim == 2:
110        if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3):
111            raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
112
113        # Precompute the image embeddings.
114        image_embeddings = util.precompute_image_embeddings(
115            predictor=predictor,
116            input_=image_data,
117            save_path=embedding_path,
118            ndim=ndim,
119            tile_shape=tile_shape,
120            halo=halo,
121            verbose=verbose,
122        )
123
124        segmenter.initialize(image=image_data, image_embeddings=image_embeddings)
125        masks = segmenter.generate(**generate_kwargs)
126
127        if len(masks) == 0:  # instance segmentation can have no masks, hence we just save empty labels
128            if isinstance(segmenter, InstanceSegmentationWithDecoder):
129                this_shape = segmenter._foreground.shape
130            elif isinstance(segmenter, AMGBase):
131                this_shape = segmenter._original_size
132            else:
133                this_shape = image_data.shape[-2:]
134
135            instances = np.zeros(this_shape, dtype="uint32")
136        else:
137            instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)
138
139    else:
140        if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
141            raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
142
143        instances = automatic_3d_segmentation(
144            volume=image_data,
145            predictor=predictor,
146            segmentor=segmenter,
147            embedding_path=embedding_path,
148            tile_shape=tile_shape,
149            halo=halo,
150            verbose=verbose,
151            **generate_kwargs
152        )
153
154    if output_path is not None:
155        # Save the instance segmentation
156        output_path = Path(output_path).with_suffix(".tif")
157        imageio.imwrite(output_path, instances, compression="zlib")
158
159    return instances

Run automatic segmentation for the input image.

Arguments:
  • predictor: The Segment Anything model.
  • segmenter: The automatic instance segmentation class.
  • input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
  • output_path: The output path where the instance segmentations will be saved.
  • embedding_path: The path where the embeddings are cached already / will be saved.
  • key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
  • ndim: The dimensionality of the data. By default the dimensionality of the data will be used. If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction.
  • verbose: Verbosity flag.
  • generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:

The segmentation result.