micro_sam.automatic_segmentation

  1import os
  2from glob import glob
  3from tqdm import tqdm
  4from pathlib import Path
  5from typing import Optional, Union, Tuple
  6
  7import numpy as np
  8import imageio.v3 as imageio
  9
 10from torch_em.data.datasets.util import split_kwargs
 11
 12from . import util
 13from .instance_segmentation import (
 14    get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder,
 15    AMGBase, AutomaticMaskGenerator, TiledAutomaticMaskGenerator
 16)
 17from .multi_dimensional_segmentation import automatic_3d_segmentation
 18
 19
 20def get_predictor_and_segmenter(
 21    model_type: str,
 22    checkpoint: Optional[Union[os.PathLike, str]] = None,
 23    device: str = None,
 24    amg: Optional[bool] = None,
 25    is_tiled: bool = False,
 26    **kwargs,
 27) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
 28    """Get the Segment Anything model and class for automatic instance segmentation.
 29
 30    Args:
 31        model_type: The Segment Anything model choice.
 32        checkpoint: The filepath to the stored model checkpoints.
 33        device: The torch device.
 34        amg: Whether to perform automatic segmentation in AMG mode.
 35            Otherwise AIS will be used, which requires a special segmentation decoder.
 36            If not specified AIS will be used if it is available and otherwise AMG will be used.
 37        is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
 38        kwargs: Keyword arguments for the automatic mask generation class.
 39
 40    Returns:
 41        The Segment Anything model.
 42        The automatic instance segmentation class.
 43    """
 44    # Get the device
 45    device = util.get_device(device=device)
 46
 47    # Get the predictor and state for Segment Anything models.
 48    predictor, state = util.get_sam_model(
 49        model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True,
 50    )
 51
 52    if amg is None:
 53        amg = "decoder_state" not in state
 54
 55    if amg:
 56        decoder = None
 57    else:
 58        if "decoder_state" not in state:
 59            raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.")
 60        decoder_state = state["decoder_state"]
 61        decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)
 62
 63    segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs)
 64
 65    return predictor, segmenter
 66
 67
 68def _add_suffix_to_output_path(output_path: Union[str, os.PathLike], suffix: str) -> str:
 69    fpath = Path(output_path).resolve()
 70    fext = fpath.suffix if fpath.suffix else ".tif"
 71    return str(fpath.with_name(f"{fpath.stem}{suffix}{fext}"))
 72
 73
 74def automatic_instance_segmentation(
 75    predictor: util.SamPredictor,
 76    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
 77    input_path: Union[Union[os.PathLike, str], np.ndarray],
 78    output_path: Optional[Union[os.PathLike, str]] = None,
 79    embedding_path: Optional[Union[os.PathLike, str]] = None,
 80    key: Optional[str] = None,
 81    ndim: Optional[int] = None,
 82    tile_shape: Optional[Tuple[int, int]] = None,
 83    halo: Optional[Tuple[int, int]] = None,
 84    verbose: bool = True,
 85    return_embeddings: bool = False,
 86    annotate: bool = False,
 87    batch_size: int = 1,
 88    **generate_kwargs
 89) -> np.ndarray:
 90    """Run automatic segmentation for the input image.
 91
 92    Args:
 93        predictor: The Segment Anything model.
 94        segmenter: The automatic instance segmentation class.
 95        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
 96            or a container file (e.g. hdf5 or zarr).
 97        output_path: The output path where the instance segmentations will be saved.
 98        embedding_path: The path where the embeddings are cached already / will be saved.
 99        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
100            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
101        ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
102            If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
103        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
104        halo: Overlap of the tiles for tiled prediction.
105        verbose: Verbosity flag.
106        return_embeddings: Whether to return the precomputed image embeddings.
107        annotate: Whether to activate the annotator for continue annotation process.
108        batch_size: The batch size to compute image embeddings over tiles / z-planes.
109            By default, does it sequentially, i.e. one after the other.
110        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
111
112    Returns:
113        The segmentation result.
114    """
115    # Avoid overwriting already stored segmentations.
116    if output_path is not None:
117        output_path = Path(output_path).with_suffix(".tif")
118        if os.path.exists(output_path):
119            print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.")
120            return
121
122    # Load the input image file.
123    if isinstance(input_path, np.ndarray):
124        image_data = input_path
125    else:
126        image_data = util.load_image_data(input_path, key)
127
128    ndim = image_data.ndim if ndim is None else ndim
129
130    # We perform additional post-processing for AMG-only.
131    # Otherwise, we ignore additional post-processing for AIS.
132    if isinstance(segmenter, InstanceSegmentationWithDecoder):
133        generate_kwargs["output_mode"] = None
134
135    if ndim == 2:
136        if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3):
137            raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
138
139        # Precompute the image embeddings.
140        image_embeddings = util.precompute_image_embeddings(
141            predictor=predictor,
142            input_=image_data,
143            save_path=embedding_path,
144            ndim=ndim,
145            tile_shape=tile_shape,
146            halo=halo,
147            verbose=verbose,
148            batch_size=batch_size,
149        )
150        initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose)
151
152        # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing.
153        # In this case, we also add the batch size to the initialize kwargs,
154        # so that the segmentation decoder can be applied in a batched fashion.
155        if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None:
156            generate_kwargs.update({"tile_shape": tile_shape, "halo": halo})
157            initialize_kwargs["batch_size"] = batch_size
158
159        segmenter.initialize(**initialize_kwargs)
160        masks = segmenter.generate(**generate_kwargs)
161
162        if isinstance(masks, list):
163            # whether the predictions from 'generate' are list of dict,
164            # which contains additional info req. for post-processing, eg. area per object.
165            if len(masks) == 0:
166                instances = np.zeros(image_data.shape[:2], dtype="uint32")
167            else:
168                instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)
169        else:
170            # if (raw) predictions provided, store them as it is w/o further post-processing.
171            instances = masks
172
173    else:
174        if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
175            raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
176
177        instances, image_embeddings = automatic_3d_segmentation(
178            volume=image_data,
179            predictor=predictor,
180            segmentor=segmenter,
181            embedding_path=embedding_path,
182            tile_shape=tile_shape,
183            halo=halo,
184            verbose=verbose,
185            return_embeddings=True,
186            batch_size=batch_size,
187            **generate_kwargs
188        )
189
190    # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage.
191    if output_path is not None:
192        _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path
193        imageio.imwrite(_output_path, instances, compression="zlib")
194        print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.")
195
196    # Allow opening the automatic segmentation in the annotator for further annotation, if desired.
197    if annotate:
198        from micro_sam.sam_annotator import annotator_2d, annotator_3d
199        annotator_function = annotator_2d if ndim == 2 else annotator_3d
200
201        viewer = annotator_function(
202            image=image_data,
203            model_type=predictor.model_name,
204            embedding_path=image_embeddings,  # Providing the precomputed image embeddings.
205            segmentation_result=instances,  # Initializes the automatic segmentation to the annotator.
206            tile_shape=tile_shape,
207            halo=halo,
208            return_viewer=True,  # Returns the viewer, which allows the user to store the updated segmentations.
209        )
210
211        # Start the GUI here
212        import napari
213        napari.run()
214
215        # We extract the segmentation in "committed_objects" layer, where the user either:
216        # a) Performed interactive segmentation / corrections and committed them, OR
217        # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is.
218        instances = viewer.layers["committed_objects"].data
219
220        # Save the instance segmentation, if 'output_path' provided.
221        if output_path is not None:
222            imageio.imwrite(output_path, instances, compression="zlib")
223            print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.")
224
225    if return_embeddings:
226        return instances, image_embeddings
227    else:
228        return instances
229
230
231def _get_inputs_from_paths(paths, pattern):
232    "Function to get all filepaths in a directory."
233
234    if isinstance(paths, str):
235        paths = [paths]
236
237    fpaths = []
238    for path in paths:
239        if _has_extension(path):  # It is just one filepath.
240            fpaths.append(path)
241        else:  # Otherwise, if the path is a directory, fetch all inputs provided with a pattern.
242            assert pattern is not None, \
243                f"You must provide a pattern to search for files in the directory: '{os.path.abspath(path)}'."
244            fpaths.extend(glob(os.path.join(path, pattern)))
245
246    return fpaths
247
248
249def _has_extension(fpath: Union[os.PathLike, str]) -> bool:
250    "Returns whether the provided path has an extension or not."
251    return bool(os.path.splitext(fpath)[1])
252
253
254def main():
255    """@private"""
256    import argparse
257
258    available_models = list(util.get_model_names())
259    available_models = ", ".join(available_models)
260
261    parser = argparse.ArgumentParser(description="Run automatic segmentation for an image.")
262    parser.add_argument(
263        "-i", "--input_path", required=True, type=str, nargs="+",
264        help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "
265        "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
266    )
267    parser.add_argument(
268        "-o", "--output_path", required=True, type=str,
269        help="The filepath to store the instance segmentation. The current support stores segmentation in a 'tif' file."
270    )
271    parser.add_argument(
272        "-e", "--embedding_path", default=None, type=str, help="The path where the embeddings will be saved."
273    )
274    parser.add_argument(
275        "--pattern", type=str, help="Pattern / wildcard for selecting files in a folder. To select all files use '*'."
276    )
277    parser.add_argument(
278        "-k", "--key", default=None, type=str,
279        help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
280        "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
281    )
282    parser.add_argument(
283        "-m", "--model_type", default=util._DEFAULT_MODEL, type=str,
284        help=f"The segment anything model that will be used, one of {available_models}."
285    )
286    parser.add_argument(
287        "-c", "--checkpoint", default=None, type=str, help="Checkpoint from which the SAM model will be loaded."
288    )
289    parser.add_argument(
290        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None
291    )
292    parser.add_argument(
293        "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None
294    )
295    parser.add_argument(
296        "-n", "--ndim", default=None, type=int,
297        help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension."
298    )
299    parser.add_argument(
300        "--mode", default="auto", type=str,
301        help="The choice of automatic segmentation with the Segment Anything models. Either 'auto', 'amg' or 'ais'."
302    )
303    parser.add_argument(
304        "--annotate", action="store_true",
305        help="Whether to continue annotation after the automatic segmentation is generated."
306    )
307    parser.add_argument(
308        "-d", "--device", default=None, type=str,
309        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
310        "By default the most performant available device will be selected."
311    )
312    parser.add_argument(
313        "--batch_size", type=int, default=1,
314        help="The batch size for computing image embeddings over tiles or z-plane. "
315        "By default, computes the image embeddings for one tile / z-plane at a time."
316    )
317    parser.add_argument(
318        "-v", "--verbose", action="store_true", help="Whether to allow verbosity of outputs."
319    )
320
321    args, parameter_args = parser.parse_known_args()
322
323    def _convert_argval(value):
324        # The values for the parsed arguments need to be in the expected input structure as provided.
325        # i.e. integers and floats should be in their original types.
326        try:
327            return int(value)
328        except ValueError:
329            return float(value)
330
331    # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to
332    # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS)
333    extra_kwargs = {
334        parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2)
335    }
336
337    # Separate extra arguments as per where they should be passed in the automatic segmentation class.
338    # This is done to ensure the extra arguments are allocated to the desired location.
339    # eg. for AMG, 'points_per_side' is expected by '__init__',
340    # and 'stability_score_thresh' is expected in 'generate' method.
341    amg_class = AutomaticMaskGenerator if args.tile_shape is None else TiledAutomaticMaskGenerator
342    amg_kwargs, generate_kwargs = split_kwargs(amg_class, **extra_kwargs)
343
344    # Validate for the expected automatic segmentation mode.
345    # By default, it is set to 'auto', i.e. searches for the decoder state to prioritize AIS for finetuned models.
346    # Otherwise, runs AMG for all models in any case.
347    amg = None
348    if args.mode != "auto":
349        assert args.mode in ["ais", "amg"], \
350            f"'{args.mode}' is not a valid automatic segmentation mode. Please choose either 'amg' or 'ais'."
351        amg = (args.mode == "amg")
352
353    predictor, segmenter = get_predictor_and_segmenter(
354        model_type=args.model_type,
355        checkpoint=args.checkpoint,
356        device=args.device,
357        amg=amg,
358        is_tiled=args.tile_shape is not None,
359        **amg_kwargs,
360    )
361
362    # Get the filepaths to input images (and other paths to store stuff, eg. segmentations and embeddings)
363    # Check whether the inputs are as expected, otherwise assort them.
364    input_paths = _get_inputs_from_paths(args.input_path, args.pattern)
365    assert len(input_paths) > 0, "'micro-sam' could not extract any image data internally."
366
367    output_path = args.output_path
368    embedding_path = args.embedding_path
369    has_one_input = len(input_paths) == 1
370
371    # Run automatic segmentation per image.
372    for path in tqdm(input_paths, desc="Run automatic segmentation"):
373        if has_one_input:  # if we have one image only.
374            _output_fpath = str(Path(output_path).with_suffix(".tif"))
375            _embedding_fpath = embedding_path
376
377        else:  # if we have multiple image, we need to make the other target filepaths compatible.
378            # Let's check for 'embedding_path'.
379            _embedding_fpath = embedding_path
380            if embedding_path:
381                if _has_extension(embedding_path):  # in this case, use filename as addl. suffix to provided path.
382                    _embedding_fpath = str(Path(embedding_path).with_suffix(".zarr"))
383                    _embedding_fpath = _embedding_fpath.replace(".zarr", f"_{Path(path).stem}.zarr")
384                else:   # otherwise, for directory, use image filename for multiple images.
385                    os.makedirs(embedding_path, exist_ok=True)
386                    _embedding_fpath = os.path.join(embedding_path, Path(os.path.basename(path)).with_suffix(".zarr"))
387
388            # Next, let's check for output file to store segmentation.
389            if _has_extension(output_path):  # in this case, use filename as addl. suffix to provided path.
390                _output_fpath = str(Path(output_path).with_suffix(".tif"))
391                _output_fpath = _output_fpath.replace(".tif", f"_{Path(path).stem}.tif")
392            else:  # otherwise, for directory, use image filename for multiple images.
393                os.makedirs(output_path, exist_ok=True)
394                _output_fpath = os.path.join(output_path, Path(os.path.basename(path)).with_suffix(".tif"))
395
396        automatic_instance_segmentation(
397            predictor=predictor,
398            segmenter=segmenter,
399            input_path=path,
400            output_path=_output_fpath,
401            embedding_path=_embedding_fpath,
402            key=args.key,
403            ndim=args.ndim,
404            tile_shape=args.tile_shape,
405            halo=args.halo,
406            annotate=args.annotate,
407            verbose=args.verbose,
408            batch_size=args.batch_size,
409            **generate_kwargs,
410        )
def get_predictor_and_segmenter( model_type: str, checkpoint: Union[str, os.PathLike, NoneType] = None, device: str = None, amg: Optional[bool] = None, is_tiled: bool = False, **kwargs) -> Tuple[mobile_sam.predictor.SamPredictor, Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder]]:
21def get_predictor_and_segmenter(
22    model_type: str,
23    checkpoint: Optional[Union[os.PathLike, str]] = None,
24    device: str = None,
25    amg: Optional[bool] = None,
26    is_tiled: bool = False,
27    **kwargs,
28) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
29    """Get the Segment Anything model and class for automatic instance segmentation.
30
31    Args:
32        model_type: The Segment Anything model choice.
33        checkpoint: The filepath to the stored model checkpoints.
34        device: The torch device.
35        amg: Whether to perform automatic segmentation in AMG mode.
36            Otherwise AIS will be used, which requires a special segmentation decoder.
37            If not specified AIS will be used if it is available and otherwise AMG will be used.
38        is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
39        kwargs: Keyword arguments for the automatic mask generation class.
40
41    Returns:
42        The Segment Anything model.
43        The automatic instance segmentation class.
44    """
45    # Get the device
46    device = util.get_device(device=device)
47
48    # Get the predictor and state for Segment Anything models.
49    predictor, state = util.get_sam_model(
50        model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True,
51    )
52
53    if amg is None:
54        amg = "decoder_state" not in state
55
56    if amg:
57        decoder = None
58    else:
59        if "decoder_state" not in state:
60            raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.")
61        decoder_state = state["decoder_state"]
62        decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)
63
64    segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs)
65
66    return predictor, segmenter

Get the Segment Anything model and class for automatic instance segmentation.

Arguments:
  • model_type: The Segment Anything model choice.
  • checkpoint: The filepath to the stored model checkpoints.
  • device: The torch device.
  • amg: Whether to perform automatic segmentation in AMG mode. Otherwise AIS will be used, which requires a special segmentation decoder. If not specified AIS will be used if it is available and otherwise AMG will be used.
  • is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
  • kwargs: Keyword arguments for the automatic mask generation class.
Returns:

The Segment Anything model. The automatic instance segmentation class.

def automatic_instance_segmentation( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[str, os.PathLike, NoneType] = None, embedding_path: Union[str, os.PathLike, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, batch_size: int = 1, **generate_kwargs) -> numpy.ndarray:
 75def automatic_instance_segmentation(
 76    predictor: util.SamPredictor,
 77    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
 78    input_path: Union[Union[os.PathLike, str], np.ndarray],
 79    output_path: Optional[Union[os.PathLike, str]] = None,
 80    embedding_path: Optional[Union[os.PathLike, str]] = None,
 81    key: Optional[str] = None,
 82    ndim: Optional[int] = None,
 83    tile_shape: Optional[Tuple[int, int]] = None,
 84    halo: Optional[Tuple[int, int]] = None,
 85    verbose: bool = True,
 86    return_embeddings: bool = False,
 87    annotate: bool = False,
 88    batch_size: int = 1,
 89    **generate_kwargs
 90) -> np.ndarray:
 91    """Run automatic segmentation for the input image.
 92
 93    Args:
 94        predictor: The Segment Anything model.
 95        segmenter: The automatic instance segmentation class.
 96        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
 97            or a container file (e.g. hdf5 or zarr).
 98        output_path: The output path where the instance segmentations will be saved.
 99        embedding_path: The path where the embeddings are cached already / will be saved.
100        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
101            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
102        ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
103            If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
104        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
105        halo: Overlap of the tiles for tiled prediction.
106        verbose: Verbosity flag.
107        return_embeddings: Whether to return the precomputed image embeddings.
108        annotate: Whether to activate the annotator for continue annotation process.
109        batch_size: The batch size to compute image embeddings over tiles / z-planes.
110            By default, does it sequentially, i.e. one after the other.
111        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
112
113    Returns:
114        The segmentation result.
115    """
116    # Avoid overwriting already stored segmentations.
117    if output_path is not None:
118        output_path = Path(output_path).with_suffix(".tif")
119        if os.path.exists(output_path):
120            print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.")
121            return
122
123    # Load the input image file.
124    if isinstance(input_path, np.ndarray):
125        image_data = input_path
126    else:
127        image_data = util.load_image_data(input_path, key)
128
129    ndim = image_data.ndim if ndim is None else ndim
130
131    # We perform additional post-processing for AMG-only.
132    # Otherwise, we ignore additional post-processing for AIS.
133    if isinstance(segmenter, InstanceSegmentationWithDecoder):
134        generate_kwargs["output_mode"] = None
135
136    if ndim == 2:
137        if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3):
138            raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
139
140        # Precompute the image embeddings.
141        image_embeddings = util.precompute_image_embeddings(
142            predictor=predictor,
143            input_=image_data,
144            save_path=embedding_path,
145            ndim=ndim,
146            tile_shape=tile_shape,
147            halo=halo,
148            verbose=verbose,
149            batch_size=batch_size,
150        )
151        initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose)
152
153        # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing.
154        # In this case, we also add the batch size to the initialize kwargs,
155        # so that the segmentation decoder can be applied in a batched fashion.
156        if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None:
157            generate_kwargs.update({"tile_shape": tile_shape, "halo": halo})
158            initialize_kwargs["batch_size"] = batch_size
159
160        segmenter.initialize(**initialize_kwargs)
161        masks = segmenter.generate(**generate_kwargs)
162
163        if isinstance(masks, list):
164            # whether the predictions from 'generate' are list of dict,
165            # which contains additional info req. for post-processing, eg. area per object.
166            if len(masks) == 0:
167                instances = np.zeros(image_data.shape[:2], dtype="uint32")
168            else:
169                instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)
170        else:
171            # if (raw) predictions provided, store them as it is w/o further post-processing.
172            instances = masks
173
174    else:
175        if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
176            raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
177
178        instances, image_embeddings = automatic_3d_segmentation(
179            volume=image_data,
180            predictor=predictor,
181            segmentor=segmenter,
182            embedding_path=embedding_path,
183            tile_shape=tile_shape,
184            halo=halo,
185            verbose=verbose,
186            return_embeddings=True,
187            batch_size=batch_size,
188            **generate_kwargs
189        )
190
191    # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage.
192    if output_path is not None:
193        _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path
194        imageio.imwrite(_output_path, instances, compression="zlib")
195        print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.")
196
197    # Allow opening the automatic segmentation in the annotator for further annotation, if desired.
198    if annotate:
199        from micro_sam.sam_annotator import annotator_2d, annotator_3d
200        annotator_function = annotator_2d if ndim == 2 else annotator_3d
201
202        viewer = annotator_function(
203            image=image_data,
204            model_type=predictor.model_name,
205            embedding_path=image_embeddings,  # Providing the precomputed image embeddings.
206            segmentation_result=instances,  # Initializes the automatic segmentation to the annotator.
207            tile_shape=tile_shape,
208            halo=halo,
209            return_viewer=True,  # Returns the viewer, which allows the user to store the updated segmentations.
210        )
211
212        # Start the GUI here
213        import napari
214        napari.run()
215
216        # We extract the segmentation in "committed_objects" layer, where the user either:
217        # a) Performed interactive segmentation / corrections and committed them, OR
218        # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is.
219        instances = viewer.layers["committed_objects"].data
220
221        # Save the instance segmentation, if 'output_path' provided.
222        if output_path is not None:
223            imageio.imwrite(output_path, instances, compression="zlib")
224            print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.")
225
226    if return_embeddings:
227        return instances, image_embeddings
228    else:
229        return instances

Run automatic segmentation for the input image.

Arguments:
  • predictor: The Segment Anything model.
  • segmenter: The automatic instance segmentation class.
  • input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
  • output_path: The output path where the instance segmentations will be saved.
  • embedding_path: The path where the embeddings are cached already / will be saved.
  • key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
  • ndim: The dimensionality of the data. By default the dimensionality of the data will be used. If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction.
  • verbose: Verbosity flag.
  • return_embeddings: Whether to return the precomputed image embeddings.
  • annotate: Whether to activate the annotator for continue annotation process.
  • batch_size: The batch size to compute image embeddings over tiles / z-planes. By default, does it sequentially, i.e. one after the other.
  • generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:

The segmentation result.