micro_sam.sam_annotator.annotator_3d

  1from typing import Optional, Tuple, Union
  2
  3import napari
  4import numpy as np
  5
  6import torch
  7
  8from .. import util
  9from . import _widgets as widgets
 10from ._state import AnnotatorState
 11from ._annotator import _AnnotatorBase
 12from .util import _initialize_parser, _sync_embedding_widget, _load_amg_state, _load_is_state
 13
 14
 15class Annotator3d(_AnnotatorBase):
 16    def _get_widgets(self):
 17        autosegment = widgets.AutoSegmentWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True)
 18        segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=False)
 19        return {
 20            "segment": widgets.segment_slice(),
 21            "segment_nd": segment_nd,
 22            "autosegment": autosegment,
 23            "commit": widgets.commit(),
 24            "clear": widgets.clear_volume(),
 25        }
 26
 27    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
 28        self._with_decoder = AnnotatorState().decoder is not None
 29        super().__init__(viewer=viewer, ndim=3)
 30
 31        # Set the expected annotator class to the state.
 32        state = AnnotatorState()
 33
 34        # Reset the state.
 35        if reset_state:
 36            state.reset_state()
 37
 38        state.annotator = self
 39
 40    def _update_image(self, segmentation_result=None):
 41        super()._update_image(segmentation_result=segmentation_result)
 42        # Load the amg state from the embedding path.
 43        state = AnnotatorState()
 44        if self._with_decoder:
 45            state.amg_state = _load_is_state(state.embedding_path)
 46        else:
 47            state.amg_state = _load_amg_state(state.embedding_path)
 48
 49
 50def annotator_3d(
 51    image: np.ndarray,
 52    embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None,
 53    segmentation_result: Optional[np.ndarray] = None,
 54    model_type: str = util._DEFAULT_MODEL,
 55    tile_shape: Optional[Tuple[int, int]] = None,
 56    halo: Optional[Tuple[int, int]] = None,
 57    return_viewer: bool = False,
 58    viewer: Optional["napari.viewer.Viewer"] = None,
 59    precompute_amg_state: bool = False,
 60    checkpoint_path: Optional[str] = None,
 61    decoder_path: Optional[str] = None,
 62    device: Optional[Union[str, torch.device]] = None,
 63    prefer_decoder: bool = True,
 64) -> Optional["napari.viewer.Viewer"]:
 65    """Start the 3d annotation tool for a given image volume.
 66
 67    Args:
 68        image: The volumetric image data.
 69        embedding_path: Filepath where to save the embeddings
 70            or the precompted image embeddings computed by `precompute_image_embeddings`.
 71        segmentation_result: An initial segmentation to load.
 72            This can be used to correct segmentations with Segment Anything or to save and load progress.
 73            The segmentation will be loaded as the 'committed_objects' layer.
 74        model_type: The Segment Anything model to use. For details on the available models check out
 75            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
 76        tile_shape: Shape of tiles for tiled embedding prediction.
 77            If `None` then the whole image is passed to Segment Anything.
 78        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
 79        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
 80            By default, does not return the napari viewer.
 81        viewer: The viewer to which the Segment Anything functionality should be added.
 82            This enables using a pre-initialized viewer.
 83        precompute_amg_state: Whether to precompute the state for automatic mask generation.
 84            This will take more time when precomputing embeddings, but will then make
 85            automatic mask generation much faster. By default, set to 'False'.
 86        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
 87        decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder.
 88        device: The computational device to use for the SAM model.
 89            By default, automatically chooses the best available device.
 90        prefer_decoder: Whether to use decoder based instance segmentation if
 91            the model used has an additional decoder for instance segmentation.
 92            By default, set to 'True'.
 93
 94    Returns:
 95        The napari viewer, only returned if `return_viewer=True`.
 96    """
 97
 98    # Initialize the predictor state.
 99    state = AnnotatorState()
100    state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape
101    state.initialize_predictor(
102        image, model_type=model_type, save_path=embedding_path,
103        halo=halo, tile_shape=tile_shape, ndim=3, precompute_amg_state=precompute_amg_state,
104        checkpoint_path=checkpoint_path, decoder_path=decoder_path,
105        device=device, prefer_decoder=prefer_decoder,
106        use_cli=True,
107    )
108
109    if viewer is None:
110        viewer = napari.Viewer()
111
112    viewer.add_image(image, name="image")
113    annotator = Annotator3d(viewer, reset_state=False)
114
115    # Trigger layer update of the annotator so that layers have the correct shape.
116    # And initialize the 'committed_objects' with the segmentation result if it was given.
117    annotator._update_image(segmentation_result=segmentation_result)
118
119    # Add the annotator widget to the viewer and sync widgets.
120    viewer.window.add_dock_widget(annotator)
121    _sync_embedding_widget(
122        widget=state.widgets["embeddings"],
123        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
124        save_path=embedding_path,
125        checkpoint_path=checkpoint_path,
126        device=device,
127        tile_shape=tile_shape,
128        halo=halo
129    )
130
131    if return_viewer:
132        return viewer
133
134    napari.run()
135
136
137def main():
138    """@private"""
139    parser = _initialize_parser(description="Run interactive segmentation for an image volume.")
140    args = parser.parse_args()
141    image = util.load_image_data(args.input, key=args.key)
142
143    if args.segmentation_result is None:
144        segmentation_result = None
145    else:
146        segmentation_result = util.load_image_data(args.segmentation_result, key=args.segmentation_key)
147
148    annotator_3d(
149        image, embedding_path=args.embedding_path,
150        segmentation_result=segmentation_result,
151        model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo,
152        checkpoint_path=args.checkpoint, device=args.device,
153        precompute_amg_state=args.precompute_amg_state, prefer_decoder=args.prefer_decoder,
154        decoder_path=args.decoder_path,
155    )
class Annotator3d(micro_sam.sam_annotator._annotator._AnnotatorBase):
16class Annotator3d(_AnnotatorBase):
17    def _get_widgets(self):
18        autosegment = widgets.AutoSegmentWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True)
19        segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=False)
20        return {
21            "segment": widgets.segment_slice(),
22            "segment_nd": segment_nd,
23            "autosegment": autosegment,
24            "commit": widgets.commit(),
25            "clear": widgets.clear_volume(),
26        }
27
28    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
29        self._with_decoder = AnnotatorState().decoder is not None
30        super().__init__(viewer=viewer, ndim=3)
31
32        # Set the expected annotator class to the state.
33        state = AnnotatorState()
34
35        # Reset the state.
36        if reset_state:
37            state.reset_state()
38
39        state.annotator = self
40
41    def _update_image(self, segmentation_result=None):
42        super()._update_image(segmentation_result=segmentation_result)
43        # Load the amg state from the embedding path.
44        state = AnnotatorState()
45        if self._with_decoder:
46            state.amg_state = _load_is_state(state.embedding_path)
47        else:
48            state.amg_state = _load_amg_state(state.embedding_path)

Base class for micro_sam annotation plugins.

Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.

Annotator3d(viewer: napari.viewer.Viewer, reset_state: bool = True)
28    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
29        self._with_decoder = AnnotatorState().decoder is not None
30        super().__init__(viewer=viewer, ndim=3)
31
32        # Set the expected annotator class to the state.
33        state = AnnotatorState()
34
35        # Reset the state.
36        if reset_state:
37            state.reset_state()
38
39        state.annotator = self

Create the annotator GUI.

Arguments:
  • viewer: The napari viewer.
  • ndim: The number of spatial dimension of the image data (2 or 3).
def annotator_3d( image: numpy.ndarray, embedding_path: Union[str, Dict[str, Any], NoneType] = None, segmentation_result: Optional[numpy.ndarray] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, decoder_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True) -> Optional[napari.viewer.Viewer]:
 51def annotator_3d(
 52    image: np.ndarray,
 53    embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None,
 54    segmentation_result: Optional[np.ndarray] = None,
 55    model_type: str = util._DEFAULT_MODEL,
 56    tile_shape: Optional[Tuple[int, int]] = None,
 57    halo: Optional[Tuple[int, int]] = None,
 58    return_viewer: bool = False,
 59    viewer: Optional["napari.viewer.Viewer"] = None,
 60    precompute_amg_state: bool = False,
 61    checkpoint_path: Optional[str] = None,
 62    decoder_path: Optional[str] = None,
 63    device: Optional[Union[str, torch.device]] = None,
 64    prefer_decoder: bool = True,
 65) -> Optional["napari.viewer.Viewer"]:
 66    """Start the 3d annotation tool for a given image volume.
 67
 68    Args:
 69        image: The volumetric image data.
 70        embedding_path: Filepath where to save the embeddings
 71            or the precompted image embeddings computed by `precompute_image_embeddings`.
 72        segmentation_result: An initial segmentation to load.
 73            This can be used to correct segmentations with Segment Anything or to save and load progress.
 74            The segmentation will be loaded as the 'committed_objects' layer.
 75        model_type: The Segment Anything model to use. For details on the available models check out
 76            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
 77        tile_shape: Shape of tiles for tiled embedding prediction.
 78            If `None` then the whole image is passed to Segment Anything.
 79        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
 80        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
 81            By default, does not return the napari viewer.
 82        viewer: The viewer to which the Segment Anything functionality should be added.
 83            This enables using a pre-initialized viewer.
 84        precompute_amg_state: Whether to precompute the state for automatic mask generation.
 85            This will take more time when precomputing embeddings, but will then make
 86            automatic mask generation much faster. By default, set to 'False'.
 87        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
 88        decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder.
 89        device: The computational device to use for the SAM model.
 90            By default, automatically chooses the best available device.
 91        prefer_decoder: Whether to use decoder based instance segmentation if
 92            the model used has an additional decoder for instance segmentation.
 93            By default, set to 'True'.
 94
 95    Returns:
 96        The napari viewer, only returned if `return_viewer=True`.
 97    """
 98
 99    # Initialize the predictor state.
100    state = AnnotatorState()
101    state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape
102    state.initialize_predictor(
103        image, model_type=model_type, save_path=embedding_path,
104        halo=halo, tile_shape=tile_shape, ndim=3, precompute_amg_state=precompute_amg_state,
105        checkpoint_path=checkpoint_path, decoder_path=decoder_path,
106        device=device, prefer_decoder=prefer_decoder,
107        use_cli=True,
108    )
109
110    if viewer is None:
111        viewer = napari.Viewer()
112
113    viewer.add_image(image, name="image")
114    annotator = Annotator3d(viewer, reset_state=False)
115
116    # Trigger layer update of the annotator so that layers have the correct shape.
117    # And initialize the 'committed_objects' with the segmentation result if it was given.
118    annotator._update_image(segmentation_result=segmentation_result)
119
120    # Add the annotator widget to the viewer and sync widgets.
121    viewer.window.add_dock_widget(annotator)
122    _sync_embedding_widget(
123        widget=state.widgets["embeddings"],
124        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
125        save_path=embedding_path,
126        checkpoint_path=checkpoint_path,
127        device=device,
128        tile_shape=tile_shape,
129        halo=halo
130    )
131
132    if return_viewer:
133        return viewer
134
135    napari.run()

Start the 3d annotation tool for a given image volume.

Arguments:
  • image: The volumetric image data.
  • embedding_path: Filepath where to save the embeddings or the precompted image embeddings computed by precompute_image_embeddings.
  • segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster. By default, set to 'False'.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder.
  • device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
  • prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation. By default, set to 'True'.
Returns:

The napari viewer, only returned if return_viewer=True.