micro_sam.automatic_segmentation

  1import os
  2from pathlib import Path
  3from typing import Optional, Union, Tuple, Dict
  4
  5import numpy as np
  6import imageio.v3 as imageio
  7
  8from . import util
  9from .instance_segmentation import (
 10    get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder, AMGBase
 11)
 12from .multi_dimensional_segmentation import automatic_3d_segmentation
 13
 14
 15def get_predictor_and_segmenter(
 16    model_type: str,
 17    checkpoint: Optional[Union[os.PathLike, str]] = None,
 18    device: str = None,
 19    amg: Optional[bool] = None,
 20    is_tiled: bool = False,
 21    **kwargs,
 22) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
 23    """Get the Segment Anything model and class for automatic instance segmentation.
 24
 25    Args:
 26        model_type: The Segment Anything model choice.
 27        checkpoint: The filepath to the stored model checkpoints.
 28        device: The torch device.
 29        amg: Whether to perform automatic segmentation in AMG mode.
 30            Otherwise AIS will be used, which requires a special segmentation decoder.
 31            If not specified AIS will be used if it is available and otherwise AMG will be used.
 32        is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
 33        kwargs: Keyword arguments for the automatic instance segmentation class.
 34
 35    Returns:
 36        The Segment Anything model.
 37        The automatic instance segmentation class.
 38    """
 39    # Get the device
 40    device = util.get_device(device=device)
 41
 42    # Get the predictor and state for Segment Anything models.
 43    predictor, state = util.get_sam_model(
 44        model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True,
 45    )
 46
 47    if amg is None:
 48        amg = "decoder_state" not in state
 49    if amg:
 50        decoder = None
 51    else:
 52        if "decoder_state" not in state:
 53            raise RuntimeError("You have passed amg=False, but your model does not contain a segmentation decoder.")
 54        decoder_state = state["decoder_state"]
 55        decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)
 56
 57    segmenter = get_amg(
 58        predictor=predictor,
 59        is_tiled=is_tiled,
 60        decoder=decoder,
 61        **kwargs
 62    )
 63    return predictor, segmenter
 64
 65
 66def automatic_instance_segmentation(
 67    predictor: util.SamPredictor,
 68    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
 69    input_path: Union[Union[os.PathLike, str], np.ndarray],
 70    output_path: Optional[Union[os.PathLike, str]] = None,
 71    embedding_path: Optional[Union[os.PathLike, str]] = None,
 72    key: Optional[str] = None,
 73    ndim: Optional[int] = None,
 74    tile_shape: Optional[Tuple[int, int]] = None,
 75    halo: Optional[Tuple[int, int]] = None,
 76    verbose: bool = True,
 77    **generate_kwargs
 78) -> np.ndarray:
 79    """Run automatic segmentation for the input image.
 80
 81    Args:
 82        predictor: The Segment Anything model.
 83        segmenter: The automatic instance segmentation class.
 84        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
 85            or a container file (e.g. hdf5 or zarr).
 86        output_path: The output path where the instance segmentations will be saved.
 87        embedding_path: The path where the embeddings are cached already / will be saved.
 88        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
 89            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
 90        ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
 91            If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
 92        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
 93        halo: Overlap of the tiles for tiled prediction.
 94        verbose: Verbosity flag.
 95        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
 96
 97    Returns:
 98        The segmentation result.
 99    """
100    # Load the input image file.
101    if isinstance(input_path, np.ndarray):
102        image_data = input_path
103    else:
104        image_data = util.load_image_data(input_path, key)
105
106    ndim = image_data.ndim if ndim is None else ndim
107
108    if ndim == 2:
109        if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3):
110            raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
111
112        # Precompute the image embeddings.
113        image_embeddings = util.precompute_image_embeddings(
114            predictor=predictor,
115            input_=image_data,
116            save_path=embedding_path,
117            ndim=ndim,
118            tile_shape=tile_shape,
119            halo=halo,
120            verbose=verbose,
121        )
122
123        segmenter.initialize(image=image_data, image_embeddings=image_embeddings)
124        masks = segmenter.generate(**generate_kwargs)
125
126        if len(masks) == 0:  # instance segmentation can have no masks, hence we just save empty labels
127            if isinstance(segmenter, InstanceSegmentationWithDecoder):
128                this_shape = segmenter._foreground.shape
129            elif isinstance(segmenter, AMGBase):
130                this_shape = segmenter._original_size
131            else:
132                this_shape = image_data.shape[-2:]
133
134            instances = np.zeros(this_shape, dtype="uint32")
135        else:
136            instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)
137    else:
138        if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
139            raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
140
141        instances = automatic_3d_segmentation(
142            volume=image_data,
143            predictor=predictor,
144            segmentor=segmenter,
145            embedding_path=embedding_path,
146            tile_shape=tile_shape,
147            halo=halo,
148            verbose=verbose,
149            **generate_kwargs
150        )
151
152    if output_path is not None:
153        # Save the instance segmentation
154        output_path = Path(output_path).with_suffix(".tif")
155        imageio.imwrite(output_path, instances, compression="zlib")
156
157    return instances
158
159
160def main():
161    """@private"""
162    import argparse
163
164    available_models = list(util.get_model_names())
165    available_models = ", ".join(available_models)
166
167    parser = argparse.ArgumentParser(description="Run automatic segmentation for an image.")
168    parser.add_argument(
169        "-i", "--input_path", required=True,
170        help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "
171        "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
172    )
173    parser.add_argument(
174        "-o", "--output_path", required=True,
175        help="The filepath to store the instance segmentation. The current support stores segmentation in a 'tif' file."
176    )
177    parser.add_argument(
178        "-e", "--embedding_path", default=None, type=str, help="The path where the embeddings will be saved."
179    )
180    parser.add_argument(
181        "--pattern", help="Pattern / wildcard for selecting files in a folder. To select all files use '*'."
182    )
183    parser.add_argument(
184        "-k", "--key",
185        help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
186        "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
187    )
188    parser.add_argument(
189        "-m", "--model_type", default=util._DEFAULT_MODEL,
190        help=f"The segment anything model that will be used, one of {available_models}."
191    )
192    parser.add_argument(
193        "-c", "--checkpoint", default=None,
194        help="Checkpoint from which the SAM model will be loaded loaded."
195    )
196    parser.add_argument(
197        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None
198    )
199    parser.add_argument(
200        "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None
201    )
202    parser.add_argument(
203        "-n", "--ndim", type=int, default=None,
204        help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension."
205    )
206    parser.add_argument(
207        "--amg", action="store_true", help="Whether to use automatic mask generation with the model."
208    )
209    parser.add_argument(
210        "-d", "--device", default=None,
211        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
212        "By default the most performant available device will be selected."
213    )
214
215    args, parameter_args = parser.parse_known_args()
216
217    def _convert_argval(value):
218        # The values for the parsed arguments need to be in the expected input structure as provided.
219        # i.e. integers and floats should be in their original types.
220        try:
221            return int(value)
222        except ValueError:
223            return float(value)
224
225    # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to
226    # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS)
227    generate_kwargs = {
228        parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2)
229    }
230
231    predictor, segmenter = get_predictor_and_segmenter(
232        model_type=args.model_type, checkpoint=args.checkpoint, device=args.device,
233    )
234
235    automatic_instance_segmentation(
236        predictor=predictor,
237        segmenter=segmenter,
238        input_path=args.input_path,
239        output_path=args.output_path,
240        embedding_path=args.embedding_path,
241        key=args.key,
242        ndim=args.ndim,
243        tile_shape=args.tile_shape,
244        halo=args.halo,
245        **generate_kwargs,
246    )
247
248
249if __name__ == "__main__":
250    main()
def get_predictor_and_segmenter( model_type: str, checkpoint: Union[os.PathLike, str, NoneType] = None, device: str = None, amg: Optional[bool] = None, is_tiled: bool = False, **kwargs) -> Tuple[mobile_sam.predictor.SamPredictor, Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder]]:
16def get_predictor_and_segmenter(
17    model_type: str,
18    checkpoint: Optional[Union[os.PathLike, str]] = None,
19    device: str = None,
20    amg: Optional[bool] = None,
21    is_tiled: bool = False,
22    **kwargs,
23) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
24    """Get the Segment Anything model and class for automatic instance segmentation.
25
26    Args:
27        model_type: The Segment Anything model choice.
28        checkpoint: The filepath to the stored model checkpoints.
29        device: The torch device.
30        amg: Whether to perform automatic segmentation in AMG mode.
31            Otherwise AIS will be used, which requires a special segmentation decoder.
32            If not specified AIS will be used if it is available and otherwise AMG will be used.
33        is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
34        kwargs: Keyword arguments for the automatic instance segmentation class.
35
36    Returns:
37        The Segment Anything model.
38        The automatic instance segmentation class.
39    """
40    # Get the device
41    device = util.get_device(device=device)
42
43    # Get the predictor and state for Segment Anything models.
44    predictor, state = util.get_sam_model(
45        model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True,
46    )
47
48    if amg is None:
49        amg = "decoder_state" not in state
50    if amg:
51        decoder = None
52    else:
53        if "decoder_state" not in state:
54            raise RuntimeError("You have passed amg=False, but your model does not contain a segmentation decoder.")
55        decoder_state = state["decoder_state"]
56        decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)
57
58    segmenter = get_amg(
59        predictor=predictor,
60        is_tiled=is_tiled,
61        decoder=decoder,
62        **kwargs
63    )
64    return predictor, segmenter

Get the Segment Anything model and class for automatic instance segmentation.

Arguments:
  • model_type: The Segment Anything model choice.
  • checkpoint: The filepath to the stored model checkpoints.
  • device: The torch device.
  • amg: Whether to perform automatic segmentation in AMG mode. Otherwise AIS will be used, which requires a special segmentation decoder. If not specified AIS will be used if it is available and otherwise AMG will be used.
  • is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
  • kwargs: Keyword arguments for the automatic instance segmentation class.
Returns:

The Segment Anything model. The automatic instance segmentation class.

def automatic_instance_segmentation( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, **generate_kwargs) -> numpy.ndarray:
 67def automatic_instance_segmentation(
 68    predictor: util.SamPredictor,
 69    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
 70    input_path: Union[Union[os.PathLike, str], np.ndarray],
 71    output_path: Optional[Union[os.PathLike, str]] = None,
 72    embedding_path: Optional[Union[os.PathLike, str]] = None,
 73    key: Optional[str] = None,
 74    ndim: Optional[int] = None,
 75    tile_shape: Optional[Tuple[int, int]] = None,
 76    halo: Optional[Tuple[int, int]] = None,
 77    verbose: bool = True,
 78    **generate_kwargs
 79) -> np.ndarray:
 80    """Run automatic segmentation for the input image.
 81
 82    Args:
 83        predictor: The Segment Anything model.
 84        segmenter: The automatic instance segmentation class.
 85        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
 86            or a container file (e.g. hdf5 or zarr).
 87        output_path: The output path where the instance segmentations will be saved.
 88        embedding_path: The path where the embeddings are cached already / will be saved.
 89        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
 90            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
 91        ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
 92            If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
 93        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
 94        halo: Overlap of the tiles for tiled prediction.
 95        verbose: Verbosity flag.
 96        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
 97
 98    Returns:
 99        The segmentation result.
100    """
101    # Load the input image file.
102    if isinstance(input_path, np.ndarray):
103        image_data = input_path
104    else:
105        image_data = util.load_image_data(input_path, key)
106
107    ndim = image_data.ndim if ndim is None else ndim
108
109    if ndim == 2:
110        if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3):
111            raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
112
113        # Precompute the image embeddings.
114        image_embeddings = util.precompute_image_embeddings(
115            predictor=predictor,
116            input_=image_data,
117            save_path=embedding_path,
118            ndim=ndim,
119            tile_shape=tile_shape,
120            halo=halo,
121            verbose=verbose,
122        )
123
124        segmenter.initialize(image=image_data, image_embeddings=image_embeddings)
125        masks = segmenter.generate(**generate_kwargs)
126
127        if len(masks) == 0:  # instance segmentation can have no masks, hence we just save empty labels
128            if isinstance(segmenter, InstanceSegmentationWithDecoder):
129                this_shape = segmenter._foreground.shape
130            elif isinstance(segmenter, AMGBase):
131                this_shape = segmenter._original_size
132            else:
133                this_shape = image_data.shape[-2:]
134
135            instances = np.zeros(this_shape, dtype="uint32")
136        else:
137            instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)
138    else:
139        if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
140            raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
141
142        instances = automatic_3d_segmentation(
143            volume=image_data,
144            predictor=predictor,
145            segmentor=segmenter,
146            embedding_path=embedding_path,
147            tile_shape=tile_shape,
148            halo=halo,
149            verbose=verbose,
150            **generate_kwargs
151        )
152
153    if output_path is not None:
154        # Save the instance segmentation
155        output_path = Path(output_path).with_suffix(".tif")
156        imageio.imwrite(output_path, instances, compression="zlib")
157
158    return instances

Run automatic segmentation for the input image.

Arguments:
  • predictor: The Segment Anything model.
  • segmenter: The automatic instance segmentation class.
  • input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
  • output_path: The output path where the instance segmentations will be saved.
  • embedding_path: The path where the embeddings are cached already / will be saved.
  • key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
  • ndim: The dimensionality of the data. By default the dimensionality of the data will be used. If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction.
  • verbose: Verbosity flag.
  • generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:

The segmentation result.