micro_sam.automatic_segmentation

  1import os
  2import warnings
  3from glob import glob
  4from tqdm import tqdm
  5from pathlib import Path
  6from functools import partial
  7from typing import Dict, List, Optional, Union, Tuple
  8
  9import numpy as np
 10import imageio.v3 as imageio
 11
 12from torch_em.data.datasets.util import split_kwargs
 13
 14from . import util
 15from .instance_segmentation import (
 16    get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder,
 17    AMGBase, AutomaticMaskGenerator, TiledAutomaticMaskGenerator
 18)
 19from .multi_dimensional_segmentation import automatic_3d_segmentation, automatic_tracking_implementation
 20
 21
 22def get_predictor_and_segmenter(
 23    model_type: str,
 24    checkpoint: Optional[Union[os.PathLike, str]] = None,
 25    device: str = None,
 26    amg: Optional[bool] = None,
 27    is_tiled: bool = False,
 28    **kwargs,
 29) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
 30    """Get the Segment Anything model and class for automatic instance segmentation.
 31
 32    Args:
 33        model_type: The Segment Anything model choice.
 34        checkpoint: The filepath to the stored model checkpoints.
 35        device: The torch device. By default, automatically chooses the best available device.
 36        amg: Whether to perform automatic segmentation in AMG mode.
 37            Otherwise AIS will be used, which requires a special segmentation decoder.
 38            If not specified AIS will be used if it is available and otherwise AMG will be used.
 39        is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
 40            By default, set to 'False'.
 41        kwargs: Keyword arguments for the automatic mask generation class.
 42
 43    Returns:
 44        The Segment Anything model.
 45        The automatic instance segmentation class.
 46    """
 47    # Get the device
 48    device = util.get_device(device=device)
 49
 50    # Get the predictor and state for Segment Anything models.
 51    predictor, state = util.get_sam_model(
 52        model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True,
 53    )
 54
 55    if amg is None:
 56        amg = "decoder_state" not in state
 57
 58    if amg:
 59        decoder = None
 60    else:
 61        if "decoder_state" not in state:
 62            raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.")
 63        decoder_state = state["decoder_state"]
 64        decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)
 65
 66    segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs)
 67
 68    return predictor, segmenter
 69
 70
 71def _add_suffix_to_output_path(output_path: Union[str, os.PathLike], suffix: str) -> str:
 72    fpath = Path(output_path).resolve()
 73    fext = fpath.suffix if fpath.suffix else ".tif"
 74    return str(fpath.with_name(f"{fpath.stem}{suffix}{fext}"))
 75
 76
 77def automatic_tracking(
 78    predictor: util.SamPredictor,
 79    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
 80    input_path: Union[Union[os.PathLike, str], np.ndarray],
 81    output_path: Optional[Union[os.PathLike, str]] = None,
 82    embedding_path: Optional[Union[os.PathLike, str]] = None,
 83    key: Optional[str] = None,
 84    tile_shape: Optional[Tuple[int, int]] = None,
 85    halo: Optional[Tuple[int, int]] = None,
 86    verbose: bool = True,
 87    return_embeddings: bool = False,
 88    annotate: bool = False,
 89    batch_size: int = 1,
 90    **generate_kwargs
 91) -> Tuple[np.ndarray, List[Dict]]:
 92    """Run automatic tracking for the input timeseries.
 93
 94    Args:
 95        predictor: The Segment Anything model.
 96        segmenter: The automatic instance segmentation class.
 97        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
 98            or a container file (e.g. hdf5 or zarr).
 99        output_path: The folder where the tracking outputs will be saved in CTC format.
100        embedding_path: The path where the embeddings are cached already / will be saved.
101        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
102            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
103        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
104        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
105        verbose: Verbosity flag. By default, set to 'True'.
106        return_embeddings: Whether to return the precomputed image embeddings.
107            By default, does not return the embeddings.
108        annotate: Whether to activate the annotator for continue annotation process.
109            By default, does not activate the annotator.
110        batch_size: The batch size to compute image embeddings over tiles / z-planes.
111            By default, does it sequentially, i.e. one after the other.
112        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
113
114    Returns:
115        The tracking result as a timeseries, where each object is labeled by its track id.
116        The lineages representing cell divisions, stored as a dictionary.
117    """
118    # Load the input image file.
119    if isinstance(input_path, np.ndarray):
120        image_data = input_path
121    else:
122        image_data = util.load_image_data(input_path, key)
123
124    # We perform additional post-processing for AMG-only.
125    # Otherwise, we ignore additional post-processing for AIS.
126    if isinstance(segmenter, InstanceSegmentationWithDecoder):
127        generate_kwargs["output_mode"] = None
128
129    if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
130        raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
131
132    gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent")
133    segmentation, lineage, image_embeddings = automatic_tracking_implementation(
134        image_data,
135        predictor,
136        segmenter,
137        embedding_path=embedding_path,
138        gap_closing=gap_closing,
139        min_time_extent=min_time_extent,
140        tile_shape=tile_shape,
141        halo=halo,
142        verbose=verbose,
143        batch_size=batch_size,
144        return_embeddings=True,
145        output_folder=output_path,
146        **generate_kwargs,
147    )
148
149    if annotate:
150        # TODO We need to support initialization of the tracking annotator with the tracking result for this.
151        raise NotImplementedError("Annotation after running the automated tracking is currently not supported.")
152
153    if return_embeddings:
154        return segmentation, lineage, image_embeddings
155    else:
156        return segmentation, lineage
157
158
159def automatic_instance_segmentation(
160    predictor: util.SamPredictor,
161    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
162    input_path: Union[Union[os.PathLike, str], np.ndarray],
163    output_path: Optional[Union[os.PathLike, str]] = None,
164    embedding_path: Optional[Union[os.PathLike, str]] = None,
165    key: Optional[str] = None,
166    ndim: Optional[int] = None,
167    tile_shape: Optional[Tuple[int, int]] = None,
168    halo: Optional[Tuple[int, int]] = None,
169    verbose: bool = True,
170    return_embeddings: bool = False,
171    annotate: bool = False,
172    batch_size: int = 1,
173    **generate_kwargs
174) -> np.ndarray:
175    """Run automatic segmentation for the input image.
176
177    Args:
178        predictor: The Segment Anything model.
179        segmenter: The automatic instance segmentation class.
180        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
181            or a container file (e.g. hdf5 or zarr).
182        output_path: The output path where the instance segmentations will be saved.
183        embedding_path: The path where the embeddings are cached already / will be saved.
184        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
185            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
186        ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
187            If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
188        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
189        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
190        verbose: Verbosity flag. By default, set to 'True'.
191        return_embeddings: Whether to return the precomputed image embeddings.
192            By default, does not return the embeddings.
193        annotate: Whether to activate the annotator for continue annotation process.
194            By default, does not activate the annotator.
195        batch_size: The batch size to compute image embeddings over tiles / z-planes.
196            By default, does it sequentially, i.e. one after the other.
197        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
198
199    Returns:
200        The segmentation result.
201    """
202    # Avoid overwriting already stored segmentations.
203    if output_path is not None:
204        output_path = Path(output_path).with_suffix(".tif")
205        if os.path.exists(output_path):
206            print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.")
207            return
208
209    # Load the input image file.
210    if isinstance(input_path, np.ndarray):
211        image_data = input_path
212    else:
213        image_data = util.load_image_data(input_path, key)
214
215    ndim = image_data.ndim if ndim is None else ndim
216
217    # We perform additional post-processing for AMG-only.
218    # Otherwise, we ignore additional post-processing for AIS.
219    if isinstance(segmenter, InstanceSegmentationWithDecoder):
220        generate_kwargs["output_mode"] = None
221
222    if ndim == 2:
223        if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3):
224            raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
225
226        # Precompute the image embeddings.
227        image_embeddings = util.precompute_image_embeddings(
228            predictor=predictor,
229            input_=image_data,
230            save_path=embedding_path,
231            ndim=ndim,
232            tile_shape=tile_shape,
233            halo=halo,
234            verbose=verbose,
235            batch_size=batch_size,
236        )
237        initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose)
238
239        # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing.
240        # In this case, we also add the batch size to the initialize kwargs,
241        # so that the segmentation decoder can be applied in a batched fashion.
242        if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None:
243            generate_kwargs.update({"tile_shape": tile_shape, "halo": halo})
244            initialize_kwargs["batch_size"] = batch_size
245
246        segmenter.initialize(**initialize_kwargs)
247        masks = segmenter.generate(**generate_kwargs)
248
249        if isinstance(masks, list):
250            # whether the predictions from 'generate' are list of dict,
251            # which contains additional info req. for post-processing, eg. area per object.
252            if len(masks) == 0:
253                instances = np.zeros(image_data.shape[:2], dtype="uint32")
254            else:
255                instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)
256        else:
257            # if (raw) predictions provided, store them as it is w/o further post-processing.
258            instances = masks
259
260    else:
261        if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
262            raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
263
264        instances, image_embeddings = automatic_3d_segmentation(
265            volume=image_data,
266            predictor=predictor,
267            segmentor=segmenter,
268            embedding_path=embedding_path,
269            tile_shape=tile_shape,
270            halo=halo,
271            verbose=verbose,
272            return_embeddings=True,
273            batch_size=batch_size,
274            **generate_kwargs
275        )
276
277    # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage.
278    if output_path is not None:
279        _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path
280        imageio.imwrite(_output_path, instances, compression="zlib")
281        if verbose:
282            print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.")
283
284    # Allow opening the automatic segmentation in the annotator for further annotation, if desired.
285    if annotate:
286        from micro_sam.sam_annotator import annotator_2d, annotator_3d
287        annotator_function = annotator_2d if ndim == 2 else annotator_3d
288
289        viewer = annotator_function(
290            image=image_data,
291            model_type=predictor.model_name,
292            embedding_path=image_embeddings,  # Providing the precomputed image embeddings.
293            segmentation_result=instances,  # Initializes the automatic segmentation to the annotator.
294            tile_shape=tile_shape,
295            halo=halo,
296            return_viewer=True,  # Returns the viewer, which allows the user to store the updated segmentations.
297        )
298
299        # Start the GUI here
300        import napari
301        napari.run()
302
303        # We extract the segmentation in "committed_objects" layer, where the user either:
304        # a) Performed interactive segmentation / corrections and committed them, OR
305        # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is.
306        instances = viewer.layers["committed_objects"].data
307
308        # Save the instance segmentation, if 'output_path' provided.
309        if output_path is not None:
310            imageio.imwrite(output_path, instances, compression="zlib")
311            if verbose:
312                print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.")
313
314    if return_embeddings:
315        return instances, image_embeddings
316    else:
317        return instances
318
319
320def _get_inputs_from_paths(paths, pattern):
321    "Function to get all filepaths in a directory."
322
323    if isinstance(paths, str):
324        paths = [paths]
325
326    fpaths = []
327    for path in paths:
328        if os.path.isfile(path):  # It is just one filepath.
329            fpaths.append(path)
330        else:  # Otherwise, if the path is a directory, fetch all inputs provided with a pattern.
331            assert pattern is not None, \
332                f"You must provide a pattern to search for files in the directory: '{os.path.abspath(path)}'."
333            fpaths.extend(glob(os.path.join(path, pattern)))
334
335    return fpaths
336
337
338def main():
339    """@private"""
340    import argparse
341
342    available_models = list(util.get_model_names())
343    available_models = ", ".join(available_models)
344
345    parser = argparse.ArgumentParser(
346        description="Run automatic segmentation or tracking for 2d, 3d or timeseries data.\n"
347        "Either a single input file or multiple input files are supported. You can specify multiple files "
348        "by either providing multiple filepaths to the '--i/--input_paths' argument, or by providing an argument "
349        "to '--pattern' to use a wildcard pattern ('*') for selecting multiple files.\n"
350        "NOTE: for automatic 3d segmentation or tracking the data has to be stored as volume / timeseries, "
351        "stacking individual tif images is not supported.\n"
352        "Segmentation is performed using one of the two modes supported by micro_sam: \n"
353        "automatic instance segmentation (AIS) or automatic mask generation (AMG).\n"
354        "In addition to the options listed below, "
355        "you can also passed additional arguments for these two segmentation modes:\n"
356        "For AIS: '--center_distance_threshold', '--boundary_distance_threshold' and other arguments of `InstanceSegmentationWithDecoder.generate`."  # noqa
357        "For AMG: '--pred_iou_thresh', '--stability_score_thresh' and other arguments of `AutomaticMaskGenerator.generate`."  # noqa
358    )
359    parser.add_argument(
360        "-i", "--input_path", required=True, type=str, nargs="+",
361        help="The filepath(s) to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "  # noqa
362        "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
363    )
364    parser.add_argument(
365        "-o", "--output_path", required=True, type=str,
366        help="The filepath to store the results. If multiple inputs are provied, "
367        "this should be a folder. For a single image, you should provide the path to a tif file for the output segmentation."  # noqa
368        "NOTE: Segmentation results are stored as tif files, tracking results in the CTC fil format ."
369    )
370    parser.add_argument(
371        "-e", "--embedding_path", default=None, type=str,
372        help="An optional path where the embeddings will be saved. If multiple inputs are provided, "
373        "this should be a folder. Otherwise you can store embeddings in single zarr file."
374    )
375    parser.add_argument(
376        "--pattern", type=str, help="Pattern / wildcard for selecting files in a folder. To select all files use '*'."
377    )
378    parser.add_argument(
379        "-k", "--key", default=None, type=str,
380        help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
381        "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
382    )
383    parser.add_argument(
384        "-m", "--model_type", default=util._DEFAULT_MODEL, type=str,
385        help=f"The segment anything model that will be used, one of {available_models}."
386    )
387    parser.add_argument(
388        "-c", "--checkpoint", default=None, type=str, help="Checkpoint from which the SAM model will be loaded."
389    )
390    parser.add_argument(
391        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None
392    )
393    parser.add_argument(
394        "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None
395    )
396    parser.add_argument(
397        "-n", "--ndim", default=None, type=int,
398        help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension."
399    )
400    parser.add_argument(
401        "--mode", default="auto", type=str,
402        help="The choice of automatic segmentation with the Segment Anything models. Either 'auto', 'amg' or 'ais'."
403    )
404    parser.add_argument(
405        "--annotate", action="store_true",
406        help="Whether to continue annotation after the automatic segmentation is generated."
407    )
408    parser.add_argument(
409        "-d", "--device", default=None, type=str,
410        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
411        "By default the most performant available device will be selected."
412    )
413    parser.add_argument(
414        "--batch_size", type=int, default=1,
415        help="The batch size for computing image embeddings over tiles or z-plane. "
416        "By default, computes the image embeddings for one tile / z-plane at a time."
417    )
418    parser.add_argument(
419        "--tracking", action="store_true", help="Run automatic tracking instead of instance segmentation. "
420        "NOTE: It is only supported for timeseries inputs."
421    )
422    parser.add_argument(
423        "-v", "--verbose", action="store_true", help="Whether to allow verbosity of outputs."
424    )
425
426    args, parameter_args = parser.parse_known_args()
427
428    def _convert_argval(value):
429        # The values for the parsed arguments need to be in the expected input structure as provided.
430        # i.e. integers and floats should be in their original types.
431        try:
432            return int(value)
433        except ValueError:
434            return float(value)
435
436    # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to
437    # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS)
438    extra_kwargs = {
439        parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2)
440    }
441
442    # Separate extra arguments as per where they should be passed in the automatic segmentation class.
443    # This is done to ensure the extra arguments are allocated to the desired location.
444    # eg. for AMG, 'points_per_side' is expected by '__init__',
445    # and 'stability_score_thresh' is expected in 'generate' method.
446    amg_class = AutomaticMaskGenerator if args.tile_shape is None else TiledAutomaticMaskGenerator
447    amg_kwargs, generate_kwargs = split_kwargs(amg_class, **extra_kwargs)
448
449    # Validate for the expected automatic segmentation mode.
450    # By default, it is set to 'auto', i.e. searches for the decoder state to prioritize AIS for finetuned models.
451    # Otherwise, runs AMG for all models in any case.
452    amg = None
453    if args.mode != "auto":
454        assert args.mode in ["ais", "amg"], \
455            f"'{args.mode}' is not a valid automatic segmentation mode. Please choose either 'amg' or 'ais'."
456        amg = (args.mode == "amg")
457
458    predictor, segmenter = get_predictor_and_segmenter(
459        model_type=args.model_type,
460        checkpoint=args.checkpoint,
461        device=args.device,
462        amg=amg,
463        is_tiled=args.tile_shape is not None,
464        **amg_kwargs,
465    )
466
467    # Get the filepaths to input images (and other paths to store stuff, eg. segmentations and embeddings)
468    # Check whether the inputs are as expected, otherwise assort them.
469    input_paths = _get_inputs_from_paths(args.input_path, args.pattern)
470    assert len(input_paths) > 0, "'micro-sam' could not extract any image data internally."
471
472    output_path = args.output_path
473    embedding_path = args.embedding_path
474    has_one_input = len(input_paths) == 1
475
476    instance_seg_function = automatic_tracking if args.tracking else partial(
477        automatic_instance_segmentation, ndim=args.ndim
478    )
479
480    # Run automatic segmentation per image.
481    for input_path in tqdm(input_paths, desc="Run automatic " + ("tracking" if args.tracking else "segmentation")):
482        if has_one_input:  # When we have only one image / volume.
483            _embedding_fpath = embedding_path  # Either folder or zarr file, would work for both.
484
485            output_fdir = os.path.splitext(output_path)[0]
486            os.makedirs(output_fdir, exist_ok=True)
487
488            # For tracking, we ensure that the output path is a folder,
489            # i.e. does not have an extension. We throw a warning if the user provided an extension.
490            if args.tracking:
491                if os.path.splitext(output_path)[-1]:
492                    warnings.warn(
493                        f"The output folder has an extension '{os.path.splitext(output_path)[-1]}'. "
494                        "We remove it and treat it as a folder to store tracking outputs in CTC format."
495                    )
496                _output_fpath = output_fdir
497            else:  # Otherwise, we can store outputs for user directly in the provided filepath, ensuring extension .tif
498                _output_fpath = f"{output_fdir}.tif"
499
500        else:  # When we have multiple images.
501            # Get the input filename, without the extension.
502            input_name = str(Path(input_path).stem)
503
504            # Let's check the 'embedding_path'.
505            if embedding_path is None:  # For computing embeddings on-the-fly, we don't care about the path logic.
506                _embedding_fpath = embedding_path
507            else:  # Otherwise, store each embeddings inside a folder.
508                embedding_folder = os.path.splitext(embedding_path)[0]  # Treat the provided embedding path as folder.
509                os.makedirs(embedding_folder, exist_ok=True)
510                _embedding_fpath = os.path.join(embedding_folder, f"{input_name}.zarr")  # Create each embedding file.
511
512            # Get the output folder name.
513            output_folder = os.path.splitext(output_path)[0]
514            os.makedirs(output_folder, exist_ok=True)
515
516            # Next, let's check for output file to store segmentation (or tracks).
517            if args.tracking:  # For tracking, we store CTC outputs in subfolders, with input_name as folder.
518                _output_fpath = os.path.join(output_folder, input_name)
519            else:  # Otherwise, store each result inside a folder.
520                _output_fpath = os.path.join(output_folder, f"{input_name}.tif")
521
522        instance_seg_function(
523            predictor=predictor,
524            segmenter=segmenter,
525            input_path=input_path,
526            output_path=_output_fpath,
527            embedding_path=_embedding_fpath,
528            key=args.key,
529            tile_shape=args.tile_shape,
530            halo=args.halo,
531            annotate=args.annotate,
532            verbose=args.verbose,
533            batch_size=args.batch_size,
534            **generate_kwargs,
535        )
def get_predictor_and_segmenter( model_type: str, checkpoint: Union[os.PathLike, str, NoneType] = None, device: str = None, amg: Optional[bool] = None, is_tiled: bool = False, **kwargs) -> Tuple[segment_anything.predictor.SamPredictor, Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder]]:
23def get_predictor_and_segmenter(
24    model_type: str,
25    checkpoint: Optional[Union[os.PathLike, str]] = None,
26    device: str = None,
27    amg: Optional[bool] = None,
28    is_tiled: bool = False,
29    **kwargs,
30) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
31    """Get the Segment Anything model and class for automatic instance segmentation.
32
33    Args:
34        model_type: The Segment Anything model choice.
35        checkpoint: The filepath to the stored model checkpoints.
36        device: The torch device. By default, automatically chooses the best available device.
37        amg: Whether to perform automatic segmentation in AMG mode.
38            Otherwise AIS will be used, which requires a special segmentation decoder.
39            If not specified AIS will be used if it is available and otherwise AMG will be used.
40        is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
41            By default, set to 'False'.
42        kwargs: Keyword arguments for the automatic mask generation class.
43
44    Returns:
45        The Segment Anything model.
46        The automatic instance segmentation class.
47    """
48    # Get the device
49    device = util.get_device(device=device)
50
51    # Get the predictor and state for Segment Anything models.
52    predictor, state = util.get_sam_model(
53        model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True,
54    )
55
56    if amg is None:
57        amg = "decoder_state" not in state
58
59    if amg:
60        decoder = None
61    else:
62        if "decoder_state" not in state:
63            raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.")
64        decoder_state = state["decoder_state"]
65        decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)
66
67    segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs)
68
69    return predictor, segmenter

Get the Segment Anything model and class for automatic instance segmentation.

Arguments:
  • model_type: The Segment Anything model choice.
  • checkpoint: The filepath to the stored model checkpoints.
  • device: The torch device. By default, automatically chooses the best available device.
  • amg: Whether to perform automatic segmentation in AMG mode. Otherwise AIS will be used, which requires a special segmentation decoder. If not specified AIS will be used if it is available and otherwise AMG will be used.
  • is_tiled: Whether to return segmenter for performing segmentation in tiling window style. By default, set to 'False'.
  • kwargs: Keyword arguments for the automatic mask generation class.
Returns:

The Segment Anything model. The automatic instance segmentation class.

def automatic_tracking( predictor: segment_anything.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, batch_size: int = 1, **generate_kwargs) -> Tuple[numpy.ndarray, List[Dict]]:
 78def automatic_tracking(
 79    predictor: util.SamPredictor,
 80    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
 81    input_path: Union[Union[os.PathLike, str], np.ndarray],
 82    output_path: Optional[Union[os.PathLike, str]] = None,
 83    embedding_path: Optional[Union[os.PathLike, str]] = None,
 84    key: Optional[str] = None,
 85    tile_shape: Optional[Tuple[int, int]] = None,
 86    halo: Optional[Tuple[int, int]] = None,
 87    verbose: bool = True,
 88    return_embeddings: bool = False,
 89    annotate: bool = False,
 90    batch_size: int = 1,
 91    **generate_kwargs
 92) -> Tuple[np.ndarray, List[Dict]]:
 93    """Run automatic tracking for the input timeseries.
 94
 95    Args:
 96        predictor: The Segment Anything model.
 97        segmenter: The automatic instance segmentation class.
 98        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
 99            or a container file (e.g. hdf5 or zarr).
100        output_path: The folder where the tracking outputs will be saved in CTC format.
101        embedding_path: The path where the embeddings are cached already / will be saved.
102        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
103            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
104        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
105        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
106        verbose: Verbosity flag. By default, set to 'True'.
107        return_embeddings: Whether to return the precomputed image embeddings.
108            By default, does not return the embeddings.
109        annotate: Whether to activate the annotator for continue annotation process.
110            By default, does not activate the annotator.
111        batch_size: The batch size to compute image embeddings over tiles / z-planes.
112            By default, does it sequentially, i.e. one after the other.
113        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
114
115    Returns:
116        The tracking result as a timeseries, where each object is labeled by its track id.
117        The lineages representing cell divisions, stored as a dictionary.
118    """
119    # Load the input image file.
120    if isinstance(input_path, np.ndarray):
121        image_data = input_path
122    else:
123        image_data = util.load_image_data(input_path, key)
124
125    # We perform additional post-processing for AMG-only.
126    # Otherwise, we ignore additional post-processing for AIS.
127    if isinstance(segmenter, InstanceSegmentationWithDecoder):
128        generate_kwargs["output_mode"] = None
129
130    if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
131        raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
132
133    gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent")
134    segmentation, lineage, image_embeddings = automatic_tracking_implementation(
135        image_data,
136        predictor,
137        segmenter,
138        embedding_path=embedding_path,
139        gap_closing=gap_closing,
140        min_time_extent=min_time_extent,
141        tile_shape=tile_shape,
142        halo=halo,
143        verbose=verbose,
144        batch_size=batch_size,
145        return_embeddings=True,
146        output_folder=output_path,
147        **generate_kwargs,
148    )
149
150    if annotate:
151        # TODO We need to support initialization of the tracking annotator with the tracking result for this.
152        raise NotImplementedError("Annotation after running the automated tracking is currently not supported.")
153
154    if return_embeddings:
155        return segmentation, lineage, image_embeddings
156    else:
157        return segmentation, lineage

Run automatic tracking for the input timeseries.

Arguments:
  • predictor: The Segment Anything model.
  • segmenter: The automatic instance segmentation class.
  • input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
  • output_path: The folder where the tracking outputs will be saved in CTC format.
  • embedding_path: The path where the embeddings are cached already / will be saved.
  • key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
  • verbose: Verbosity flag. By default, set to 'True'.
  • return_embeddings: Whether to return the precomputed image embeddings. By default, does not return the embeddings.
  • annotate: Whether to activate the annotator for continue annotation process. By default, does not activate the annotator.
  • batch_size: The batch size to compute image embeddings over tiles / z-planes. By default, does it sequentially, i.e. one after the other.
  • generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:

The tracking result as a timeseries, where each object is labeled by its track id. The lineages representing cell divisions, stored as a dictionary.

def automatic_instance_segmentation( predictor: segment_anything.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[os.PathLike, str, NoneType] = None, embedding_path: Union[os.PathLike, str, NoneType] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, batch_size: int = 1, **generate_kwargs) -> numpy.ndarray:
160def automatic_instance_segmentation(
161    predictor: util.SamPredictor,
162    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
163    input_path: Union[Union[os.PathLike, str], np.ndarray],
164    output_path: Optional[Union[os.PathLike, str]] = None,
165    embedding_path: Optional[Union[os.PathLike, str]] = None,
166    key: Optional[str] = None,
167    ndim: Optional[int] = None,
168    tile_shape: Optional[Tuple[int, int]] = None,
169    halo: Optional[Tuple[int, int]] = None,
170    verbose: bool = True,
171    return_embeddings: bool = False,
172    annotate: bool = False,
173    batch_size: int = 1,
174    **generate_kwargs
175) -> np.ndarray:
176    """Run automatic segmentation for the input image.
177
178    Args:
179        predictor: The Segment Anything model.
180        segmenter: The automatic instance segmentation class.
181        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
182            or a container file (e.g. hdf5 or zarr).
183        output_path: The output path where the instance segmentations will be saved.
184        embedding_path: The path where the embeddings are cached already / will be saved.
185        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
186            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
187        ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
188            If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
189        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
190        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
191        verbose: Verbosity flag. By default, set to 'True'.
192        return_embeddings: Whether to return the precomputed image embeddings.
193            By default, does not return the embeddings.
194        annotate: Whether to activate the annotator for continue annotation process.
195            By default, does not activate the annotator.
196        batch_size: The batch size to compute image embeddings over tiles / z-planes.
197            By default, does it sequentially, i.e. one after the other.
198        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
199
200    Returns:
201        The segmentation result.
202    """
203    # Avoid overwriting already stored segmentations.
204    if output_path is not None:
205        output_path = Path(output_path).with_suffix(".tif")
206        if os.path.exists(output_path):
207            print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.")
208            return
209
210    # Load the input image file.
211    if isinstance(input_path, np.ndarray):
212        image_data = input_path
213    else:
214        image_data = util.load_image_data(input_path, key)
215
216    ndim = image_data.ndim if ndim is None else ndim
217
218    # We perform additional post-processing for AMG-only.
219    # Otherwise, we ignore additional post-processing for AIS.
220    if isinstance(segmenter, InstanceSegmentationWithDecoder):
221        generate_kwargs["output_mode"] = None
222
223    if ndim == 2:
224        if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3):
225            raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
226
227        # Precompute the image embeddings.
228        image_embeddings = util.precompute_image_embeddings(
229            predictor=predictor,
230            input_=image_data,
231            save_path=embedding_path,
232            ndim=ndim,
233            tile_shape=tile_shape,
234            halo=halo,
235            verbose=verbose,
236            batch_size=batch_size,
237        )
238        initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose)
239
240        # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing.
241        # In this case, we also add the batch size to the initialize kwargs,
242        # so that the segmentation decoder can be applied in a batched fashion.
243        if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None:
244            generate_kwargs.update({"tile_shape": tile_shape, "halo": halo})
245            initialize_kwargs["batch_size"] = batch_size
246
247        segmenter.initialize(**initialize_kwargs)
248        masks = segmenter.generate(**generate_kwargs)
249
250        if isinstance(masks, list):
251            # whether the predictions from 'generate' are list of dict,
252            # which contains additional info req. for post-processing, eg. area per object.
253            if len(masks) == 0:
254                instances = np.zeros(image_data.shape[:2], dtype="uint32")
255            else:
256                instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)
257        else:
258            # if (raw) predictions provided, store them as it is w/o further post-processing.
259            instances = masks
260
261    else:
262        if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
263            raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
264
265        instances, image_embeddings = automatic_3d_segmentation(
266            volume=image_data,
267            predictor=predictor,
268            segmentor=segmenter,
269            embedding_path=embedding_path,
270            tile_shape=tile_shape,
271            halo=halo,
272            verbose=verbose,
273            return_embeddings=True,
274            batch_size=batch_size,
275            **generate_kwargs
276        )
277
278    # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage.
279    if output_path is not None:
280        _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path
281        imageio.imwrite(_output_path, instances, compression="zlib")
282        if verbose:
283            print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.")
284
285    # Allow opening the automatic segmentation in the annotator for further annotation, if desired.
286    if annotate:
287        from micro_sam.sam_annotator import annotator_2d, annotator_3d
288        annotator_function = annotator_2d if ndim == 2 else annotator_3d
289
290        viewer = annotator_function(
291            image=image_data,
292            model_type=predictor.model_name,
293            embedding_path=image_embeddings,  # Providing the precomputed image embeddings.
294            segmentation_result=instances,  # Initializes the automatic segmentation to the annotator.
295            tile_shape=tile_shape,
296            halo=halo,
297            return_viewer=True,  # Returns the viewer, which allows the user to store the updated segmentations.
298        )
299
300        # Start the GUI here
301        import napari
302        napari.run()
303
304        # We extract the segmentation in "committed_objects" layer, where the user either:
305        # a) Performed interactive segmentation / corrections and committed them, OR
306        # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is.
307        instances = viewer.layers["committed_objects"].data
308
309        # Save the instance segmentation, if 'output_path' provided.
310        if output_path is not None:
311            imageio.imwrite(output_path, instances, compression="zlib")
312            if verbose:
313                print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.")
314
315    if return_embeddings:
316        return instances, image_embeddings
317    else:
318        return instances

Run automatic segmentation for the input image.

Arguments:
  • predictor: The Segment Anything model.
  • segmenter: The automatic instance segmentation class.
  • input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
  • output_path: The output path where the instance segmentations will be saved.
  • embedding_path: The path where the embeddings are cached already / will be saved.
  • key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
  • ndim: The dimensionality of the data. By default the dimensionality of the data will be used. If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
  • verbose: Verbosity flag. By default, set to 'True'.
  • return_embeddings: Whether to return the precomputed image embeddings. By default, does not return the embeddings.
  • annotate: Whether to activate the annotator for continue annotation process. By default, does not activate the annotator.
  • batch_size: The batch size to compute image embeddings over tiles / z-planes. By default, does it sequentially, i.e. one after the other.
  • generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:

The segmentation result.