micro_sam.sam_annotator.annotator_2d
1from typing import Optional, Tuple, Union 2 3import napari 4import numpy as np 5 6import torch 7 8from .. import util 9from . import _widgets as widgets 10from ._state import AnnotatorState 11from ._annotator import _AnnotatorBase 12from .util import _initialize_parser, _sync_embedding_widget 13 14 15class Annotator2d(_AnnotatorBase): 16 def _get_widgets(self): 17 autosegment = widgets.AutoSegmentWidget( 18 self._viewer, with_decoder=AnnotatorState().decoder is not None, volumetric=False 19 ) 20 return { 21 "segment": widgets.segment(), 22 "autosegment": autosegment, 23 "commit": widgets.commit(), 24 "clear": widgets.clear(), 25 } 26 27 def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: 28 super().__init__(viewer=viewer, ndim=2) 29 30 # Set the expected annotator class to the state. 31 state = AnnotatorState() 32 33 # Reset the state. 34 if reset_state: 35 state.reset_state() 36 37 state.annotator = self 38 39 40def annotator_2d( 41 image: np.ndarray, 42 embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None, 43 segmentation_result: Optional[np.ndarray] = None, 44 model_type: str = util._DEFAULT_MODEL, 45 tile_shape: Optional[Tuple[int, int]] = None, 46 halo: Optional[Tuple[int, int]] = None, 47 return_viewer: bool = False, 48 viewer: Optional["napari.viewer.Viewer"] = None, 49 precompute_amg_state: bool = False, 50 checkpoint_path: Optional[str] = None, 51 decoder_path: Optional[str] = None, 52 device: Optional[Union[str, torch.device]] = None, 53 prefer_decoder: bool = True, 54) -> Optional["napari.viewer.Viewer"]: 55 """Start the 2d annotation tool for a given image. 56 57 Args: 58 image: The image data. 59 embedding_path: Filepath where to save the embeddings 60 or the precompted image embeddings computed by `precompute_image_embeddings`. 61 segmentation_result: An initial segmentation to load. 62 This can be used to correct segmentations with Segment Anything or to save and load progress. 63 The segmentation will be loaded as the 'committed_objects' layer. 64 model_type: The Segment Anything model to use. For details on the available models check out 65 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 66 tile_shape: Shape of tiles for tiled embedding prediction. 67 If `None` then the whole image is passed to Segment Anything. 68 halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders. 69 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 70 By default, does not return the napari viewer. 71 viewer: The viewer to which the Segment Anything functionality should be added. 72 This enables using a pre-initialized viewer. 73 precompute_amg_state: Whether to precompute the state for automatic mask generation. 74 This will take more time when precomputing embeddings, but will then make 75 automatic mask generation much faster. By default, set to 'False'. 76 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 77 decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder. 78 device: The computational device to use for the SAM model. 79 By default, automatically chooses the best available device. 80 prefer_decoder: Whether to use decoder based instance segmentation if 81 the model used has an additional decoder for instance segmentation. 82 By default, set to 'True'. 83 84 Returns: 85 The napari viewer, only returned if `return_viewer=True`. 86 """ 87 88 state = AnnotatorState() 89 state.image_shape = image.shape[:-1] if image.ndim == 3 else image.shape 90 91 state.initialize_predictor( 92 image, model_type=model_type, save_path=embedding_path, 93 halo=halo, tile_shape=tile_shape, precompute_amg_state=precompute_amg_state, 94 ndim=2, checkpoint_path=checkpoint_path, decoder_path=decoder_path, 95 device=device, prefer_decoder=prefer_decoder, 96 skip_load=False, use_cli=True, 97 ) 98 99 if viewer is None: 100 viewer = napari.Viewer() 101 102 viewer.add_image(image, name="image") 103 annotator = Annotator2d(viewer, reset_state=False) 104 105 # Trigger layer update of the annotator so that layers have the correct shape. 106 # And initialize the 'committed_objects' with the segmentation result if it was given. 107 annotator._update_image(segmentation_result=segmentation_result) 108 109 # Add the annotator widget to the viewer and sync widgets. 110 viewer.window.add_dock_widget(annotator) 111 _sync_embedding_widget( 112 widget=state.widgets["embeddings"], 113 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 114 save_path=embedding_path, 115 checkpoint_path=checkpoint_path, 116 device=device, 117 tile_shape=tile_shape, 118 halo=halo, 119 ) 120 121 if return_viewer: 122 return viewer 123 124 napari.run() 125 126 127def main(): 128 """@private""" 129 parser = _initialize_parser(description="Run interactive segmentation for an image.") 130 args = parser.parse_args() 131 image = util.load_image_data(args.input, key=args.key) 132 133 if args.segmentation_result is None: 134 segmentation_result = None 135 else: 136 segmentation_result = util.load_image_data(args.segmentation_result, key=args.segmentation_key) 137 138 annotator_2d( 139 image, embedding_path=args.embedding_path, 140 segmentation_result=segmentation_result, 141 model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo, 142 precompute_amg_state=args.precompute_amg_state, checkpoint_path=args.checkpoint, 143 decoder_path=args.decoder_path, device=args.device, prefer_decoder=args.prefer_decoder, 144 )
16class Annotator2d(_AnnotatorBase): 17 def _get_widgets(self): 18 autosegment = widgets.AutoSegmentWidget( 19 self._viewer, with_decoder=AnnotatorState().decoder is not None, volumetric=False 20 ) 21 return { 22 "segment": widgets.segment(), 23 "autosegment": autosegment, 24 "commit": widgets.commit(), 25 "clear": widgets.clear(), 26 } 27 28 def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: 29 super().__init__(viewer=viewer, ndim=2) 30 31 # Set the expected annotator class to the state. 32 state = AnnotatorState() 33 34 # Reset the state. 35 if reset_state: 36 state.reset_state() 37 38 state.annotator = self
Base class for micro_sam annotation plugins.
Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.
Annotator2d(viewer: napari.viewer.Viewer, reset_state: bool = True)
28 def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: 29 super().__init__(viewer=viewer, ndim=2) 30 31 # Set the expected annotator class to the state. 32 state = AnnotatorState() 33 34 # Reset the state. 35 if reset_state: 36 state.reset_state() 37 38 state.annotator = self
Create the annotator GUI.
Arguments:
- viewer: The napari viewer.
- ndim: The number of spatial dimension of the image data (2 or 3).
def
annotator_2d( image: numpy.ndarray, embedding_path: Union[str, Dict[str, Any], NoneType] = None, segmentation_result: Optional[numpy.ndarray] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, decoder_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True) -> Optional[napari.viewer.Viewer]:
41def annotator_2d( 42 image: np.ndarray, 43 embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None, 44 segmentation_result: Optional[np.ndarray] = None, 45 model_type: str = util._DEFAULT_MODEL, 46 tile_shape: Optional[Tuple[int, int]] = None, 47 halo: Optional[Tuple[int, int]] = None, 48 return_viewer: bool = False, 49 viewer: Optional["napari.viewer.Viewer"] = None, 50 precompute_amg_state: bool = False, 51 checkpoint_path: Optional[str] = None, 52 decoder_path: Optional[str] = None, 53 device: Optional[Union[str, torch.device]] = None, 54 prefer_decoder: bool = True, 55) -> Optional["napari.viewer.Viewer"]: 56 """Start the 2d annotation tool for a given image. 57 58 Args: 59 image: The image data. 60 embedding_path: Filepath where to save the embeddings 61 or the precompted image embeddings computed by `precompute_image_embeddings`. 62 segmentation_result: An initial segmentation to load. 63 This can be used to correct segmentations with Segment Anything or to save and load progress. 64 The segmentation will be loaded as the 'committed_objects' layer. 65 model_type: The Segment Anything model to use. For details on the available models check out 66 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 67 tile_shape: Shape of tiles for tiled embedding prediction. 68 If `None` then the whole image is passed to Segment Anything. 69 halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders. 70 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 71 By default, does not return the napari viewer. 72 viewer: The viewer to which the Segment Anything functionality should be added. 73 This enables using a pre-initialized viewer. 74 precompute_amg_state: Whether to precompute the state for automatic mask generation. 75 This will take more time when precomputing embeddings, but will then make 76 automatic mask generation much faster. By default, set to 'False'. 77 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 78 decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder. 79 device: The computational device to use for the SAM model. 80 By default, automatically chooses the best available device. 81 prefer_decoder: Whether to use decoder based instance segmentation if 82 the model used has an additional decoder for instance segmentation. 83 By default, set to 'True'. 84 85 Returns: 86 The napari viewer, only returned if `return_viewer=True`. 87 """ 88 89 state = AnnotatorState() 90 state.image_shape = image.shape[:-1] if image.ndim == 3 else image.shape 91 92 state.initialize_predictor( 93 image, model_type=model_type, save_path=embedding_path, 94 halo=halo, tile_shape=tile_shape, precompute_amg_state=precompute_amg_state, 95 ndim=2, checkpoint_path=checkpoint_path, decoder_path=decoder_path, 96 device=device, prefer_decoder=prefer_decoder, 97 skip_load=False, use_cli=True, 98 ) 99 100 if viewer is None: 101 viewer = napari.Viewer() 102 103 viewer.add_image(image, name="image") 104 annotator = Annotator2d(viewer, reset_state=False) 105 106 # Trigger layer update of the annotator so that layers have the correct shape. 107 # And initialize the 'committed_objects' with the segmentation result if it was given. 108 annotator._update_image(segmentation_result=segmentation_result) 109 110 # Add the annotator widget to the viewer and sync widgets. 111 viewer.window.add_dock_widget(annotator) 112 _sync_embedding_widget( 113 widget=state.widgets["embeddings"], 114 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 115 save_path=embedding_path, 116 checkpoint_path=checkpoint_path, 117 device=device, 118 tile_shape=tile_shape, 119 halo=halo, 120 ) 121 122 if return_viewer: 123 return viewer 124 125 napari.run()
Start the 2d annotation tool for a given image.
Arguments:
- image: The image data.
- embedding_path: Filepath where to save the embeddings
or the precompted image embeddings computed by
precompute_image_embeddings. - segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer.
- model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
- tile_shape: Shape of tiles for tiled embedding prediction.
If
Nonethen the whole image is passed to Segment Anything. - halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
- viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
- precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster. By default, set to 'False'.
- checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
- decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder.
- device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
- prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation. By default, set to 'True'.
Returns:
The napari viewer, only returned if
return_viewer=True.