micro_sam.automatic_segmentation

  1import os
  2import warnings
  3from glob import glob
  4from tqdm import tqdm
  5from pathlib import Path
  6from functools import partial
  7from typing import Dict, List, Optional, Union, Tuple
  8
  9import numpy as np
 10import imageio.v3 as imageio
 11
 12from torch_em.data.datasets.util import split_kwargs
 13
 14from . import util
 15from .instance_segmentation import (
 16    get_instance_segmentation_generator, get_decoder, AMGBase,
 17    AutomaticMaskGenerator, TiledAutomaticMaskGenerator,
 18    AutomaticPromptGenerator, TiledAutomaticPromptGenerator,
 19    InstanceSegmentationWithDecoder, TiledInstanceSegmentationWithDecoder,
 20    DEFAULT_SEGMENTATION_MODE_WITH_DECODER,
 21)
 22from .multi_dimensional_segmentation import automatic_3d_segmentation, automatic_tracking_implementation
 23
 24
 25def get_predictor_and_segmenter(
 26    model_type: str,
 27    checkpoint: Optional[Union[os.PathLike, str]] = None,
 28    device: str = None,
 29    segmentation_mode: Optional[str] = None,
 30    is_tiled: bool = False,
 31    predictor=None,
 32    state=None,
 33    **kwargs,
 34) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
 35    f"""Get the Segment Anything model and class for automatic instance segmentation.
 36
 37    Args:
 38        model_type: The Segment Anything model choice.
 39        checkpoint: The filepath to the stored model checkpoints.
 40        device: The torch device. By default, automatically chooses the best available device.
 41        segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
 42            By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used
 43            if a decoder is passed, otherwise 'amg' is used.
 44        is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
 45            By default, set to 'False'.
 46        predictor: The pre-loaded predictor (optional).
 47        state: The pre-loaded state (optional).
 48        kwargs: Keyword arguments for the automatic mask generation class.
 49
 50    Returns:
 51        The Segment Anything model.
 52        The automatic instance segmentation class.
 53    """
 54    # Get the predictor and state for Segment Anything Model.
 55    if predictor is None:
 56        device = util.get_device(device=device)
 57        predictor, state = util.get_sam_model(
 58            model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True
 59        )
 60    else:
 61        assert state is not None
 62
 63    if segmentation_mode in (None, "auto"):
 64        segmentation_mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg"
 65
 66    if segmentation_mode.lower() == "amg":
 67        decoder = None
 68    else:
 69        if "decoder_state" not in state:
 70            raise RuntimeError(
 71                f"You have passed 'segmentation_mode={segmentation_mode}', but your model does not contain a decoder."
 72            )
 73        decoder_state = state["decoder_state"]
 74        decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)
 75
 76    segmenter = get_instance_segmentation_generator(
 77        predictor=predictor, is_tiled=is_tiled, decoder=decoder, segmentation_mode=segmentation_mode, **kwargs
 78    )
 79    return predictor, segmenter
 80
 81
 82def _add_suffix_to_output_path(output_path: Union[str, os.PathLike], suffix: str) -> str:
 83    fpath = Path(output_path).resolve()
 84    fext = fpath.suffix if fpath.suffix else ".tif"
 85    return str(fpath.with_name(f"{fpath.stem}{suffix}{fext}"))
 86
 87
 88def automatic_tracking(
 89    predictor: util.SamPredictor,
 90    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
 91    input_path: Union[Union[os.PathLike, str], np.ndarray],
 92    output_path: Optional[Union[os.PathLike, str]] = None,
 93    embedding_path: Optional[Union[os.PathLike, str]] = None,
 94    key: Optional[str] = None,
 95    tile_shape: Optional[Tuple[int, int]] = None,
 96    halo: Optional[Tuple[int, int]] = None,
 97    verbose: bool = True,
 98    return_embeddings: bool = False,
 99    annotate: bool = False,
100    batch_size: int = 1,
101    **generate_kwargs
102) -> Tuple[np.ndarray, List[Dict]]:
103    """Run automatic tracking for the input timeseries.
104
105    Args:
106        predictor: The Segment Anything model.
107        segmenter: The automatic instance segmentation class.
108        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
109            or a container file (e.g. hdf5 or zarr).
110        output_path: The folder where the tracking outputs will be saved in CTC format.
111        embedding_path: The path where the embeddings are cached already / will be saved.
112        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
113            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
114        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
115        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
116        verbose: Verbosity flag. By default, set to 'True'.
117        return_embeddings: Whether to return the precomputed image embeddings.
118            By default, does not return the embeddings.
119        annotate: Whether to activate the annotator for continue annotation process.
120            By default, does not activate the annotator.
121        batch_size: The batch size to compute image embeddings over tiles / z-planes.
122            By default, does it sequentially, i.e. one after the other.
123        generate_kwargs: optional keyword arguments for the generate function of the AMG, APG, or AIS class.
124
125    Returns:
126        The tracking result as a timeseries, where each object is labeled by its track id.
127        The lineages representing cell divisions, stored as a dictionary.
128    """
129    # Load the input image file.
130    if isinstance(input_path, np.ndarray):
131        image_data = input_path
132    else:
133        image_data = util.load_image_data(input_path, key)
134
135    if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
136        raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
137
138    gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent")
139    segmentation, lineage, image_embeddings = automatic_tracking_implementation(
140        image_data,
141        predictor,
142        segmenter,
143        embedding_path=embedding_path,
144        gap_closing=gap_closing,
145        min_time_extent=min_time_extent,
146        tile_shape=tile_shape,
147        halo=halo,
148        verbose=verbose,
149        batch_size=batch_size,
150        return_embeddings=True,
151        output_folder=output_path,
152        **generate_kwargs,
153    )
154
155    if annotate:
156        # TODO We need to support initialization of the tracking annotator with the tracking result for this.
157        raise NotImplementedError("Annotation after running the automated tracking is currently not supported.")
158
159    if return_embeddings:
160        return segmentation, lineage, image_embeddings
161    else:
162        return segmentation, lineage
163
164
165def automatic_instance_segmentation(
166    predictor: util.SamPredictor,
167    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
168    input_path: Union[Union[os.PathLike, str], np.ndarray],
169    output_path: Optional[Union[os.PathLike, str]] = None,
170    embedding_path: Optional[Union[os.PathLike, str]] = None,
171    key: Optional[str] = None,
172    ndim: Optional[int] = None,
173    tile_shape: Optional[Tuple[int, int]] = None,
174    halo: Optional[Tuple[int, int]] = None,
175    verbose: bool = True,
176    return_embeddings: bool = False,
177    annotate: bool = False,
178    batch_size: int = 1,
179    **generate_kwargs
180) -> np.ndarray:
181    """Run automatic segmentation for the input image.
182
183    Args:
184        predictor: The Segment Anything model.
185        segmenter: The automatic instance segmentation class.
186        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
187            or a container file (e.g. hdf5 or zarr).
188        output_path: The output path where the instance segmentations will be saved.
189        embedding_path: The path where the embeddings are cached already / will be saved.
190        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
191            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
192        ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
193            If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
194        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
195        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
196        verbose: Verbosity flag. By default, set to 'True'.
197        return_embeddings: Whether to return the precomputed image embeddings.
198            By default, does not return the embeddings.
199        annotate: Whether to activate the annotator for continue annotation process.
200            By default, does not activate the annotator.
201        batch_size: The batch size to compute image embeddings over tiles / z-planes.
202            By default, does it sequentially, i.e. one after the other.
203        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
204
205    Returns:
206        The segmentation result.
207    """
208    # Avoid overwriting already stored segmentations.
209    if output_path is not None:
210        output_path = Path(output_path).with_suffix(".tif")
211        if os.path.exists(output_path):
212            print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.")
213            return
214
215    # Load the input image file.
216    if isinstance(input_path, np.ndarray):
217        image_data = input_path
218    else:
219        image_data = util.load_image_data(input_path, key)
220
221    ndim = image_data.ndim if ndim is None else ndim
222
223    if ndim == 2:
224        if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3):
225            raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
226
227        # Precompute the image embeddings.
228        image_embeddings = util.precompute_image_embeddings(
229            predictor=predictor,
230            input_=image_data,
231            save_path=embedding_path,
232            ndim=ndim,
233            tile_shape=tile_shape,
234            halo=halo,
235            verbose=verbose,
236            batch_size=batch_size,
237        )
238        initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose)
239
240        # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing.
241        # In this case, we also add the batch size to the initialize kwargs,
242        # so that the segmentation decoder can be applied in a batched fashion.
243        if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None:
244            generate_kwargs.update({"tile_shape": tile_shape, "halo": halo})
245            initialize_kwargs["batch_size"] = batch_size
246
247        segmenter.initialize(**initialize_kwargs)
248        instances = segmenter.generate(**generate_kwargs)
249
250    else:
251        if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
252            raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
253
254        instances, image_embeddings = automatic_3d_segmentation(
255            volume=image_data,
256            predictor=predictor,
257            segmentor=segmenter,
258            embedding_path=embedding_path,
259            tile_shape=tile_shape,
260            halo=halo,
261            verbose=verbose,
262            return_embeddings=True,
263            batch_size=batch_size,
264            **generate_kwargs
265        )
266
267    # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage.
268    if output_path is not None:
269        _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path
270        imageio.imwrite(_output_path, instances, compression="zlib")
271        if verbose:
272            print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.")
273
274    # Allow opening the automatic segmentation in the annotator for further annotation, if desired.
275    if annotate:
276        from micro_sam.sam_annotator import annotator_2d, annotator_3d
277        annotator_function = annotator_2d if ndim == 2 else annotator_3d
278
279        viewer = annotator_function(
280            image=image_data,
281            model_type=predictor.model_name,
282            embedding_path=image_embeddings,  # Providing the precomputed image embeddings.
283            segmentation_result=instances,  # Initializes the automatic segmentation to the annotator.
284            tile_shape=tile_shape,
285            halo=halo,
286            return_viewer=True,  # Returns the viewer, which allows the user to store the updated segmentations.
287        )
288
289        # Start the GUI here
290        import napari
291        napari.run()
292
293        # We extract the segmentation in "committed_objects" layer, where the user either:
294        # a) Performed interactive segmentation / corrections and committed them, OR
295        # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is.
296        instances = viewer.layers["committed_objects"].data
297
298        # Save the instance segmentation, if 'output_path' provided.
299        if output_path is not None:
300            imageio.imwrite(output_path, instances, compression="zlib")
301            if verbose:
302                print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.")
303
304    if return_embeddings:
305        return instances, image_embeddings
306    else:
307        return instances
308
309
310def _get_inputs_from_paths(paths, pattern):
311    "Function to get all filepaths in a directory."
312
313    if isinstance(paths, str):
314        paths = [paths]
315
316    fpaths = []
317    for path in paths:
318        if os.path.isfile(path):  # It is just one filepath.
319            fpaths.append(path)
320        else:  # Otherwise, if the path is a directory, fetch all inputs provided with a pattern.
321            assert pattern is not None, \
322                f"You must provide a pattern to search for files in the directory: '{os.path.abspath(path)}'."
323            fpaths.extend(glob(os.path.join(path, pattern)))
324
325    return fpaths
326
327
328def main():
329    """@private"""
330    import argparse
331
332    available_models = list(util.get_model_names())
333    available_models = ", ".join(available_models)
334
335    parser = argparse.ArgumentParser(
336        description="Run automatic segmentation or tracking for 2d, 3d or timeseries data.\n"
337        "Either a single input file or multiple input files are supported. You can specify multiple files "
338        "by either providing multiple filepaths to the '--i/--input_paths' argument, or by providing an argument "
339        "to '--pattern' to use a wildcard pattern ('*') for selecting multiple files.\n"
340        "NOTE: for automatic 3d segmentation or tracking the data has to be stored as volume / timeseries, "
341        "stacking individual tif images is not supported.\n"
342        "Segmentation is performed using one of the three modes supported by micro_sam: \n"
343        "automatic instance segmentation (AIS), automatic prompt generation (APG) or automatic mask generation (AMG).\n"
344        "In addition to the options listed below, "
345        "you can also passed additional arguments for these three segmentation modes:\n"
346        "For AIS: '--center_distance_threshold', '--boundary_distance_threshold' and other arguments of `InstanceSegmentationWithDecoder.generate`.\n"  # noqa
347        "FOR APG: '--center_distance_threshold', '--boundary_distance_threshold' and other arguments of `AutomaticPromptGenerator.generate`.\n"  # noqa
348        "For AMG: '--pred_iou_thresh', '--stability_score_thresh' and other arguments of `AutomaticMaskGenerator.generate`."  # noqa
349    )
350    parser.add_argument(
351        "-i", "--input_path", required=True, type=str, nargs="+",
352        help="The filepath(s) to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "  # noqa
353        "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
354    )
355    parser.add_argument(
356        "-o", "--output_path", required=True, type=str,
357        help="The filepath to store the results. If multiple inputs are provied, "
358        "this should be a folder. For a single image, you should provide the path to a tif file for the output segmentation."  # noqa
359        "NOTE: Segmentation results are stored as tif files, tracking results in the CTC fil format ."
360    )
361    parser.add_argument(
362        "-e", "--embedding_path", default=None, type=str,
363        help="An optional path where the embeddings will be saved. If multiple inputs are provided, "
364        "this should be a folder. Otherwise you can store embeddings in single zarr file."
365    )
366    parser.add_argument(
367        "--pattern", type=str, help="Pattern / wildcard for selecting files in a folder. To select all files use '*'."
368    )
369    parser.add_argument(
370        "-k", "--key", default=None, type=str,
371        help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
372        "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
373    )
374    parser.add_argument(
375        "-m", "--model_type", default=util._DEFAULT_MODEL, type=str,
376        help=f"The segment anything model that will be used, one of {available_models}."
377    )
378    parser.add_argument(
379        "-c", "--checkpoint", default=None, type=str, help="Checkpoint from which the SAM model will be loaded."
380    )
381    parser.add_argument(
382        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None
383    )
384    parser.add_argument(
385        "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None
386    )
387    parser.add_argument(
388        "-n", "--ndim", default=None, type=int,
389        help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension."
390    )
391    parser.add_argument(
392        "--mode", default="auto", type=str,
393        help="The choice of automatic segmentation mode. Either 'auto', 'amg', 'apg', or 'ais'."
394    )
395    parser.add_argument(
396        "--annotate", action="store_true",
397        help="Whether to continue annotation after the automatic segmentation is generated."
398    )
399    parser.add_argument(
400        "-d", "--device", default=None, type=str,
401        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
402        "By default the most performant available device will be selected."
403    )
404    parser.add_argument(
405        "--batch_size", type=int, default=1,
406        help="The batch size for computing image embeddings over tiles or z-plane. "
407        "By default, computes the image embeddings for one tile / z-plane at a time."
408    )
409    parser.add_argument(
410        "--tracking", action="store_true", help="Run automatic tracking instead of instance segmentation. "
411        "NOTE: It is only supported for timeseries inputs."
412    )
413    parser.add_argument(
414        "-v", "--verbose", action="store_true", help="Whether to allow verbosity of outputs."
415    )
416
417    args, parameter_args = parser.parse_known_args()
418
419    def _convert_argval(value):
420        # The values for the parsed arguments need to be in the expected input structure as provided.
421        # i.e. integers and floats should be in their original types.
422        try:
423            return int(value)
424        except ValueError:
425            return float(value)
426
427    # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to
428    # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS)
429    extra_kwargs = {
430        parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2)
431    }
432
433    # Separate extra arguments as per where they should be passed in the automatic segmentation class.
434    # This is done to ensure the extra arguments are allocated to the desired location.
435    # eg. for AMG, 'points_per_side' is expected by '__init__',
436    # and 'stability_score_thresh' is expected in 'generate' method.
437    mode = args.mode
438    if mode in ("auto", None):
439        # We have to load the state to see if we have a decoder in this case.
440        device = util.get_device(device=args.device)
441        predictor, state = util.get_sam_model(
442            model_type=args.model_type, device=device, checkpoint_path=args.checkpoint, return_state=True
443        )
444        mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg"
445    else:
446        predictor, state = None, None
447
448    if mode.lower() == "amg":
449        segmenter_class = AutomaticMaskGenerator if args.tile_shape is None else TiledAutomaticMaskGenerator
450    elif mode.lower() == "ais":
451        segmenter_class = InstanceSegmentationWithDecoder if args.tile_shape is None else\
452            TiledInstanceSegmentationWithDecoder
453    elif mode.lower() == "apg":
454        segmenter_class = AutomaticPromptGenerator if args.tile_shape is None else TiledAutomaticPromptGenerator
455    else:
456        raise ValueError(f"Invalid segmentation_mode: {mode}. Choose one of 'amg', 'ais', or 'apg'.")
457    init_kwargs, generate_kwargs = split_kwargs(segmenter_class, **extra_kwargs)
458
459    predictor, segmenter = get_predictor_and_segmenter(
460        model_type=args.model_type,
461        checkpoint=args.checkpoint,
462        device=args.device,
463        segmentation_mode=mode,
464        is_tiled=args.tile_shape is not None,
465        predictor=predictor,
466        state=state,
467        **init_kwargs,
468    )
469
470    # Get the filepaths to input images (and other paths to store stuff, eg. segmentations and embeddings)
471    # Check whether the inputs are as expected, otherwise assort them.
472    input_paths = _get_inputs_from_paths(args.input_path, args.pattern)
473    assert len(input_paths) > 0, "'micro-sam' could not extract any image data internally."
474
475    output_path = args.output_path
476    embedding_path = args.embedding_path
477    has_one_input = len(input_paths) == 1
478
479    instance_seg_function = automatic_tracking if args.tracking else partial(
480        automatic_instance_segmentation, ndim=args.ndim
481    )
482
483    # Run automatic segmentation per image.
484    for input_path in tqdm(input_paths, desc="Run automatic " + ("tracking" if args.tracking else "segmentation")):
485        if has_one_input:  # When we have only one image / volume.
486            _embedding_fpath = embedding_path  # Either folder or zarr file, would work for both.
487
488            output_fdir = os.path.splitext(output_path)[0]
489            os.makedirs(output_fdir, exist_ok=True)
490
491            # For tracking, we ensure that the output path is a folder,
492            # i.e. does not have an extension. We throw a warning if the user provided an extension.
493            if args.tracking:
494                if os.path.splitext(output_path)[-1]:
495                    warnings.warn(
496                        f"The output folder has an extension '{os.path.splitext(output_path)[-1]}'. "
497                        "We remove it and treat it as a folder to store tracking outputs in CTC format."
498                    )
499                _output_fpath = output_fdir
500            else:  # Otherwise, we can store outputs for user directly in the provided filepath, ensuring extension .tif
501                _output_fpath = f"{output_fdir}.tif"
502
503        else:  # When we have multiple images.
504            # Get the input filename, without the extension.
505            input_name = str(Path(input_path).stem)
506
507            # Let's check the 'embedding_path'.
508            if embedding_path is None:  # For computing embeddings on-the-fly, we don't care about the path logic.
509                _embedding_fpath = embedding_path
510            else:  # Otherwise, store each embeddings inside a folder.
511                embedding_folder = os.path.splitext(embedding_path)[0]  # Treat the provided embedding path as folder.
512                os.makedirs(embedding_folder, exist_ok=True)
513                _embedding_fpath = os.path.join(embedding_folder, f"{input_name}.zarr")  # Create each embedding file.
514
515            # Get the output folder name.
516            output_folder = os.path.splitext(output_path)[0]
517            os.makedirs(output_folder, exist_ok=True)
518
519            # Next, let's check for output file to store segmentation (or tracks).
520            if args.tracking:  # For tracking, we store CTC outputs in subfolders, with input_name as folder.
521                _output_fpath = os.path.join(output_folder, input_name)
522            else:  # Otherwise, store each result inside a folder.
523                _output_fpath = os.path.join(output_folder, f"{input_name}.tif")
524
525        instance_seg_function(
526            predictor=predictor,
527            segmenter=segmenter,
528            input_path=input_path,
529            output_path=_output_fpath,
530            embedding_path=_embedding_fpath,
531            key=args.key,
532            tile_shape=args.tile_shape,
533            halo=args.halo,
534            annotate=args.annotate,
535            verbose=args.verbose,
536            batch_size=args.batch_size,
537            **generate_kwargs,
538        )
def get_predictor_and_segmenter( model_type: str, checkpoint: Union[os.PathLike, str, NoneType] = None, device: str = None, segmentation_mode: Optional[str] = None, is_tiled: bool = False, predictor=None, state=None, **kwargs) -> Tuple[mobile_sam.predictor.SamPredictor, Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder]]:
26def get_predictor_and_segmenter(
27    model_type: str,
28    checkpoint: Optional[Union[os.PathLike, str]] = None,
29    device: str = None,
30    segmentation_mode: Optional[str] = None,
31    is_tiled: bool = False,
32    predictor=None,
33    state=None,
34    **kwargs,
35) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
36    f"""Get the Segment Anything model and class for automatic instance segmentation.
37
38    Args:
39        model_type: The Segment Anything model choice.
40        checkpoint: The filepath to the stored model checkpoints.
41        device: The torch device. By default, automatically chooses the best available device.
42        segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
43            By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used
44            if a decoder is passed, otherwise 'amg' is used.
45        is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
46            By default, set to 'False'.
47        predictor: The pre-loaded predictor (optional).
48        state: The pre-loaded state (optional).
49        kwargs: Keyword arguments for the automatic mask generation class.
50
51    Returns:
52        The Segment Anything model.
53        The automatic instance segmentation class.
54    """
55    # Get the predictor and state for Segment Anything Model.
56    if predictor is None:
57        device = util.get_device(device=device)
58        predictor, state = util.get_sam_model(
59            model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True
60        )
61    else:
62        assert state is not None
63
64    if segmentation_mode in (None, "auto"):
65        segmentation_mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg"
66
67    if segmentation_mode.lower() == "amg":
68        decoder = None
69    else:
70        if "decoder_state" not in state:
71            raise RuntimeError(
72                f"You have passed 'segmentation_mode={segmentation_mode}', but your model does not contain a decoder."
73            )
74        decoder_state = state["decoder_state"]
75        decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)
76
77    segmenter = get_instance_segmentation_generator(
78        predictor=predictor, is_tiled=is_tiled, decoder=decoder, segmentation_mode=segmentation_mode, **kwargs
79    )
80    return predictor, segmenter
def automatic_tracking( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, batch_size: int = 1, **generate_kwargs) -> Tuple[numpy.ndarray, List[Dict]]:
 89def automatic_tracking(
 90    predictor: util.SamPredictor,
 91    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
 92    input_path: Union[Union[os.PathLike, str], np.ndarray],
 93    output_path: Optional[Union[os.PathLike, str]] = None,
 94    embedding_path: Optional[Union[os.PathLike, str]] = None,
 95    key: Optional[str] = None,
 96    tile_shape: Optional[Tuple[int, int]] = None,
 97    halo: Optional[Tuple[int, int]] = None,
 98    verbose: bool = True,
 99    return_embeddings: bool = False,
100    annotate: bool = False,
101    batch_size: int = 1,
102    **generate_kwargs
103) -> Tuple[np.ndarray, List[Dict]]:
104    """Run automatic tracking for the input timeseries.
105
106    Args:
107        predictor: The Segment Anything model.
108        segmenter: The automatic instance segmentation class.
109        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
110            or a container file (e.g. hdf5 or zarr).
111        output_path: The folder where the tracking outputs will be saved in CTC format.
112        embedding_path: The path where the embeddings are cached already / will be saved.
113        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
114            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
115        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
116        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
117        verbose: Verbosity flag. By default, set to 'True'.
118        return_embeddings: Whether to return the precomputed image embeddings.
119            By default, does not return the embeddings.
120        annotate: Whether to activate the annotator for continue annotation process.
121            By default, does not activate the annotator.
122        batch_size: The batch size to compute image embeddings over tiles / z-planes.
123            By default, does it sequentially, i.e. one after the other.
124        generate_kwargs: optional keyword arguments for the generate function of the AMG, APG, or AIS class.
125
126    Returns:
127        The tracking result as a timeseries, where each object is labeled by its track id.
128        The lineages representing cell divisions, stored as a dictionary.
129    """
130    # Load the input image file.
131    if isinstance(input_path, np.ndarray):
132        image_data = input_path
133    else:
134        image_data = util.load_image_data(input_path, key)
135
136    if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
137        raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
138
139    gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent")
140    segmentation, lineage, image_embeddings = automatic_tracking_implementation(
141        image_data,
142        predictor,
143        segmenter,
144        embedding_path=embedding_path,
145        gap_closing=gap_closing,
146        min_time_extent=min_time_extent,
147        tile_shape=tile_shape,
148        halo=halo,
149        verbose=verbose,
150        batch_size=batch_size,
151        return_embeddings=True,
152        output_folder=output_path,
153        **generate_kwargs,
154    )
155
156    if annotate:
157        # TODO We need to support initialization of the tracking annotator with the tracking result for this.
158        raise NotImplementedError("Annotation after running the automated tracking is currently not supported.")
159
160    if return_embeddings:
161        return segmentation, lineage, image_embeddings
162    else:
163        return segmentation, lineage

Run automatic tracking for the input timeseries.

Arguments:
  • predictor: The Segment Anything model.
  • segmenter: The automatic instance segmentation class.
  • input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
  • output_path: The folder where the tracking outputs will be saved in CTC format.
  • embedding_path: The path where the embeddings are cached already / will be saved.
  • key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
  • verbose: Verbosity flag. By default, set to 'True'.
  • return_embeddings: Whether to return the precomputed image embeddings. By default, does not return the embeddings.
  • annotate: Whether to activate the annotator for continue annotation process. By default, does not activate the annotator.
  • batch_size: The batch size to compute image embeddings over tiles / z-planes. By default, does it sequentially, i.e. one after the other.
  • generate_kwargs: optional keyword arguments for the generate function of the AMG, APG, or AIS class.
Returns:

The tracking result as a timeseries, where each object is labeled by its track id. The lineages representing cell divisions, stored as a dictionary.

def automatic_instance_segmentation( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, batch_size: int = 1, **generate_kwargs) -> numpy.ndarray:
166def automatic_instance_segmentation(
167    predictor: util.SamPredictor,
168    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
169    input_path: Union[Union[os.PathLike, str], np.ndarray],
170    output_path: Optional[Union[os.PathLike, str]] = None,
171    embedding_path: Optional[Union[os.PathLike, str]] = None,
172    key: Optional[str] = None,
173    ndim: Optional[int] = None,
174    tile_shape: Optional[Tuple[int, int]] = None,
175    halo: Optional[Tuple[int, int]] = None,
176    verbose: bool = True,
177    return_embeddings: bool = False,
178    annotate: bool = False,
179    batch_size: int = 1,
180    **generate_kwargs
181) -> np.ndarray:
182    """Run automatic segmentation for the input image.
183
184    Args:
185        predictor: The Segment Anything model.
186        segmenter: The automatic instance segmentation class.
187        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
188            or a container file (e.g. hdf5 or zarr).
189        output_path: The output path where the instance segmentations will be saved.
190        embedding_path: The path where the embeddings are cached already / will be saved.
191        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
192            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
193        ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
194            If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
195        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
196        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
197        verbose: Verbosity flag. By default, set to 'True'.
198        return_embeddings: Whether to return the precomputed image embeddings.
199            By default, does not return the embeddings.
200        annotate: Whether to activate the annotator for continue annotation process.
201            By default, does not activate the annotator.
202        batch_size: The batch size to compute image embeddings over tiles / z-planes.
203            By default, does it sequentially, i.e. one after the other.
204        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
205
206    Returns:
207        The segmentation result.
208    """
209    # Avoid overwriting already stored segmentations.
210    if output_path is not None:
211        output_path = Path(output_path).with_suffix(".tif")
212        if os.path.exists(output_path):
213            print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.")
214            return
215
216    # Load the input image file.
217    if isinstance(input_path, np.ndarray):
218        image_data = input_path
219    else:
220        image_data = util.load_image_data(input_path, key)
221
222    ndim = image_data.ndim if ndim is None else ndim
223
224    if ndim == 2:
225        if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3):
226            raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
227
228        # Precompute the image embeddings.
229        image_embeddings = util.precompute_image_embeddings(
230            predictor=predictor,
231            input_=image_data,
232            save_path=embedding_path,
233            ndim=ndim,
234            tile_shape=tile_shape,
235            halo=halo,
236            verbose=verbose,
237            batch_size=batch_size,
238        )
239        initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose)
240
241        # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing.
242        # In this case, we also add the batch size to the initialize kwargs,
243        # so that the segmentation decoder can be applied in a batched fashion.
244        if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None:
245            generate_kwargs.update({"tile_shape": tile_shape, "halo": halo})
246            initialize_kwargs["batch_size"] = batch_size
247
248        segmenter.initialize(**initialize_kwargs)
249        instances = segmenter.generate(**generate_kwargs)
250
251    else:
252        if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
253            raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
254
255        instances, image_embeddings = automatic_3d_segmentation(
256            volume=image_data,
257            predictor=predictor,
258            segmentor=segmenter,
259            embedding_path=embedding_path,
260            tile_shape=tile_shape,
261            halo=halo,
262            verbose=verbose,
263            return_embeddings=True,
264            batch_size=batch_size,
265            **generate_kwargs
266        )
267
268    # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage.
269    if output_path is not None:
270        _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path
271        imageio.imwrite(_output_path, instances, compression="zlib")
272        if verbose:
273            print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.")
274
275    # Allow opening the automatic segmentation in the annotator for further annotation, if desired.
276    if annotate:
277        from micro_sam.sam_annotator import annotator_2d, annotator_3d
278        annotator_function = annotator_2d if ndim == 2 else annotator_3d
279
280        viewer = annotator_function(
281            image=image_data,
282            model_type=predictor.model_name,
283            embedding_path=image_embeddings,  # Providing the precomputed image embeddings.
284            segmentation_result=instances,  # Initializes the automatic segmentation to the annotator.
285            tile_shape=tile_shape,
286            halo=halo,
287            return_viewer=True,  # Returns the viewer, which allows the user to store the updated segmentations.
288        )
289
290        # Start the GUI here
291        import napari
292        napari.run()
293
294        # We extract the segmentation in "committed_objects" layer, where the user either:
295        # a) Performed interactive segmentation / corrections and committed them, OR
296        # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is.
297        instances = viewer.layers["committed_objects"].data
298
299        # Save the instance segmentation, if 'output_path' provided.
300        if output_path is not None:
301            imageio.imwrite(output_path, instances, compression="zlib")
302            if verbose:
303                print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.")
304
305    if return_embeddings:
306        return instances, image_embeddings
307    else:
308        return instances

Run automatic segmentation for the input image.

Arguments:
  • predictor: The Segment Anything model.
  • segmenter: The automatic instance segmentation class.
  • input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
  • output_path: The output path where the instance segmentations will be saved.
  • embedding_path: The path where the embeddings are cached already / will be saved.
  • key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
  • ndim: The dimensionality of the data. By default the dimensionality of the data will be used. If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
  • verbose: Verbosity flag. By default, set to 'True'.
  • return_embeddings: Whether to return the precomputed image embeddings. By default, does not return the embeddings.
  • annotate: Whether to activate the annotator for continue annotation process. By default, does not activate the annotator.
  • batch_size: The batch size to compute image embeddings over tiles / z-planes. By default, does it sequentially, i.e. one after the other.
  • generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:

The segmentation result.