micro_sam.automatic_segmentation

  1import os
  2from pathlib import Path
  3from typing import Optional, Union, Tuple
  4
  5import numpy as np
  6import imageio.v3 as imageio
  7
  8from torch_em.data.datasets.util import split_kwargs
  9
 10from . import util
 11from .instance_segmentation import (
 12    get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder,
 13    AMGBase, AutomaticMaskGenerator, TiledAutomaticMaskGenerator
 14)
 15from .multi_dimensional_segmentation import automatic_3d_segmentation
 16
 17
 18def get_predictor_and_segmenter(
 19    model_type: str,
 20    checkpoint: Optional[Union[os.PathLike, str]] = None,
 21    device: str = None,
 22    amg: Optional[bool] = None,
 23    is_tiled: bool = False,
 24    **kwargs,
 25) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
 26    """Get the Segment Anything model and class for automatic instance segmentation.
 27
 28    Args:
 29        model_type: The Segment Anything model choice.
 30        checkpoint: The filepath to the stored model checkpoints.
 31        device: The torch device.
 32        amg: Whether to perform automatic segmentation in AMG mode.
 33            Otherwise AIS will be used, which requires a special segmentation decoder.
 34            If not specified AIS will be used if it is available and otherwise AMG will be used.
 35        is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
 36        kwargs: Keyword arguments for the automatic mask generation class.
 37
 38    Returns:
 39        The Segment Anything model.
 40        The automatic instance segmentation class.
 41    """
 42    # Get the device
 43    device = util.get_device(device=device)
 44
 45    # Get the predictor and state for Segment Anything models.
 46    predictor, state = util.get_sam_model(
 47        model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True,
 48    )
 49
 50    if amg is None:
 51        amg = "decoder_state" not in state
 52
 53    if amg:
 54        decoder = None
 55    else:
 56        if "decoder_state" not in state:
 57            raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.")
 58        decoder_state = state["decoder_state"]
 59        decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)
 60
 61    segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs)
 62
 63    return predictor, segmenter
 64
 65
 66def automatic_instance_segmentation(
 67    predictor: util.SamPredictor,
 68    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
 69    input_path: Union[Union[os.PathLike, str], np.ndarray],
 70    output_path: Optional[Union[os.PathLike, str]] = None,
 71    embedding_path: Optional[Union[os.PathLike, str]] = None,
 72    key: Optional[str] = None,
 73    ndim: Optional[int] = None,
 74    tile_shape: Optional[Tuple[int, int]] = None,
 75    halo: Optional[Tuple[int, int]] = None,
 76    verbose: bool = True,
 77    return_embeddings: bool = False,
 78    annotate: bool = False,
 79    **generate_kwargs
 80) -> np.ndarray:
 81    """Run automatic segmentation for the input image.
 82
 83    Args:
 84        predictor: The Segment Anything model.
 85        segmenter: The automatic instance segmentation class.
 86        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
 87            or a container file (e.g. hdf5 or zarr).
 88        output_path: The output path where the instance segmentations will be saved.
 89        embedding_path: The path where the embeddings are cached already / will be saved.
 90        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
 91            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
 92        ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
 93            If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
 94        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
 95        halo: Overlap of the tiles for tiled prediction.
 96        verbose: Verbosity flag.
 97        return_embeddings: Whether to return the precomputed image embeddings.
 98        annotate: Whether to activate the annotator for continue annotation process.
 99        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
100
101    Returns:
102        The segmentation result.
103    """
104    # Avoid overwriting already stored segmentations.
105    if output_path is not None:
106        output_path = Path(output_path).with_suffix(".tif")
107        if os.path.exists(output_path):
108            print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.")
109            return
110
111    # Load the input image file.
112    if isinstance(input_path, np.ndarray):
113        image_data = input_path
114    else:
115        image_data = util.load_image_data(input_path, key)
116
117    ndim = image_data.ndim if ndim is None else ndim
118
119    if ndim == 2:
120        if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3):
121            raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
122
123        # Precompute the image embeddings.
124        image_embeddings = util.precompute_image_embeddings(
125            predictor=predictor,
126            input_=image_data,
127            save_path=embedding_path,
128            ndim=ndim,
129            tile_shape=tile_shape,
130            halo=halo,
131            verbose=verbose,
132        )
133
134        # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing.
135        if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None:
136            generate_kwargs.update({"tile_shape": tile_shape, "halo": halo})
137
138        segmenter.initialize(image=image_data, image_embeddings=image_embeddings, verbose=verbose)
139        masks = segmenter.generate(**generate_kwargs)
140
141        if isinstance(masks, list):
142            # whether the predictions from 'generate' are list of dict,
143            # which contains additional info req. for post-processing, eg. area per object.
144            if len(masks) == 0:
145                instances = np.zeros(image_data.shape[:2], dtype="uint32")
146            else:
147                instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)
148        else:
149            # if (raw) predictions provided, store them as it is w/o further post-processing.
150            instances = masks
151
152    else:
153        if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
154            raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
155
156        outputs = automatic_3d_segmentation(
157            volume=image_data,
158            predictor=predictor,
159            segmentor=segmenter,
160            embedding_path=embedding_path,
161            tile_shape=tile_shape,
162            halo=halo,
163            verbose=verbose,
164            return_embeddings=return_embeddings,
165            **generate_kwargs
166        )
167
168        if return_embeddings:
169            instances, image_embeddings = outputs
170        else:
171            instances = outputs
172
173    # Allow opening the automatic segmentation in the annotator for further annotation, if desired.
174    if annotate:
175        from micro_sam.sam_annotator import annotator_2d, annotator_3d
176        annotator_function = annotator_2d if ndim == 2 else annotator_3d
177
178        viewer = annotator_function(
179            image=image_data,
180            model_type=predictor.model_name,
181            embedding_path=embedding_path,
182            segmentation_result=instances,  # Initializes the automatic segmentation to the annotator.
183            tile_shape=tile_shape,
184            halo=halo,
185            return_viewer=True,  # Returns the viewer, which allows the user to store the updated segmentations.
186        )
187
188        # Start the GUI here
189        import napari
190        napari.run()
191
192        # We extract the segmentation in "committed_objects" layer, where the user either:
193        # a) Performed interactive segmentation / corrections and committed them, OR
194        # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is.
195        instances = viewer.layers["committed_objects"].data
196
197    # Save the instance segmentation, if 'output_path' provided.
198    if output_path is not None:
199        imageio.imwrite(output_path, instances, compression="zlib")
200        print(f"The segmentation results are stored at '{os.path.abspath(output_path)}'.")
201
202    if return_embeddings:
203        return instances, image_embeddings
204    else:
205        return instances
206
207
208def main():
209    """@private"""
210    import argparse
211
212    available_models = list(util.get_model_names())
213    available_models = ", ".join(available_models)
214
215    parser = argparse.ArgumentParser(description="Run automatic segmentation for an image.")
216    parser.add_argument(
217        "-i", "--input_path", required=True,
218        help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "
219        "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
220    )
221    parser.add_argument(
222        "-o", "--output_path", required=True,
223        help="The filepath to store the instance segmentation. The current support stores segmentation in a 'tif' file."
224    )
225    parser.add_argument(
226        "-e", "--embedding_path", default=None, type=str, help="The path where the embeddings will be saved."
227    )
228    parser.add_argument(
229        "--pattern", help="Pattern / wildcard for selecting files in a folder. To select all files use '*'."
230    )
231    parser.add_argument(
232        "-k", "--key",
233        help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
234        "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
235    )
236    parser.add_argument(
237        "-m", "--model_type", default=util._DEFAULT_MODEL,
238        help=f"The segment anything model that will be used, one of {available_models}."
239    )
240    parser.add_argument(
241        "-c", "--checkpoint", default=None, help="Checkpoint from which the SAM model will be loaded."
242    )
243    parser.add_argument(
244        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None
245    )
246    parser.add_argument(
247        "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None
248    )
249    parser.add_argument(
250        "-n", "--ndim", type=int, default=None,
251        help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension."
252    )
253    parser.add_argument(
254        "--mode", type=str, default="auto",
255        help="The choice of automatic segmentation with the Segment Anything models. Either 'auto', 'amg' or 'ais'."
256    )
257    parser.add_argument(
258        "--annotate", action="store_true",
259        help="Whether to continue annotation after the automatic segmentation is generated."
260    )
261    parser.add_argument(
262        "-d", "--device", default=None,
263        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
264        "By default the most performant available device will be selected."
265    )
266    parser.add_argument(
267        "-v", "--verbose", action="store_true", help="Whether to allow verbosity of outputs."
268    )
269
270    args, parameter_args = parser.parse_known_args()
271
272    def _convert_argval(value):
273        # The values for the parsed arguments need to be in the expected input structure as provided.
274        # i.e. integers and floats should be in their original types.
275        try:
276            return int(value)
277        except ValueError:
278            return float(value)
279
280    # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to
281    # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS)
282    extra_kwargs = {
283        parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2)
284    }
285
286    # Separate extra arguments as per where they should be passed in the automatic segmentation class.
287    # This is done to ensure the extra arguments are allocated to the desired location.
288    # eg. for AMG, 'points_per_side' is expected by '__init__',
289    # and 'stability_score_thresh' is expected in 'generate' method.
290    amg_class = AutomaticMaskGenerator if args.tile_shape is None else TiledAutomaticMaskGenerator
291    amg_kwargs, generate_kwargs = split_kwargs(amg_class, **extra_kwargs)
292
293    # Validate for the expected automatic segmentation mode.
294    # By default, it is set to 'auto', i.e. searches for the decoder state to prioritize AIS for finetuned models.
295    # Otherwise, runs AMG for all models in any case.
296    amg = None
297    if args.mode != "auto":
298        assert args.mode in ["ais", "amg"], \
299            f"'{args.mode}' is not a valid automatic segmentation mode. Please choose either 'amg' or 'ais'."
300        amg = (args.mode == "amg")
301
302    predictor, segmenter = get_predictor_and_segmenter(
303        model_type=args.model_type,
304        checkpoint=args.checkpoint,
305        device=args.device,
306        amg=amg,
307        is_tiled=args.tile_shape is not None,
308        **amg_kwargs,
309    )
310
311    # We perform additional post-processing for AMG-only.
312    # Otherwise, we ignore additional post-processing for AIS.
313    if isinstance(segmenter, InstanceSegmentationWithDecoder):
314        generate_kwargs["output_mode"] = None
315
316    automatic_instance_segmentation(
317        predictor=predictor,
318        segmenter=segmenter,
319        input_path=args.input_path,
320        output_path=args.output_path,
321        embedding_path=args.embedding_path,
322        key=args.key,
323        ndim=args.ndim,
324        tile_shape=args.tile_shape,
325        halo=args.halo,
326        annotate=args.annotate,
327        verbose=args.verbose,
328        **generate_kwargs,
329    )
330
331
332if __name__ == "__main__":
333    main()
def get_predictor_and_segmenter( model_type: str, checkpoint: Union[os.PathLike, str, NoneType] = None, device: str = None, amg: Optional[bool] = None, is_tiled: bool = False, **kwargs) -> Tuple[mobile_sam.predictor.SamPredictor, Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder]]:
19def get_predictor_and_segmenter(
20    model_type: str,
21    checkpoint: Optional[Union[os.PathLike, str]] = None,
22    device: str = None,
23    amg: Optional[bool] = None,
24    is_tiled: bool = False,
25    **kwargs,
26) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
27    """Get the Segment Anything model and class for automatic instance segmentation.
28
29    Args:
30        model_type: The Segment Anything model choice.
31        checkpoint: The filepath to the stored model checkpoints.
32        device: The torch device.
33        amg: Whether to perform automatic segmentation in AMG mode.
34            Otherwise AIS will be used, which requires a special segmentation decoder.
35            If not specified AIS will be used if it is available and otherwise AMG will be used.
36        is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
37        kwargs: Keyword arguments for the automatic mask generation class.
38
39    Returns:
40        The Segment Anything model.
41        The automatic instance segmentation class.
42    """
43    # Get the device
44    device = util.get_device(device=device)
45
46    # Get the predictor and state for Segment Anything models.
47    predictor, state = util.get_sam_model(
48        model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True,
49    )
50
51    if amg is None:
52        amg = "decoder_state" not in state
53
54    if amg:
55        decoder = None
56    else:
57        if "decoder_state" not in state:
58            raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.")
59        decoder_state = state["decoder_state"]
60        decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)
61
62    segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs)
63
64    return predictor, segmenter

Get the Segment Anything model and class for automatic instance segmentation.

Arguments:
  • model_type: The Segment Anything model choice.
  • checkpoint: The filepath to the stored model checkpoints.
  • device: The torch device.
  • amg: Whether to perform automatic segmentation in AMG mode. Otherwise AIS will be used, which requires a special segmentation decoder. If not specified AIS will be used if it is available and otherwise AMG will be used.
  • is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
  • kwargs: Keyword arguments for the automatic mask generation class.
Returns:

The Segment Anything model. The automatic instance segmentation class.

def automatic_instance_segmentation( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, **generate_kwargs) -> numpy.ndarray:
 67def automatic_instance_segmentation(
 68    predictor: util.SamPredictor,
 69    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
 70    input_path: Union[Union[os.PathLike, str], np.ndarray],
 71    output_path: Optional[Union[os.PathLike, str]] = None,
 72    embedding_path: Optional[Union[os.PathLike, str]] = None,
 73    key: Optional[str] = None,
 74    ndim: Optional[int] = None,
 75    tile_shape: Optional[Tuple[int, int]] = None,
 76    halo: Optional[Tuple[int, int]] = None,
 77    verbose: bool = True,
 78    return_embeddings: bool = False,
 79    annotate: bool = False,
 80    **generate_kwargs
 81) -> np.ndarray:
 82    """Run automatic segmentation for the input image.
 83
 84    Args:
 85        predictor: The Segment Anything model.
 86        segmenter: The automatic instance segmentation class.
 87        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
 88            or a container file (e.g. hdf5 or zarr).
 89        output_path: The output path where the instance segmentations will be saved.
 90        embedding_path: The path where the embeddings are cached already / will be saved.
 91        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
 92            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
 93        ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
 94            If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
 95        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
 96        halo: Overlap of the tiles for tiled prediction.
 97        verbose: Verbosity flag.
 98        return_embeddings: Whether to return the precomputed image embeddings.
 99        annotate: Whether to activate the annotator for continue annotation process.
100        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
101
102    Returns:
103        The segmentation result.
104    """
105    # Avoid overwriting already stored segmentations.
106    if output_path is not None:
107        output_path = Path(output_path).with_suffix(".tif")
108        if os.path.exists(output_path):
109            print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.")
110            return
111
112    # Load the input image file.
113    if isinstance(input_path, np.ndarray):
114        image_data = input_path
115    else:
116        image_data = util.load_image_data(input_path, key)
117
118    ndim = image_data.ndim if ndim is None else ndim
119
120    if ndim == 2:
121        if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3):
122            raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
123
124        # Precompute the image embeddings.
125        image_embeddings = util.precompute_image_embeddings(
126            predictor=predictor,
127            input_=image_data,
128            save_path=embedding_path,
129            ndim=ndim,
130            tile_shape=tile_shape,
131            halo=halo,
132            verbose=verbose,
133        )
134
135        # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing.
136        if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None:
137            generate_kwargs.update({"tile_shape": tile_shape, "halo": halo})
138
139        segmenter.initialize(image=image_data, image_embeddings=image_embeddings, verbose=verbose)
140        masks = segmenter.generate(**generate_kwargs)
141
142        if isinstance(masks, list):
143            # whether the predictions from 'generate' are list of dict,
144            # which contains additional info req. for post-processing, eg. area per object.
145            if len(masks) == 0:
146                instances = np.zeros(image_data.shape[:2], dtype="uint32")
147            else:
148                instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)
149        else:
150            # if (raw) predictions provided, store them as it is w/o further post-processing.
151            instances = masks
152
153    else:
154        if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
155            raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
156
157        outputs = automatic_3d_segmentation(
158            volume=image_data,
159            predictor=predictor,
160            segmentor=segmenter,
161            embedding_path=embedding_path,
162            tile_shape=tile_shape,
163            halo=halo,
164            verbose=verbose,
165            return_embeddings=return_embeddings,
166            **generate_kwargs
167        )
168
169        if return_embeddings:
170            instances, image_embeddings = outputs
171        else:
172            instances = outputs
173
174    # Allow opening the automatic segmentation in the annotator for further annotation, if desired.
175    if annotate:
176        from micro_sam.sam_annotator import annotator_2d, annotator_3d
177        annotator_function = annotator_2d if ndim == 2 else annotator_3d
178
179        viewer = annotator_function(
180            image=image_data,
181            model_type=predictor.model_name,
182            embedding_path=embedding_path,
183            segmentation_result=instances,  # Initializes the automatic segmentation to the annotator.
184            tile_shape=tile_shape,
185            halo=halo,
186            return_viewer=True,  # Returns the viewer, which allows the user to store the updated segmentations.
187        )
188
189        # Start the GUI here
190        import napari
191        napari.run()
192
193        # We extract the segmentation in "committed_objects" layer, where the user either:
194        # a) Performed interactive segmentation / corrections and committed them, OR
195        # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is.
196        instances = viewer.layers["committed_objects"].data
197
198    # Save the instance segmentation, if 'output_path' provided.
199    if output_path is not None:
200        imageio.imwrite(output_path, instances, compression="zlib")
201        print(f"The segmentation results are stored at '{os.path.abspath(output_path)}'.")
202
203    if return_embeddings:
204        return instances, image_embeddings
205    else:
206        return instances

Run automatic segmentation for the input image.

Arguments:
  • predictor: The Segment Anything model.
  • segmenter: The automatic instance segmentation class.
  • input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
  • output_path: The output path where the instance segmentations will be saved.
  • embedding_path: The path where the embeddings are cached already / will be saved.
  • key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
  • ndim: The dimensionality of the data. By default the dimensionality of the data will be used. If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction.
  • verbose: Verbosity flag.
  • return_embeddings: Whether to return the precomputed image embeddings.
  • annotate: Whether to activate the annotator for continue annotation process.
  • generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:

The segmentation result.