micro_sam.sam_annotator.annotator_2d

  1from typing import Optional, Tuple, Union
  2
  3import napari
  4import numpy as np
  5
  6import torch
  7
  8from .. import util
  9from . import _widgets as widgets
 10from ._state import AnnotatorState
 11from ._annotator import _AnnotatorBase
 12from .util import _initialize_parser, _sync_embedding_widget
 13
 14
 15class Annotator2d(_AnnotatorBase):
 16    def _get_widgets(self):
 17        autosegment = widgets.AutoSegmentWidget(
 18            self._viewer, with_decoder=AnnotatorState().decoder is not None, volumetric=False
 19        )
 20        return {
 21            "segment": widgets.segment(),
 22            "autosegment": autosegment,
 23            "commit": widgets.commit(),
 24            "clear": widgets.clear(),
 25        }
 26
 27    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
 28        super().__init__(viewer=viewer, ndim=2)
 29
 30        # Set the expected annotator class to the state.
 31        state = AnnotatorState()
 32
 33        # Reset the state.
 34        if reset_state:
 35            state.reset_state()
 36
 37        state.annotator = self
 38
 39
 40def annotator_2d(
 41    image: np.ndarray,
 42    embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None,
 43    segmentation_result: Optional[np.ndarray] = None,
 44    model_type: str = util._DEFAULT_MODEL,
 45    tile_shape: Optional[Tuple[int, int]] = None,
 46    halo: Optional[Tuple[int, int]] = None,
 47    return_viewer: bool = False,
 48    viewer: Optional["napari.viewer.Viewer"] = None,
 49    precompute_amg_state: bool = False,
 50    checkpoint_path: Optional[str] = None,
 51    decoder_path: Optional[str] = None,
 52    device: Optional[Union[str, torch.device]] = None,
 53    prefer_decoder: bool = True,
 54) -> Optional["napari.viewer.Viewer"]:
 55    """Start the 2d annotation tool for a given image.
 56
 57    Args:
 58        image: The image data.
 59        embedding_path: Filepath where to save the embeddings
 60            or the precompted image embeddings computed by `precompute_image_embeddings`.
 61        segmentation_result: An initial segmentation to load.
 62            This can be used to correct segmentations with Segment Anything or to save and load progress.
 63            The segmentation will be loaded as the 'committed_objects' layer.
 64        model_type: The Segment Anything model to use. For details on the available models check out
 65            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
 66        tile_shape: Shape of tiles for tiled embedding prediction.
 67            If `None` then the whole image is passed to Segment Anything.
 68        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
 69        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
 70            By default, does not return the napari viewer.
 71        viewer: The viewer to which the Segment Anything functionality should be added.
 72            This enables using a pre-initialized viewer.
 73        precompute_amg_state: Whether to precompute the state for automatic mask generation.
 74            This will take more time when precomputing embeddings, but will then make
 75            automatic mask generation much faster. By default, set to 'False'.
 76        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
 77        decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder.
 78        device: The computational device to use for the SAM model.
 79            By default, automatically chooses the best available device.
 80        prefer_decoder: Whether to use decoder based instance segmentation if
 81            the model used has an additional decoder for instance segmentation.
 82            By default, set to 'True'.
 83
 84    Returns:
 85        The napari viewer, only returned if `return_viewer=True`.
 86    """
 87
 88    state = AnnotatorState()
 89    state.image_shape = image.shape[:-1] if image.ndim == 3 else image.shape
 90
 91    state.initialize_predictor(
 92        image, model_type=model_type, save_path=embedding_path,
 93        halo=halo, tile_shape=tile_shape, precompute_amg_state=precompute_amg_state,
 94        ndim=2, checkpoint_path=checkpoint_path, decoder_path=decoder_path,
 95        device=device, prefer_decoder=prefer_decoder,
 96        skip_load=False, use_cli=True,
 97    )
 98
 99    if viewer is None:
100        viewer = napari.Viewer()
101
102    viewer.add_image(image, name="image")
103    annotator = Annotator2d(viewer, reset_state=False)
104
105    # Trigger layer update of the annotator so that layers have the correct shape.
106    # And initialize the 'committed_objects' with the segmentation result if it was given.
107    annotator._update_image(segmentation_result=segmentation_result)
108
109    # Add the annotator widget to the viewer and sync widgets.
110    viewer.window.add_dock_widget(annotator)
111    _sync_embedding_widget(
112        widget=state.widgets["embeddings"],
113        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
114        save_path=embedding_path,
115        checkpoint_path=checkpoint_path,
116        device=device,
117        tile_shape=tile_shape,
118        halo=halo,
119    )
120
121    if return_viewer:
122        return viewer
123
124    napari.run()
125
126
127def main():
128    """@private"""
129    parser = _initialize_parser(description="Run interactive segmentation for an image.")
130    args = parser.parse_args()
131    image = util.load_image_data(args.input, key=args.key)
132
133    if args.segmentation_result is None:
134        segmentation_result = None
135    else:
136        segmentation_result = util.load_image_data(args.segmentation_result, key=args.segmentation_key)
137
138    annotator_2d(
139        image, embedding_path=args.embedding_path,
140        segmentation_result=segmentation_result,
141        model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo,
142        precompute_amg_state=args.precompute_amg_state, checkpoint_path=args.checkpoint,
143        decoder_path=args.decoder_path, device=args.device, prefer_decoder=args.prefer_decoder,
144    )
class Annotator2d(micro_sam.sam_annotator._annotator._AnnotatorBase):
16class Annotator2d(_AnnotatorBase):
17    def _get_widgets(self):
18        autosegment = widgets.AutoSegmentWidget(
19            self._viewer, with_decoder=AnnotatorState().decoder is not None, volumetric=False
20        )
21        return {
22            "segment": widgets.segment(),
23            "autosegment": autosegment,
24            "commit": widgets.commit(),
25            "clear": widgets.clear(),
26        }
27
28    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
29        super().__init__(viewer=viewer, ndim=2)
30
31        # Set the expected annotator class to the state.
32        state = AnnotatorState()
33
34        # Reset the state.
35        if reset_state:
36            state.reset_state()
37
38        state.annotator = self

Base class for micro_sam annotation plugins.

Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.

Annotator2d(viewer: napari.viewer.Viewer, reset_state: bool = True)
28    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
29        super().__init__(viewer=viewer, ndim=2)
30
31        # Set the expected annotator class to the state.
32        state = AnnotatorState()
33
34        # Reset the state.
35        if reset_state:
36            state.reset_state()
37
38        state.annotator = self

Create the annotator GUI.

Arguments:
  • viewer: The napari viewer.
  • ndim: The number of spatial dimension of the image data (2 or 3).
def annotator_2d( image: numpy.ndarray, embedding_path: Union[str, Dict[str, Any], NoneType] = None, segmentation_result: Optional[numpy.ndarray] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, decoder_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True) -> Optional[napari.viewer.Viewer]:
 41def annotator_2d(
 42    image: np.ndarray,
 43    embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None,
 44    segmentation_result: Optional[np.ndarray] = None,
 45    model_type: str = util._DEFAULT_MODEL,
 46    tile_shape: Optional[Tuple[int, int]] = None,
 47    halo: Optional[Tuple[int, int]] = None,
 48    return_viewer: bool = False,
 49    viewer: Optional["napari.viewer.Viewer"] = None,
 50    precompute_amg_state: bool = False,
 51    checkpoint_path: Optional[str] = None,
 52    decoder_path: Optional[str] = None,
 53    device: Optional[Union[str, torch.device]] = None,
 54    prefer_decoder: bool = True,
 55) -> Optional["napari.viewer.Viewer"]:
 56    """Start the 2d annotation tool for a given image.
 57
 58    Args:
 59        image: The image data.
 60        embedding_path: Filepath where to save the embeddings
 61            or the precompted image embeddings computed by `precompute_image_embeddings`.
 62        segmentation_result: An initial segmentation to load.
 63            This can be used to correct segmentations with Segment Anything or to save and load progress.
 64            The segmentation will be loaded as the 'committed_objects' layer.
 65        model_type: The Segment Anything model to use. For details on the available models check out
 66            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
 67        tile_shape: Shape of tiles for tiled embedding prediction.
 68            If `None` then the whole image is passed to Segment Anything.
 69        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
 70        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
 71            By default, does not return the napari viewer.
 72        viewer: The viewer to which the Segment Anything functionality should be added.
 73            This enables using a pre-initialized viewer.
 74        precompute_amg_state: Whether to precompute the state for automatic mask generation.
 75            This will take more time when precomputing embeddings, but will then make
 76            automatic mask generation much faster. By default, set to 'False'.
 77        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
 78        decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder.
 79        device: The computational device to use for the SAM model.
 80            By default, automatically chooses the best available device.
 81        prefer_decoder: Whether to use decoder based instance segmentation if
 82            the model used has an additional decoder for instance segmentation.
 83            By default, set to 'True'.
 84
 85    Returns:
 86        The napari viewer, only returned if `return_viewer=True`.
 87    """
 88
 89    state = AnnotatorState()
 90    state.image_shape = image.shape[:-1] if image.ndim == 3 else image.shape
 91
 92    state.initialize_predictor(
 93        image, model_type=model_type, save_path=embedding_path,
 94        halo=halo, tile_shape=tile_shape, precompute_amg_state=precompute_amg_state,
 95        ndim=2, checkpoint_path=checkpoint_path, decoder_path=decoder_path,
 96        device=device, prefer_decoder=prefer_decoder,
 97        skip_load=False, use_cli=True,
 98    )
 99
100    if viewer is None:
101        viewer = napari.Viewer()
102
103    viewer.add_image(image, name="image")
104    annotator = Annotator2d(viewer, reset_state=False)
105
106    # Trigger layer update of the annotator so that layers have the correct shape.
107    # And initialize the 'committed_objects' with the segmentation result if it was given.
108    annotator._update_image(segmentation_result=segmentation_result)
109
110    # Add the annotator widget to the viewer and sync widgets.
111    viewer.window.add_dock_widget(annotator)
112    _sync_embedding_widget(
113        widget=state.widgets["embeddings"],
114        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
115        save_path=embedding_path,
116        checkpoint_path=checkpoint_path,
117        device=device,
118        tile_shape=tile_shape,
119        halo=halo,
120    )
121
122    if return_viewer:
123        return viewer
124
125    napari.run()

Start the 2d annotation tool for a given image.

Arguments:
  • image: The image data.
  • embedding_path: Filepath where to save the embeddings or the precompted image embeddings computed by precompute_image_embeddings.
  • segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster. By default, set to 'False'.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder.
  • device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
  • prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation. By default, set to 'True'.
Returns:

The napari viewer, only returned if return_viewer=True.