micro_sam.sam_annotator.annotator_3d
1from typing import Optional, Tuple, Union 2 3import napari 4import numpy as np 5 6import torch 7 8from .. import util 9from . import _widgets as widgets 10from ._state import AnnotatorState 11from ._annotator import _AnnotatorBase 12from .util import _initialize_parser, _sync_embedding_widget, _load_amg_state, _load_is_state 13 14 15class Annotator3d(_AnnotatorBase): 16 def _get_widgets(self): 17 autosegment = widgets.AutoSegmentWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True) 18 segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=False) 19 return { 20 "segment": widgets.segment_slice(), 21 "segment_nd": segment_nd, 22 "autosegment": autosegment, 23 "commit": widgets.commit(), 24 "clear": widgets.clear_volume(), 25 } 26 27 def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: 28 self._with_decoder = AnnotatorState().decoder is not None 29 super().__init__(viewer=viewer, ndim=3) 30 31 # Set the expected annotator class to the state. 32 state = AnnotatorState() 33 34 # Reset the state. 35 if reset_state: 36 state.reset_state() 37 38 state.annotator = self 39 40 def _update_image(self, segmentation_result=None): 41 super()._update_image(segmentation_result=segmentation_result) 42 # Load the amg state from the embedding path. 43 state = AnnotatorState() 44 if self._with_decoder: 45 state.amg_state = _load_is_state(state.embedding_path) 46 else: 47 state.amg_state = _load_amg_state(state.embedding_path) 48 49 50def annotator_3d( 51 image: np.ndarray, 52 embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None, 53 segmentation_result: Optional[np.ndarray] = None, 54 model_type: str = util._DEFAULT_MODEL, 55 tile_shape: Optional[Tuple[int, int]] = None, 56 halo: Optional[Tuple[int, int]] = None, 57 return_viewer: bool = False, 58 viewer: Optional["napari.viewer.Viewer"] = None, 59 precompute_amg_state: bool = False, 60 checkpoint_path: Optional[str] = None, 61 decoder_path: Optional[str] = None, 62 device: Optional[Union[str, torch.device]] = None, 63 prefer_decoder: bool = True, 64) -> Optional["napari.viewer.Viewer"]: 65 """Start the 3d annotation tool for a given image volume. 66 67 Args: 68 image: The volumetric image data. 69 embedding_path: Filepath where to save the embeddings 70 or the precompted image embeddings computed by `precompute_image_embeddings`. 71 segmentation_result: An initial segmentation to load. 72 This can be used to correct segmentations with Segment Anything or to save and load progress. 73 The segmentation will be loaded as the 'committed_objects' layer. 74 model_type: The Segment Anything model to use. For details on the available models check out 75 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 76 tile_shape: Shape of tiles for tiled embedding prediction. 77 If `None` then the whole image is passed to Segment Anything. 78 halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders. 79 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 80 By default, does not return the napari viewer. 81 viewer: The viewer to which the Segment Anything functionality should be added. 82 This enables using a pre-initialized viewer. 83 precompute_amg_state: Whether to precompute the state for automatic mask generation. 84 This will take more time when precomputing embeddings, but will then make 85 automatic mask generation much faster. By default, set to 'False'. 86 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 87 decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder. 88 device: The computational device to use for the SAM model. 89 By default, automatically chooses the best available device. 90 prefer_decoder: Whether to use decoder based instance segmentation if 91 the model used has an additional decoder for instance segmentation. 92 By default, set to 'True'. 93 94 Returns: 95 The napari viewer, only returned if `return_viewer=True`. 96 """ 97 98 # Initialize the predictor state. 99 state = AnnotatorState() 100 state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape 101 state.initialize_predictor( 102 image, model_type=model_type, save_path=embedding_path, 103 halo=halo, tile_shape=tile_shape, ndim=3, precompute_amg_state=precompute_amg_state, 104 checkpoint_path=checkpoint_path, decoder_path=decoder_path, 105 device=device, prefer_decoder=prefer_decoder, 106 use_cli=True, 107 ) 108 109 if viewer is None: 110 viewer = napari.Viewer() 111 112 viewer.add_image(image, name="image") 113 annotator = Annotator3d(viewer, reset_state=False) 114 115 # Trigger layer update of the annotator so that layers have the correct shape. 116 # And initialize the 'committed_objects' with the segmentation result if it was given. 117 annotator._update_image(segmentation_result=segmentation_result) 118 119 # Add the annotator widget to the viewer and sync widgets. 120 viewer.window.add_dock_widget(annotator) 121 _sync_embedding_widget( 122 widget=state.widgets["embeddings"], 123 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 124 save_path=embedding_path, 125 checkpoint_path=checkpoint_path, 126 device=device, 127 tile_shape=tile_shape, 128 halo=halo 129 ) 130 131 if return_viewer: 132 return viewer 133 134 napari.run() 135 136 137def main(): 138 """@private""" 139 parser = _initialize_parser(description="Run interactive segmentation for an image volume.") 140 args = parser.parse_args() 141 image = util.load_image_data(args.input, key=args.key) 142 143 if args.segmentation_result is None: 144 segmentation_result = None 145 else: 146 segmentation_result = util.load_image_data(args.segmentation_result, key=args.segmentation_key) 147 148 annotator_3d( 149 image, embedding_path=args.embedding_path, 150 segmentation_result=segmentation_result, 151 model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo, 152 checkpoint_path=args.checkpoint, device=args.device, 153 precompute_amg_state=args.precompute_amg_state, prefer_decoder=args.prefer_decoder, 154 decoder_path=args.decoder_path, 155 )
16class Annotator3d(_AnnotatorBase): 17 def _get_widgets(self): 18 autosegment = widgets.AutoSegmentWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True) 19 segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=False) 20 return { 21 "segment": widgets.segment_slice(), 22 "segment_nd": segment_nd, 23 "autosegment": autosegment, 24 "commit": widgets.commit(), 25 "clear": widgets.clear_volume(), 26 } 27 28 def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: 29 self._with_decoder = AnnotatorState().decoder is not None 30 super().__init__(viewer=viewer, ndim=3) 31 32 # Set the expected annotator class to the state. 33 state = AnnotatorState() 34 35 # Reset the state. 36 if reset_state: 37 state.reset_state() 38 39 state.annotator = self 40 41 def _update_image(self, segmentation_result=None): 42 super()._update_image(segmentation_result=segmentation_result) 43 # Load the amg state from the embedding path. 44 state = AnnotatorState() 45 if self._with_decoder: 46 state.amg_state = _load_is_state(state.embedding_path) 47 else: 48 state.amg_state = _load_amg_state(state.embedding_path)
Base class for micro_sam annotation plugins.
Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.
Annotator3d(viewer: napari.viewer.Viewer, reset_state: bool = True)
28 def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: 29 self._with_decoder = AnnotatorState().decoder is not None 30 super().__init__(viewer=viewer, ndim=3) 31 32 # Set the expected annotator class to the state. 33 state = AnnotatorState() 34 35 # Reset the state. 36 if reset_state: 37 state.reset_state() 38 39 state.annotator = self
Create the annotator GUI.
Arguments:
- viewer: The napari viewer.
- ndim: The number of spatial dimension of the image data (2 or 3).
def
annotator_3d( image: numpy.ndarray, embedding_path: Union[str, Dict[str, Any], NoneType] = None, segmentation_result: Optional[numpy.ndarray] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, decoder_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True) -> Optional[napari.viewer.Viewer]:
51def annotator_3d( 52 image: np.ndarray, 53 embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None, 54 segmentation_result: Optional[np.ndarray] = None, 55 model_type: str = util._DEFAULT_MODEL, 56 tile_shape: Optional[Tuple[int, int]] = None, 57 halo: Optional[Tuple[int, int]] = None, 58 return_viewer: bool = False, 59 viewer: Optional["napari.viewer.Viewer"] = None, 60 precompute_amg_state: bool = False, 61 checkpoint_path: Optional[str] = None, 62 decoder_path: Optional[str] = None, 63 device: Optional[Union[str, torch.device]] = None, 64 prefer_decoder: bool = True, 65) -> Optional["napari.viewer.Viewer"]: 66 """Start the 3d annotation tool for a given image volume. 67 68 Args: 69 image: The volumetric image data. 70 embedding_path: Filepath where to save the embeddings 71 or the precompted image embeddings computed by `precompute_image_embeddings`. 72 segmentation_result: An initial segmentation to load. 73 This can be used to correct segmentations with Segment Anything or to save and load progress. 74 The segmentation will be loaded as the 'committed_objects' layer. 75 model_type: The Segment Anything model to use. For details on the available models check out 76 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 77 tile_shape: Shape of tiles for tiled embedding prediction. 78 If `None` then the whole image is passed to Segment Anything. 79 halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders. 80 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 81 By default, does not return the napari viewer. 82 viewer: The viewer to which the Segment Anything functionality should be added. 83 This enables using a pre-initialized viewer. 84 precompute_amg_state: Whether to precompute the state for automatic mask generation. 85 This will take more time when precomputing embeddings, but will then make 86 automatic mask generation much faster. By default, set to 'False'. 87 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 88 decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder. 89 device: The computational device to use for the SAM model. 90 By default, automatically chooses the best available device. 91 prefer_decoder: Whether to use decoder based instance segmentation if 92 the model used has an additional decoder for instance segmentation. 93 By default, set to 'True'. 94 95 Returns: 96 The napari viewer, only returned if `return_viewer=True`. 97 """ 98 99 # Initialize the predictor state. 100 state = AnnotatorState() 101 state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape 102 state.initialize_predictor( 103 image, model_type=model_type, save_path=embedding_path, 104 halo=halo, tile_shape=tile_shape, ndim=3, precompute_amg_state=precompute_amg_state, 105 checkpoint_path=checkpoint_path, decoder_path=decoder_path, 106 device=device, prefer_decoder=prefer_decoder, 107 use_cli=True, 108 ) 109 110 if viewer is None: 111 viewer = napari.Viewer() 112 113 viewer.add_image(image, name="image") 114 annotator = Annotator3d(viewer, reset_state=False) 115 116 # Trigger layer update of the annotator so that layers have the correct shape. 117 # And initialize the 'committed_objects' with the segmentation result if it was given. 118 annotator._update_image(segmentation_result=segmentation_result) 119 120 # Add the annotator widget to the viewer and sync widgets. 121 viewer.window.add_dock_widget(annotator) 122 _sync_embedding_widget( 123 widget=state.widgets["embeddings"], 124 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 125 save_path=embedding_path, 126 checkpoint_path=checkpoint_path, 127 device=device, 128 tile_shape=tile_shape, 129 halo=halo 130 ) 131 132 if return_viewer: 133 return viewer 134 135 napari.run()
Start the 3d annotation tool for a given image volume.
Arguments:
- image: The volumetric image data.
- embedding_path: Filepath where to save the embeddings
or the precompted image embeddings computed by
precompute_image_embeddings. - segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer.
- model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
- tile_shape: Shape of tiles for tiled embedding prediction.
If
Nonethen the whole image is passed to Segment Anything. - halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
- viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
- precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster. By default, set to 'False'.
- checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
- decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder.
- device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
- prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation. By default, set to 'True'.
Returns:
The napari viewer, only returned if
return_viewer=True.