micro_sam.automatic_segmentation

  1import os
  2import warnings
  3from glob import glob
  4from tqdm import tqdm
  5from pathlib import Path
  6from functools import partial
  7from typing import Dict, List, Optional, Union, Tuple, Literal
  8
  9import numpy as np
 10import imageio.v3 as imageio
 11
 12from torch_em.data.datasets.util import split_kwargs
 13
 14from . import util
 15from .instance_segmentation import (
 16    get_instance_segmentation_generator, get_decoder, AMGBase,
 17    AutomaticMaskGenerator, TiledAutomaticMaskGenerator,
 18    AutomaticPromptGenerator, TiledAutomaticPromptGenerator,
 19    InstanceSegmentationWithDecoder, TiledInstanceSegmentationWithDecoder,
 20    DEFAULT_SEGMENTATION_MODE_WITH_DECODER,
 21)
 22from .multi_dimensional_segmentation import automatic_3d_segmentation, automatic_tracking_implementation
 23
 24
 25def get_predictor_and_segmenter(
 26    model_type: str,
 27    checkpoint: Optional[Union[os.PathLike, str]] = None,
 28    device: str = None,
 29    segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None,
 30    is_tiled: bool = False,
 31    predictor=None,
 32    state=None,
 33    **kwargs,
 34) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
 35    f"""Get the Segment Anything model and class for automatic instance segmentation.
 36
 37    Args:
 38        model_type: The Segment Anything model choice.
 39        checkpoint: The filepath to the stored model checkpoints.
 40        device: The torch device. By default, automatically chooses the best available device.
 41        segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
 42            By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used
 43            if a decoder is passed, otherwise 'amg' is used.
 44        is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
 45            By default, set to 'False'.
 46        predictor: The pre-loaded predictor (optional).
 47        state: The pre-loaded state (optional).
 48        kwargs: Keyword arguments for the automatic mask generation class.
 49
 50    Returns:
 51        The Segment Anything model.
 52        The automatic instance segmentation class.
 53    """
 54    # Get the predictor and state for Segment Anything Model.
 55    if predictor is None:
 56        device = util.get_device(device=device)
 57        predictor, state = util.get_sam_model(
 58            model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True
 59        )
 60    else:
 61        assert state is not None
 62
 63    if segmentation_mode in (None, "auto"):
 64        segmentation_mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg"
 65
 66    if segmentation_mode.lower() == "amg":
 67        decoder = None
 68    else:
 69        if "decoder_state" not in state:
 70            raise RuntimeError(
 71                f"You have passed 'segmentation_mode={segmentation_mode}', but your model does not contain a decoder."
 72            )
 73        decoder_state = state["decoder_state"]
 74        decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)
 75
 76    segmenter = get_instance_segmentation_generator(
 77        predictor=predictor, is_tiled=is_tiled, decoder=decoder, segmentation_mode=segmentation_mode, **kwargs
 78    )
 79    return predictor, segmenter
 80
 81
 82def _add_suffix_to_output_path(output_path: Union[str, os.PathLike], suffix: str) -> str:
 83    fpath = Path(output_path).resolve()
 84    fext = fpath.suffix if fpath.suffix else ".tif"
 85    return str(fpath.with_name(f"{fpath.stem}{suffix}{fext}"))
 86
 87
 88def automatic_tracking(
 89    predictor: util.SamPredictor,
 90    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
 91    input_path: Union[Union[os.PathLike, str], np.ndarray],
 92    output_path: Optional[Union[os.PathLike, str]] = None,
 93    embedding_path: Optional[Union[os.PathLike, str]] = None,
 94    key: Optional[str] = None,
 95    tile_shape: Optional[Tuple[int, int]] = None,
 96    halo: Optional[Tuple[int, int]] = None,
 97    verbose: bool = True,
 98    return_embeddings: bool = False,
 99    annotate: bool = False,
100    batch_size: int = 1,
101    **generate_kwargs
102) -> Tuple[np.ndarray, List[Dict]]:
103    """Run automatic tracking for the input timeseries.
104
105    Args:
106        predictor: The Segment Anything model.
107        segmenter: The automatic instance segmentation class.
108        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
109            or a container file (e.g. hdf5 or zarr).
110        output_path: The folder where the tracking outputs will be saved in CTC format.
111        embedding_path: The path where the embeddings are cached already / will be saved.
112        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
113            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
114        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
115        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
116        verbose: Verbosity flag. By default, set to 'True'.
117        return_embeddings: Whether to return the precomputed image embeddings.
118            By default, does not return the embeddings.
119        annotate: Whether to activate the annotator for continue annotation process.
120            By default, does not activate the annotator.
121        batch_size: The batch size to compute image embeddings over tiles / z-planes.
122            By default, does it sequentially, i.e. one after the other.
123        generate_kwargs: optional keyword arguments for the generate function of the AMG, APG, or AIS class.
124
125    Returns:
126        The tracking result as a timeseries, where each object is labeled by its track id.
127        The lineages representing cell divisions, stored as a dictionary.
128    """
129    # Load the input image file.
130    # We assume that it has to be read from file if it is a str or pathlike.
131    # Otherwise we assume it is a numpy array like object.
132    image_data = util.load_image_data(input_path, key) if isinstance(input_path, (str, os.PathLike)) else input_path
133
134    if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
135        raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
136
137    gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent")
138    segmentation, lineage, image_embeddings = automatic_tracking_implementation(
139        image_data,
140        predictor,
141        segmenter,
142        embedding_path=embedding_path,
143        gap_closing=gap_closing,
144        min_time_extent=min_time_extent,
145        tile_shape=tile_shape,
146        halo=halo,
147        verbose=verbose,
148        batch_size=batch_size,
149        return_embeddings=True,
150        output_folder=output_path,
151        **generate_kwargs,
152    )
153
154    if annotate:
155        # TODO We need to support initialization of the tracking annotator with the tracking result for this.
156        raise NotImplementedError("Annotation after running the automated tracking is currently not supported.")
157
158    if return_embeddings:
159        return segmentation, lineage, image_embeddings
160    else:
161        return segmentation, lineage
162
163
164def automatic_instance_segmentation(
165    predictor: util.SamPredictor,
166    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
167    input_path: Union[Union[os.PathLike, str], np.ndarray],
168    output_path: Optional[Union[os.PathLike, str]] = None,
169    embedding_path: Optional[Union[os.PathLike, str]] = None,
170    mask_path: Optional[Union[Union[os.PathLike, str], np.ndarray]] = None,
171    key: Optional[str] = None,
172    mask_key: Optional[str] = None,
173    ndim: Optional[int] = None,
174    tile_shape: Optional[Tuple[int, int]] = None,
175    halo: Optional[Tuple[int, int]] = None,
176    verbose: bool = True,
177    return_embeddings: bool = False,
178    annotate: bool = False,
179    batch_size: int = 1,
180    **generate_kwargs
181) -> np.ndarray:
182    """Run automatic segmentation for the input image.
183
184    Args:
185        predictor: The Segment Anything model.
186        segmenter: The automatic instance segmentation class.
187        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
188            or a container file (e.g. hdf5 or zarr).
189        output_path: The output path where the instance segmentations will be saved.
190        embedding_path: The path where the embeddings are cached already / will be saved.
191        mask_path: The path to an optional foreground mask. Areas outside of the foreground will not be processed.
192        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
193            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
194        mask_key: The key to the (optional) foreground mask.
195        ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
196            If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
197        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
198        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
199        verbose: Verbosity flag. By default, set to 'True'.
200        return_embeddings: Whether to return the precomputed image embeddings.
201            By default, does not return the embeddings.
202        annotate: Whether to activate the annotator for continue annotation process.
203            By default, does not activate the annotator.
204        batch_size: The batch size to compute image embeddings over tiles / z-planes.
205            By default, does it sequentially, i.e. one after the other.
206        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
207
208    Returns:
209        The segmentation result.
210    """
211    # Avoid overwriting already stored segmentations.
212    if output_path is not None:
213        output_path = Path(output_path).with_suffix(".tif")
214        if os.path.exists(output_path):
215            print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.")
216            return
217
218    # We assume that it has to be read from file if it is a str or pathlike.
219    # Otherwise we assume it is a numpy array like object.
220    image_data = util.load_image_data(input_path, key) if isinstance(input_path, (str, os.PathLike)) else input_path
221
222    ndim = image_data.ndim if ndim is None else ndim
223
224    # Load the mask defining foreground if it was given.
225    if mask_path is None:
226        mask = None
227    else:
228        mask = util.load_image_data(mask_path, mask_key) if isinstance(mask_path, (str, os.PathLike)) else mask_path
229
230    if ndim == 2:
231        if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3):
232            raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
233
234        # Precompute the image embeddings.
235        image_embeddings = util.precompute_image_embeddings(
236            predictor=predictor,
237            input_=image_data,
238            save_path=embedding_path,
239            ndim=ndim,
240            tile_shape=tile_shape,
241            halo=halo,
242            verbose=verbose,
243            batch_size=batch_size,
244            mask=mask,
245        )
246        initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose)
247        if mask is not None:
248            initialize_kwargs["mask"] = mask
249
250        # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing.
251        # In this case, we also add the batch size to the initialize kwargs,
252        # so that the segmentation decoder can be applied in a batched fashion.
253        if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None:
254            if not isinstance(segmenter, TiledAutomaticPromptGenerator):
255                generate_kwargs.update({"tile_shape": tile_shape, "halo": halo})
256            initialize_kwargs["batch_size"] = batch_size
257
258        segmenter.initialize(**initialize_kwargs)
259        instances = segmenter.generate(**generate_kwargs)
260
261    else:
262        if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
263            raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
264        if mask is not None:
265            raise NotImplementedError
266
267        instances, image_embeddings = automatic_3d_segmentation(
268            volume=image_data,
269            predictor=predictor,
270            segmentor=segmenter,
271            embedding_path=embedding_path,
272            tile_shape=tile_shape,
273            halo=halo,
274            verbose=verbose,
275            return_embeddings=True,
276            batch_size=batch_size,
277            **generate_kwargs
278        )
279
280    # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage.
281    if output_path is not None:
282        _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path
283        imageio.imwrite(_output_path, instances, compression="zlib")
284        if verbose:
285            print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.")
286
287    # Allow opening the automatic segmentation in the annotator for further annotation, if desired.
288    if annotate:
289        from micro_sam.sam_annotator import annotator_2d, annotator_3d
290        annotator_function = annotator_2d if ndim == 2 else annotator_3d
291
292        viewer = annotator_function(
293            image=image_data,
294            model_type=predictor.model_name,
295            embedding_path=image_embeddings,  # Providing the precomputed image embeddings.
296            segmentation_result=instances,  # Initializes the automatic segmentation to the annotator.
297            tile_shape=tile_shape,
298            halo=halo,
299            return_viewer=True,  # Returns the viewer, which allows the user to store the updated segmentations.
300        )
301
302        # Start the GUI here
303        import napari
304        napari.run()
305
306        # We extract the segmentation in "committed_objects" layer, where the user either:
307        # a) Performed interactive segmentation / corrections and committed them, OR
308        # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is.
309        instances = viewer.layers["committed_objects"].data
310
311        # Save the instance segmentation, if 'output_path' provided.
312        if output_path is not None:
313            imageio.imwrite(output_path, instances, compression="zlib")
314            if verbose:
315                print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.")
316
317    if return_embeddings:
318        return instances, image_embeddings
319    else:
320        return instances
321
322
323def _get_inputs_from_paths(paths, pattern):
324    "Function to get all filepaths in a directory."
325
326    if isinstance(paths, str):
327        paths = [paths]
328
329    fpaths = []
330    for path in paths:
331        if os.path.isfile(path):  # It is just one filepath.
332            fpaths.append(path)
333        else:  # Otherwise, if the path is a directory, fetch all inputs provided with a pattern.
334            assert pattern is not None, \
335                f"You must provide a pattern to search for files in the directory: '{os.path.abspath(path)}'."
336            fpaths.extend(glob(os.path.join(path, pattern)))
337
338    return fpaths
339
340
341def main():
342    """@private"""
343    import argparse
344
345    available_models = list(util.get_model_names())
346    available_models = ", ".join(available_models)
347
348    parser = argparse.ArgumentParser(
349        description="Run automatic segmentation or tracking for 2d, 3d or timeseries data.\n"
350        "Either a single input file or multiple input files are supported. You can specify multiple files "
351        "by either providing multiple filepaths to the '--i/--input_paths' argument, or by providing an argument "
352        "to '--pattern' to use a wildcard pattern ('*') for selecting multiple files.\n"
353        "NOTE: for automatic 3d segmentation or tracking the data has to be stored as volume / timeseries, "
354        "stacking individual tif images is not supported.\n"
355        "Segmentation is performed using one of the three modes supported by micro_sam: \n"
356        "automatic instance segmentation (AIS), automatic prompt generation (APG) or automatic mask generation (AMG).\n"
357        "In addition to the options listed below, "
358        "you can also passed additional arguments for these three segmentation modes:\n"
359        "For AIS: '--center_distance_threshold', '--boundary_distance_threshold' and other arguments of `InstanceSegmentationWithDecoder.generate`.\n"  # noqa
360        "FOR APG: '--center_distance_threshold', '--boundary_distance_threshold' and other arguments of `AutomaticPromptGenerator.generate`.\n"  # noqa
361        "For AMG: '--pred_iou_thresh', '--stability_score_thresh' and other arguments of `AutomaticMaskGenerator.generate`."  # noqa
362    )
363    parser.add_argument(
364        "-i", "--input_path", required=True, type=str, nargs="+",
365        help="The filepath(s) to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "  # noqa
366        "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
367    )
368    parser.add_argument(
369        "-o", "--output_path", required=True, type=str,
370        help="The filepath to store the results. If multiple inputs are provied, "
371        "this should be a folder. For a single image, you should provide the path to a tif file for the output segmentation."  # noqa
372        "NOTE: Segmentation results are stored as tif files, tracking results in the CTC fil format ."
373    )
374    parser.add_argument(
375        "-e", "--embedding_path", default=None, type=str,
376        help="An optional path where the embeddings will be saved. If multiple inputs are provided, "
377        "this should be a folder. Otherwise you can store embeddings in single zarr file."
378    )
379    parser.add_argument(
380        "--pattern", type=str, help="Pattern / wildcard for selecting files in a folder. To select all files use '*'."
381    )
382    parser.add_argument(
383        "-k", "--key", default=None, type=str,
384        help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
385        "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
386    )
387    parser.add_argument(
388        "-m", "--model_type", default=util._DEFAULT_MODEL, type=str,
389        help=f"The segment anything model that will be used, one of {available_models}."
390    )
391    parser.add_argument(
392        "-c", "--checkpoint", default=None, type=str, help="Checkpoint from which the SAM model will be loaded."
393    )
394    parser.add_argument(
395        "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None
396    )
397    parser.add_argument(
398        "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None
399    )
400    parser.add_argument(
401        "-n", "--ndim", default=None, type=int,
402        help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension."
403    )
404    parser.add_argument(
405        "--mode", default="auto", type=str,
406        help="The choice of automatic segmentation mode. Either 'auto', 'amg', 'apg', or 'ais'."
407    )
408    parser.add_argument(
409        "--annotate", action="store_true",
410        help="Whether to continue annotation after the automatic segmentation is generated."
411    )
412    parser.add_argument(
413        "-d", "--device", default=None, type=str,
414        help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)."
415        "By default the most performant available device will be selected."
416    )
417    parser.add_argument(
418        "--batch_size", type=int, default=1,
419        help="The batch size for computing image embeddings over tiles or z-plane. "
420        "By default, computes the image embeddings for one tile / z-plane at a time."
421    )
422    parser.add_argument(
423        "--tracking", action="store_true", help="Run automatic tracking instead of instance segmentation. "
424        "NOTE: It is only supported for timeseries inputs."
425    )
426    parser.add_argument(
427        "-v", "--verbose", action="store_true", help="Whether to allow verbosity of outputs."
428    )
429
430    args, parameter_args = parser.parse_known_args()
431
432    def _convert_argval(value):
433        # The values for the parsed arguments need to be in the expected input structure as provided.
434        # i.e. integers and floats should be in their original types.
435        try:
436            return int(value)
437        except ValueError:
438            return float(value)
439
440    # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to
441    # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS)
442    extra_kwargs = {
443        parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2)
444    }
445
446    # Separate extra arguments as per where they should be passed in the automatic segmentation class.
447    # This is done to ensure the extra arguments are allocated to the desired location.
448    # eg. for AMG, 'points_per_side' is expected by '__init__',
449    # and 'stability_score_thresh' is expected in 'generate' method.
450    mode = args.mode
451    if mode in ("auto", None):
452        # We have to load the state to see if we have a decoder in this case.
453        device = util.get_device(device=args.device)
454        predictor, state = util.get_sam_model(
455            model_type=args.model_type, device=device, checkpoint_path=args.checkpoint, return_state=True
456        )
457        mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg"
458    else:
459        predictor, state = None, None
460
461    if mode.lower() == "amg":
462        segmenter_class = AutomaticMaskGenerator if args.tile_shape is None else TiledAutomaticMaskGenerator
463    elif mode.lower() == "ais":
464        segmenter_class = InstanceSegmentationWithDecoder if args.tile_shape is None else\
465            TiledInstanceSegmentationWithDecoder
466    elif mode.lower() == "apg":
467        segmenter_class = AutomaticPromptGenerator if args.tile_shape is None else TiledAutomaticPromptGenerator
468    else:
469        raise ValueError(f"Invalid segmentation_mode: {mode}. Choose one of 'amg', 'ais', or 'apg'.")
470    init_kwargs, generate_kwargs = split_kwargs(segmenter_class, **extra_kwargs)
471
472    predictor, segmenter = get_predictor_and_segmenter(
473        model_type=args.model_type,
474        checkpoint=args.checkpoint,
475        device=args.device,
476        segmentation_mode=mode,
477        is_tiled=args.tile_shape is not None,
478        predictor=predictor,
479        state=state,
480        **init_kwargs,
481    )
482
483    # Get the filepaths to input images (and other paths to store stuff, eg. segmentations and embeddings)
484    # Check whether the inputs are as expected, otherwise assort them.
485    input_paths = _get_inputs_from_paths(args.input_path, args.pattern)
486    assert len(input_paths) > 0, "'micro-sam' could not extract any image data internally."
487
488    output_path = args.output_path
489    embedding_path = args.embedding_path
490    has_one_input = len(input_paths) == 1
491
492    instance_seg_function = automatic_tracking if args.tracking else partial(
493        automatic_instance_segmentation, ndim=args.ndim
494    )
495
496    # Run automatic segmentation per image.
497    for input_path in tqdm(input_paths, desc="Run automatic " + ("tracking" if args.tracking else "segmentation")):
498        if has_one_input:  # When we have only one image / volume.
499            _embedding_fpath = embedding_path  # Either folder or zarr file, would work for both.
500
501            output_fdir = os.path.splitext(output_path)[0]
502            os.makedirs(output_fdir, exist_ok=True)
503
504            # For tracking, we ensure that the output path is a folder,
505            # i.e. does not have an extension. We throw a warning if the user provided an extension.
506            if args.tracking:
507                if os.path.splitext(output_path)[-1]:
508                    warnings.warn(
509                        f"The output folder has an extension '{os.path.splitext(output_path)[-1]}'. "
510                        "We remove it and treat it as a folder to store tracking outputs in CTC format."
511                    )
512                _output_fpath = output_fdir
513            else:  # Otherwise, we can store outputs for user directly in the provided filepath, ensuring extension .tif
514                _output_fpath = f"{output_fdir}.tif"
515
516        else:  # When we have multiple images.
517            # Get the input filename, without the extension.
518            input_name = str(Path(input_path).stem)
519
520            # Let's check the 'embedding_path'.
521            if embedding_path is None:  # For computing embeddings on-the-fly, we don't care about the path logic.
522                _embedding_fpath = embedding_path
523            else:  # Otherwise, store each embeddings inside a folder.
524                embedding_folder = os.path.splitext(embedding_path)[0]  # Treat the provided embedding path as folder.
525                os.makedirs(embedding_folder, exist_ok=True)
526                _embedding_fpath = os.path.join(embedding_folder, f"{input_name}.zarr")  # Create each embedding file.
527
528            # Get the output folder name.
529            output_folder = os.path.splitext(output_path)[0]
530            os.makedirs(output_folder, exist_ok=True)
531
532            # Next, let's check for output file to store segmentation (or tracks).
533            if args.tracking:  # For tracking, we store CTC outputs in subfolders, with input_name as folder.
534                _output_fpath = os.path.join(output_folder, input_name)
535            else:  # Otherwise, store each result inside a folder.
536                _output_fpath = os.path.join(output_folder, f"{input_name}.tif")
537
538        instance_seg_function(
539            predictor=predictor,
540            segmenter=segmenter,
541            input_path=input_path,
542            output_path=_output_fpath,
543            embedding_path=_embedding_fpath,
544            key=args.key,
545            tile_shape=args.tile_shape,
546            halo=args.halo,
547            annotate=args.annotate,
548            verbose=args.verbose,
549            batch_size=args.batch_size,
550            **generate_kwargs,
551        )
def get_predictor_and_segmenter( model_type: str, checkpoint: Union[str, os.PathLike, NoneType] = None, device: str = None, segmentation_mode: Optional[Literal['amg', 'ais', 'apg']] = None, is_tiled: bool = False, predictor=None, state=None, **kwargs) -> Tuple[mobile_sam.predictor.SamPredictor, Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder]]:
26def get_predictor_and_segmenter(
27    model_type: str,
28    checkpoint: Optional[Union[os.PathLike, str]] = None,
29    device: str = None,
30    segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None,
31    is_tiled: bool = False,
32    predictor=None,
33    state=None,
34    **kwargs,
35) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]:
36    f"""Get the Segment Anything model and class for automatic instance segmentation.
37
38    Args:
39        model_type: The Segment Anything model choice.
40        checkpoint: The filepath to the stored model checkpoints.
41        device: The torch device. By default, automatically chooses the best available device.
42        segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
43            By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used
44            if a decoder is passed, otherwise 'amg' is used.
45        is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
46            By default, set to 'False'.
47        predictor: The pre-loaded predictor (optional).
48        state: The pre-loaded state (optional).
49        kwargs: Keyword arguments for the automatic mask generation class.
50
51    Returns:
52        The Segment Anything model.
53        The automatic instance segmentation class.
54    """
55    # Get the predictor and state for Segment Anything Model.
56    if predictor is None:
57        device = util.get_device(device=device)
58        predictor, state = util.get_sam_model(
59            model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True
60        )
61    else:
62        assert state is not None
63
64    if segmentation_mode in (None, "auto"):
65        segmentation_mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg"
66
67    if segmentation_mode.lower() == "amg":
68        decoder = None
69    else:
70        if "decoder_state" not in state:
71            raise RuntimeError(
72                f"You have passed 'segmentation_mode={segmentation_mode}', but your model does not contain a decoder."
73            )
74        decoder_state = state["decoder_state"]
75        decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)
76
77    segmenter = get_instance_segmentation_generator(
78        predictor=predictor, is_tiled=is_tiled, decoder=decoder, segmentation_mode=segmentation_mode, **kwargs
79    )
80    return predictor, segmenter
def automatic_tracking( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[str, os.PathLike, NoneType] = None, embedding_path: Union[str, os.PathLike, NoneType] = None, key: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, batch_size: int = 1, **generate_kwargs) -> Tuple[numpy.ndarray, List[Dict]]:
 89def automatic_tracking(
 90    predictor: util.SamPredictor,
 91    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
 92    input_path: Union[Union[os.PathLike, str], np.ndarray],
 93    output_path: Optional[Union[os.PathLike, str]] = None,
 94    embedding_path: Optional[Union[os.PathLike, str]] = None,
 95    key: Optional[str] = None,
 96    tile_shape: Optional[Tuple[int, int]] = None,
 97    halo: Optional[Tuple[int, int]] = None,
 98    verbose: bool = True,
 99    return_embeddings: bool = False,
100    annotate: bool = False,
101    batch_size: int = 1,
102    **generate_kwargs
103) -> Tuple[np.ndarray, List[Dict]]:
104    """Run automatic tracking for the input timeseries.
105
106    Args:
107        predictor: The Segment Anything model.
108        segmenter: The automatic instance segmentation class.
109        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
110            or a container file (e.g. hdf5 or zarr).
111        output_path: The folder where the tracking outputs will be saved in CTC format.
112        embedding_path: The path where the embeddings are cached already / will be saved.
113        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
114            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
115        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
116        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
117        verbose: Verbosity flag. By default, set to 'True'.
118        return_embeddings: Whether to return the precomputed image embeddings.
119            By default, does not return the embeddings.
120        annotate: Whether to activate the annotator for continue annotation process.
121            By default, does not activate the annotator.
122        batch_size: The batch size to compute image embeddings over tiles / z-planes.
123            By default, does it sequentially, i.e. one after the other.
124        generate_kwargs: optional keyword arguments for the generate function of the AMG, APG, or AIS class.
125
126    Returns:
127        The tracking result as a timeseries, where each object is labeled by its track id.
128        The lineages representing cell divisions, stored as a dictionary.
129    """
130    # Load the input image file.
131    # We assume that it has to be read from file if it is a str or pathlike.
132    # Otherwise we assume it is a numpy array like object.
133    image_data = util.load_image_data(input_path, key) if isinstance(input_path, (str, os.PathLike)) else input_path
134
135    if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
136        raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
137
138    gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent")
139    segmentation, lineage, image_embeddings = automatic_tracking_implementation(
140        image_data,
141        predictor,
142        segmenter,
143        embedding_path=embedding_path,
144        gap_closing=gap_closing,
145        min_time_extent=min_time_extent,
146        tile_shape=tile_shape,
147        halo=halo,
148        verbose=verbose,
149        batch_size=batch_size,
150        return_embeddings=True,
151        output_folder=output_path,
152        **generate_kwargs,
153    )
154
155    if annotate:
156        # TODO We need to support initialization of the tracking annotator with the tracking result for this.
157        raise NotImplementedError("Annotation after running the automated tracking is currently not supported.")
158
159    if return_embeddings:
160        return segmentation, lineage, image_embeddings
161    else:
162        return segmentation, lineage

Run automatic tracking for the input timeseries.

Arguments:
  • predictor: The Segment Anything model.
  • segmenter: The automatic instance segmentation class.
  • input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
  • output_path: The folder where the tracking outputs will be saved in CTC format.
  • embedding_path: The path where the embeddings are cached already / will be saved.
  • key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
  • verbose: Verbosity flag. By default, set to 'True'.
  • return_embeddings: Whether to return the precomputed image embeddings. By default, does not return the embeddings.
  • annotate: Whether to activate the annotator for continue annotation process. By default, does not activate the annotator.
  • batch_size: The batch size to compute image embeddings over tiles / z-planes. By default, does it sequentially, i.e. one after the other.
  • generate_kwargs: optional keyword arguments for the generate function of the AMG, APG, or AIS class.
Returns:

The tracking result as a timeseries, where each object is labeled by its track id. The lineages representing cell divisions, stored as a dictionary.

def automatic_instance_segmentation( predictor: mobile_sam.predictor.SamPredictor, segmenter: Union[micro_sam.instance_segmentation.AMGBase, micro_sam.instance_segmentation.InstanceSegmentationWithDecoder], input_path: Union[os.PathLike, str, numpy.ndarray], output_path: Union[str, os.PathLike, NoneType] = None, embedding_path: Union[str, os.PathLike, NoneType] = None, mask_path: Union[os.PathLike, str, numpy.ndarray, NoneType] = None, key: Optional[str] = None, mask_key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, return_embeddings: bool = False, annotate: bool = False, batch_size: int = 1, **generate_kwargs) -> numpy.ndarray:
165def automatic_instance_segmentation(
166    predictor: util.SamPredictor,
167    segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
168    input_path: Union[Union[os.PathLike, str], np.ndarray],
169    output_path: Optional[Union[os.PathLike, str]] = None,
170    embedding_path: Optional[Union[os.PathLike, str]] = None,
171    mask_path: Optional[Union[Union[os.PathLike, str], np.ndarray]] = None,
172    key: Optional[str] = None,
173    mask_key: Optional[str] = None,
174    ndim: Optional[int] = None,
175    tile_shape: Optional[Tuple[int, int]] = None,
176    halo: Optional[Tuple[int, int]] = None,
177    verbose: bool = True,
178    return_embeddings: bool = False,
179    annotate: bool = False,
180    batch_size: int = 1,
181    **generate_kwargs
182) -> np.ndarray:
183    """Run automatic segmentation for the input image.
184
185    Args:
186        predictor: The Segment Anything model.
187        segmenter: The automatic instance segmentation class.
188        input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
189            or a container file (e.g. hdf5 or zarr).
190        output_path: The output path where the instance segmentations will be saved.
191        embedding_path: The path where the embeddings are cached already / will be saved.
192        mask_path: The path to an optional foreground mask. Areas outside of the foreground will not be processed.
193        key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
194            or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
195        mask_key: The key to the (optional) foreground mask.
196        ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
197            If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
198        tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
199        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
200        verbose: Verbosity flag. By default, set to 'True'.
201        return_embeddings: Whether to return the precomputed image embeddings.
202            By default, does not return the embeddings.
203        annotate: Whether to activate the annotator for continue annotation process.
204            By default, does not activate the annotator.
205        batch_size: The batch size to compute image embeddings over tiles / z-planes.
206            By default, does it sequentially, i.e. one after the other.
207        generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
208
209    Returns:
210        The segmentation result.
211    """
212    # Avoid overwriting already stored segmentations.
213    if output_path is not None:
214        output_path = Path(output_path).with_suffix(".tif")
215        if os.path.exists(output_path):
216            print(f"The segmentation results are already stored at '{os.path.abspath(output_path)}'.")
217            return
218
219    # We assume that it has to be read from file if it is a str or pathlike.
220    # Otherwise we assume it is a numpy array like object.
221    image_data = util.load_image_data(input_path, key) if isinstance(input_path, (str, os.PathLike)) else input_path
222
223    ndim = image_data.ndim if ndim is None else ndim
224
225    # Load the mask defining foreground if it was given.
226    if mask_path is None:
227        mask = None
228    else:
229        mask = util.load_image_data(mask_path, mask_key) if isinstance(mask_path, (str, os.PathLike)) else mask_path
230
231    if ndim == 2:
232        if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3):
233            raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")
234
235        # Precompute the image embeddings.
236        image_embeddings = util.precompute_image_embeddings(
237            predictor=predictor,
238            input_=image_data,
239            save_path=embedding_path,
240            ndim=ndim,
241            tile_shape=tile_shape,
242            halo=halo,
243            verbose=verbose,
244            batch_size=batch_size,
245            mask=mask,
246        )
247        initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose)
248        if mask is not None:
249            initialize_kwargs["mask"] = mask
250
251        # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing.
252        # In this case, we also add the batch size to the initialize kwargs,
253        # so that the segmentation decoder can be applied in a batched fashion.
254        if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None:
255            if not isinstance(segmenter, TiledAutomaticPromptGenerator):
256                generate_kwargs.update({"tile_shape": tile_shape, "halo": halo})
257            initialize_kwargs["batch_size"] = batch_size
258
259        segmenter.initialize(**initialize_kwargs)
260        instances = segmenter.generate(**generate_kwargs)
261
262    else:
263        if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
264            raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
265        if mask is not None:
266            raise NotImplementedError
267
268        instances, image_embeddings = automatic_3d_segmentation(
269            volume=image_data,
270            predictor=predictor,
271            segmentor=segmenter,
272            embedding_path=embedding_path,
273            tile_shape=tile_shape,
274            halo=halo,
275            verbose=verbose,
276            return_embeddings=True,
277            batch_size=batch_size,
278            **generate_kwargs
279        )
280
281    # Before starting to annotate, if at all desired, store the automatic segmentations in the first stage.
282    if output_path is not None:
283        _output_path = _add_suffix_to_output_path(output_path, "_automatic") if annotate else output_path
284        imageio.imwrite(_output_path, instances, compression="zlib")
285        if verbose:
286            print(f"The automatic segmentation results are stored at '{os.path.abspath(_output_path)}'.")
287
288    # Allow opening the automatic segmentation in the annotator for further annotation, if desired.
289    if annotate:
290        from micro_sam.sam_annotator import annotator_2d, annotator_3d
291        annotator_function = annotator_2d if ndim == 2 else annotator_3d
292
293        viewer = annotator_function(
294            image=image_data,
295            model_type=predictor.model_name,
296            embedding_path=image_embeddings,  # Providing the precomputed image embeddings.
297            segmentation_result=instances,  # Initializes the automatic segmentation to the annotator.
298            tile_shape=tile_shape,
299            halo=halo,
300            return_viewer=True,  # Returns the viewer, which allows the user to store the updated segmentations.
301        )
302
303        # Start the GUI here
304        import napari
305        napari.run()
306
307        # We extract the segmentation in "committed_objects" layer, where the user either:
308        # a) Performed interactive segmentation / corrections and committed them, OR
309        # b) Did not do anything and closed the annotator, i.e. keeps the segmentations as it is.
310        instances = viewer.layers["committed_objects"].data
311
312        # Save the instance segmentation, if 'output_path' provided.
313        if output_path is not None:
314            imageio.imwrite(output_path, instances, compression="zlib")
315            if verbose:
316                print(f"The final segmentation results are stored at '{os.path.abspath(output_path)}'.")
317
318    if return_embeddings:
319        return instances, image_embeddings
320    else:
321        return instances

Run automatic segmentation for the input image.

Arguments:
  • predictor: The Segment Anything model.
  • segmenter: The automatic instance segmentation class.
  • input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr).
  • output_path: The output path where the instance segmentations will be saved.
  • embedding_path: The path where the embeddings are cached already / will be saved.
  • mask_path: The path to an optional foreground mask. Areas outside of the foreground will not be processed.
  • key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
  • mask_key: The key to the (optional) foreground mask.
  • ndim: The dimensionality of the data. By default the dimensionality of the data will be used. If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB.
  • tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
  • verbose: Verbosity flag. By default, set to 'True'.
  • return_embeddings: Whether to return the precomputed image embeddings. By default, does not return the embeddings.
  • annotate: Whether to activate the annotator for continue annotation process. By default, does not activate the annotator.
  • batch_size: The batch size to compute image embeddings over tiles / z-planes. By default, does it sequentially, i.e. one after the other.
  • generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
Returns:

The segmentation result.