micro_sam.sam_annotator.annotator_3d
1from typing import Optional, Tuple, Union 2 3import napari 4import numpy as np 5 6import torch 7 8from .. import util 9from . import _widgets as widgets 10from ._state import AnnotatorState 11from ._annotator import _AnnotatorBase 12from .util import _initialize_parser, _sync_embedding_widget, _load_amg_state, _load_is_state 13 14 15class Annotator3d(_AnnotatorBase): 16 def _get_widgets(self): 17 autosegment = widgets.AutoSegmentWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True) 18 segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=False) 19 return { 20 "segment": widgets.segment_slice(), 21 "segment_nd": segment_nd, 22 "autosegment": autosegment, 23 "commit": widgets.commit(), 24 "clear": widgets.clear_volume(), 25 } 26 27 def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: 28 self._with_decoder = AnnotatorState().decoder is not None 29 super().__init__(viewer=viewer, ndim=3) 30 31 # Set the expected annotator class to the state. 32 state = AnnotatorState() 33 34 # Reset the state. 35 if reset_state: 36 state.reset_state() 37 38 state.annotator = self 39 40 def _update_image(self, segmentation_result=None): 41 super()._update_image(segmentation_result=segmentation_result) 42 # Load the amg state from the embedding path. 43 state = AnnotatorState() 44 if self._with_decoder: 45 state.amg_state = _load_is_state(state.embedding_path) 46 else: 47 state.amg_state = _load_amg_state(state.embedding_path) 48 49 50def annotator_3d( 51 image: np.ndarray, 52 embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None, 53 segmentation_result: Optional[np.ndarray] = None, 54 model_type: str = util._DEFAULT_MODEL, 55 tile_shape: Optional[Tuple[int, int]] = None, 56 halo: Optional[Tuple[int, int]] = None, 57 return_viewer: bool = False, 58 viewer: Optional["napari.viewer.Viewer"] = None, 59 precompute_amg_state: bool = False, 60 checkpoint_path: Optional[str] = None, 61 device: Optional[Union[str, torch.device]] = None, 62 prefer_decoder: bool = True, 63) -> Optional["napari.viewer.Viewer"]: 64 """Start the 3d annotation tool for a given image volume. 65 66 Args: 67 image: The volumetric image data. 68 embedding_path: Filepath where to save the embeddings 69 or the precompted image embeddings computed by `precompute_image_embeddings`. 70 segmentation_result: An initial segmentation to load. 71 This can be used to correct segmentations with Segment Anything or to save and load progress. 72 The segmentation will be loaded as the 'committed_objects' layer. 73 model_type: The Segment Anything model to use. For details on the available models check out 74 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 75 tile_shape: Shape of tiles for tiled embedding prediction. 76 If `None` then the whole image is passed to Segment Anything. 77 halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders. 78 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 79 By default, does not return the napari viewer. 80 viewer: The viewer to which the Segment Anything functionality should be added. 81 This enables using a pre-initialized viewer. 82 precompute_amg_state: Whether to precompute the state for automatic mask generation. 83 This will take more time when precomputing embeddings, but will then make 84 automatic mask generation much faster. By default, set to 'False'. 85 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 86 device: The computational device to use for the SAM model. 87 By default, automatically chooses the best available device. 88 prefer_decoder: Whether to use decoder based instance segmentation if 89 the model used has an additional decoder for instance segmentation. 90 By default, set to 'True'. 91 92 Returns: 93 The napari viewer, only returned if `return_viewer=True`. 94 """ 95 96 # Initialize the predictor state. 97 state = AnnotatorState() 98 state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape 99 state.initialize_predictor( 100 image, model_type=model_type, save_path=embedding_path, 101 halo=halo, tile_shape=tile_shape, ndim=3, precompute_amg_state=precompute_amg_state, 102 checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder, 103 use_cli=True, 104 ) 105 106 if viewer is None: 107 viewer = napari.Viewer() 108 109 viewer.add_image(image, name="image") 110 annotator = Annotator3d(viewer, reset_state=False) 111 112 # Trigger layer update of the annotator so that layers have the correct shape. 113 # And initialize the 'committed_objects' with the segmentation result if it was given. 114 annotator._update_image(segmentation_result=segmentation_result) 115 116 # Add the annotator widget to the viewer and sync widgets. 117 viewer.window.add_dock_widget(annotator) 118 _sync_embedding_widget( 119 widget=state.widgets["embeddings"], 120 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 121 save_path=embedding_path, 122 checkpoint_path=checkpoint_path, 123 device=device, 124 tile_shape=tile_shape, 125 halo=halo 126 ) 127 128 if return_viewer: 129 return viewer 130 131 napari.run() 132 133 134def main(): 135 """@private""" 136 parser = _initialize_parser(description="Run interactive segmentation for an image volume.") 137 args = parser.parse_args() 138 image = util.load_image_data(args.input, key=args.key) 139 140 if args.segmentation_result is None: 141 segmentation_result = None 142 else: 143 segmentation_result = util.load_image_data(args.segmentation_result, key=args.segmentation_key) 144 145 annotator_3d( 146 image, embedding_path=args.embedding_path, 147 segmentation_result=segmentation_result, 148 model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo, 149 checkpoint_path=args.checkpoint, device=args.device, 150 precompute_amg_state=args.precompute_amg_state, prefer_decoder=args.prefer_decoder, 151 )
16class Annotator3d(_AnnotatorBase): 17 def _get_widgets(self): 18 autosegment = widgets.AutoSegmentWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True) 19 segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=False) 20 return { 21 "segment": widgets.segment_slice(), 22 "segment_nd": segment_nd, 23 "autosegment": autosegment, 24 "commit": widgets.commit(), 25 "clear": widgets.clear_volume(), 26 } 27 28 def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: 29 self._with_decoder = AnnotatorState().decoder is not None 30 super().__init__(viewer=viewer, ndim=3) 31 32 # Set the expected annotator class to the state. 33 state = AnnotatorState() 34 35 # Reset the state. 36 if reset_state: 37 state.reset_state() 38 39 state.annotator = self 40 41 def _update_image(self, segmentation_result=None): 42 super()._update_image(segmentation_result=segmentation_result) 43 # Load the amg state from the embedding path. 44 state = AnnotatorState() 45 if self._with_decoder: 46 state.amg_state = _load_is_state(state.embedding_path) 47 else: 48 state.amg_state = _load_amg_state(state.embedding_path)
Base class for micro_sam annotation plugins.
Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.
Annotator3d(viewer: napari.viewer.Viewer, reset_state: bool = True)
28 def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: 29 self._with_decoder = AnnotatorState().decoder is not None 30 super().__init__(viewer=viewer, ndim=3) 31 32 # Set the expected annotator class to the state. 33 state = AnnotatorState() 34 35 # Reset the state. 36 if reset_state: 37 state.reset_state() 38 39 state.annotator = self
Create the annotator GUI.
Arguments:
- viewer: The napari viewer.
- ndim: The number of spatial dimension of the image data (2 or 3).
Inherited Members
- PyQt5.QtWidgets.QScrollArea
- alignment
- ensureVisible
- ensureWidgetVisible
- event
- eventFilter
- focusNextPrevChild
- resizeEvent
- scrollContentsBy
- setAlignment
- setWidget
- setWidgetResizable
- sizeHint
- takeWidget
- viewportSizeHint
- widget
- widgetResizable
- PyQt5.QtWidgets.QAbstractScrollArea
- SizeAdjustPolicy
- addScrollBarWidget
- contextMenuEvent
- cornerWidget
- dragEnterEvent
- dragLeaveEvent
- dragMoveEvent
- dropEvent
- horizontalScrollBar
- horizontalScrollBarPolicy
- keyPressEvent
- maximumViewportSize
- minimumSizeHint
- mouseDoubleClickEvent
- mouseMoveEvent
- mousePressEvent
- mouseReleaseEvent
- paintEvent
- scrollBarWidgets
- setCornerWidget
- setHorizontalScrollBar
- setHorizontalScrollBarPolicy
- setSizeAdjustPolicy
- setVerticalScrollBar
- setVerticalScrollBarPolicy
- setViewport
- setViewportMargins
- setupViewport
- sizeAdjustPolicy
- verticalScrollBar
- verticalScrollBarPolicy
- viewport
- viewportEvent
- viewportMargins
- wheelEvent
- AdjustIgnored
- AdjustToContents
- AdjustToContentsOnFirstShow
- PyQt5.QtWidgets.QFrame
- Shadow
- Shape
- StyleMask
- changeEvent
- drawFrame
- frameRect
- frameShadow
- frameShape
- frameStyle
- frameWidth
- initStyleOption
- lineWidth
- midLineWidth
- setFrameRect
- setFrameShadow
- setFrameShape
- setFrameStyle
- setLineWidth
- setMidLineWidth
- Box
- HLine
- NoFrame
- Panel
- Plain
- Raised
- Shadow_Mask
- Shape_Mask
- StyledPanel
- Sunken
- VLine
- WinPanel
- PyQt5.QtWidgets.QWidget
- RenderFlag
- RenderFlags
- acceptDrops
- accessibleDescription
- accessibleName
- actionEvent
- actions
- activateWindow
- addAction
- addActions
- adjustSize
- autoFillBackground
- backgroundRole
- baseSize
- childAt
- childrenRect
- childrenRegion
- clearFocus
- clearMask
- close
- closeEvent
- contentsMargins
- contentsRect
- contextMenuPolicy
- create
- createWindowContainer
- cursor
- destroy
- devType
- effectiveWinId
- ensurePolished
- enterEvent
- find
- focusInEvent
- focusNextChild
- focusOutEvent
- focusPolicy
- focusPreviousChild
- focusProxy
- focusWidget
- font
- fontInfo
- fontMetrics
- foregroundRole
- frameGeometry
- frameSize
- geometry
- getContentsMargins
- grab
- grabGesture
- grabKeyboard
- grabMouse
- grabShortcut
- graphicsEffect
- graphicsProxyWidget
- hasFocus
- hasHeightForWidth
- hasMouseTracking
- hasTabletTracking
- height
- heightForWidth
- hide
- hideEvent
- initPainter
- inputMethodEvent
- inputMethodHints
- inputMethodQuery
- insertAction
- insertActions
- isActiveWindow
- isAncestorOf
- isEnabled
- isEnabledTo
- isFullScreen
- isHidden
- isLeftToRight
- isMaximized
- isMinimized
- isModal
- isRightToLeft
- isVisible
- isVisibleTo
- isWindow
- isWindowModified
- keyReleaseEvent
- keyboardGrabber
- layout
- layoutDirection
- leaveEvent
- locale
- lower
- mapFrom
- mapFromGlobal
- mapFromParent
- mapTo
- mapToGlobal
- mapToParent
- mask
- maximumHeight
- maximumSize
- maximumWidth
- metric
- minimumHeight
- minimumSize
- minimumWidth
- mouseGrabber
- move
- moveEvent
- nativeEvent
- nativeParentWidget
- nextInFocusChain
- normalGeometry
- overrideWindowFlags
- overrideWindowState
- paintEngine
- palette
- parentWidget
- pos
- previousInFocusChain
- raise_
- rect
- releaseKeyboard
- releaseMouse
- releaseShortcut
- removeAction
- render
- repaint
- resize
- restoreGeometry
- saveGeometry
- screen
- scroll
- setAcceptDrops
- setAccessibleDescription
- setAccessibleName
- setAttribute
- setAutoFillBackground
- setBackgroundRole
- setBaseSize
- setContentsMargins
- setContextMenuPolicy
- setCursor
- setDisabled
- setEnabled
- setFixedHeight
- setFixedSize
- setFixedWidth
- setFocus
- setFocusPolicy
- setFocusProxy
- setFont
- setForegroundRole
- setGeometry
- setGraphicsEffect
- setHidden
- setInputMethodHints
- setLayout
- setLayoutDirection
- setLocale
- setMask
- setMaximumHeight
- setMaximumSize
- setMaximumWidth
- setMinimumHeight
- setMinimumSize
- setMinimumWidth
- setMouseTracking
- setPalette
- setParent
- setShortcutAutoRepeat
- setShortcutEnabled
- setSizeIncrement
- setSizePolicy
- setStatusTip
- setStyle
- setStyleSheet
- setTabOrder
- setTabletTracking
- setToolTip
- setToolTipDuration
- setUpdatesEnabled
- setVisible
- setWhatsThis
- setWindowFilePath
- setWindowFlag
- setWindowFlags
- setWindowIcon
- setWindowIconText
- setWindowModality
- setWindowModified
- setWindowOpacity
- setWindowRole
- setWindowState
- setWindowTitle
- show
- showEvent
- showFullScreen
- showMaximized
- showMinimized
- showNormal
- size
- sizeIncrement
- sizePolicy
- stackUnder
- statusTip
- style
- styleSheet
- tabletEvent
- testAttribute
- toolTip
- toolTipDuration
- underMouse
- ungrabGesture
- unsetCursor
- unsetLayoutDirection
- unsetLocale
- update
- updateGeometry
- updateMicroFocus
- updatesEnabled
- visibleRegion
- whatsThis
- width
- winId
- window
- windowFilePath
- windowFlags
- windowHandle
- windowIcon
- windowIconText
- windowModality
- windowOpacity
- windowRole
- windowState
- windowTitle
- windowType
- x
- y
- DrawChildren
- DrawWindowBackground
- IgnoreMask
- windowIconTextChanged
- windowIconChanged
- windowTitleChanged
- customContextMenuRequested
- PyQt5.QtCore.QObject
- blockSignals
- childEvent
- children
- connectNotify
- customEvent
- deleteLater
- disconnect
- disconnectNotify
- dumpObjectInfo
- dumpObjectTree
- dynamicPropertyNames
- findChild
- findChildren
- inherits
- installEventFilter
- isSignalConnected
- isWidgetType
- isWindowType
- killTimer
- metaObject
- moveToThread
- objectName
- parent
- property
- pyqtConfigure
- receivers
- removeEventFilter
- sender
- senderSignalIndex
- setObjectName
- setProperty
- signalsBlocked
- startTimer
- thread
- timerEvent
- tr
- staticMetaObject
- objectNameChanged
- destroyed
- PyQt5.QtGui.QPaintDevice
- PaintDeviceMetric
- colorCount
- depth
- devicePixelRatio
- devicePixelRatioF
- devicePixelRatioFScale
- heightMM
- logicalDpiX
- logicalDpiY
- paintingActive
- physicalDpiX
- physicalDpiY
- widthMM
- PdmDepth
- PdmDevicePixelRatio
- PdmDevicePixelRatioScaled
- PdmDpiX
- PdmDpiY
- PdmHeight
- PdmHeightMM
- PdmNumColors
- PdmPhysicalDpiX
- PdmPhysicalDpiY
- PdmWidth
- PdmWidthMM
def
annotator_3d( image: numpy.ndarray, embedding_path: Union[str, Dict[str, Any], NoneType] = None, segmentation_result: Optional[numpy.ndarray] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True) -> Optional[napari.viewer.Viewer]:
51def annotator_3d( 52 image: np.ndarray, 53 embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None, 54 segmentation_result: Optional[np.ndarray] = None, 55 model_type: str = util._DEFAULT_MODEL, 56 tile_shape: Optional[Tuple[int, int]] = None, 57 halo: Optional[Tuple[int, int]] = None, 58 return_viewer: bool = False, 59 viewer: Optional["napari.viewer.Viewer"] = None, 60 precompute_amg_state: bool = False, 61 checkpoint_path: Optional[str] = None, 62 device: Optional[Union[str, torch.device]] = None, 63 prefer_decoder: bool = True, 64) -> Optional["napari.viewer.Viewer"]: 65 """Start the 3d annotation tool for a given image volume. 66 67 Args: 68 image: The volumetric image data. 69 embedding_path: Filepath where to save the embeddings 70 or the precompted image embeddings computed by `precompute_image_embeddings`. 71 segmentation_result: An initial segmentation to load. 72 This can be used to correct segmentations with Segment Anything or to save and load progress. 73 The segmentation will be loaded as the 'committed_objects' layer. 74 model_type: The Segment Anything model to use. For details on the available models check out 75 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 76 tile_shape: Shape of tiles for tiled embedding prediction. 77 If `None` then the whole image is passed to Segment Anything. 78 halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders. 79 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 80 By default, does not return the napari viewer. 81 viewer: The viewer to which the Segment Anything functionality should be added. 82 This enables using a pre-initialized viewer. 83 precompute_amg_state: Whether to precompute the state for automatic mask generation. 84 This will take more time when precomputing embeddings, but will then make 85 automatic mask generation much faster. By default, set to 'False'. 86 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 87 device: The computational device to use for the SAM model. 88 By default, automatically chooses the best available device. 89 prefer_decoder: Whether to use decoder based instance segmentation if 90 the model used has an additional decoder for instance segmentation. 91 By default, set to 'True'. 92 93 Returns: 94 The napari viewer, only returned if `return_viewer=True`. 95 """ 96 97 # Initialize the predictor state. 98 state = AnnotatorState() 99 state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape 100 state.initialize_predictor( 101 image, model_type=model_type, save_path=embedding_path, 102 halo=halo, tile_shape=tile_shape, ndim=3, precompute_amg_state=precompute_amg_state, 103 checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder, 104 use_cli=True, 105 ) 106 107 if viewer is None: 108 viewer = napari.Viewer() 109 110 viewer.add_image(image, name="image") 111 annotator = Annotator3d(viewer, reset_state=False) 112 113 # Trigger layer update of the annotator so that layers have the correct shape. 114 # And initialize the 'committed_objects' with the segmentation result if it was given. 115 annotator._update_image(segmentation_result=segmentation_result) 116 117 # Add the annotator widget to the viewer and sync widgets. 118 viewer.window.add_dock_widget(annotator) 119 _sync_embedding_widget( 120 widget=state.widgets["embeddings"], 121 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 122 save_path=embedding_path, 123 checkpoint_path=checkpoint_path, 124 device=device, 125 tile_shape=tile_shape, 126 halo=halo 127 ) 128 129 if return_viewer: 130 return viewer 131 132 napari.run()
Start the 3d annotation tool for a given image volume.
Arguments:
- image: The volumetric image data.
- embedding_path: Filepath where to save the embeddings
or the precompted image embeddings computed by
precompute_image_embeddings
. - segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer.
- model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
- tile_shape: Shape of tiles for tiled embedding prediction.
If
None
then the whole image is passed to Segment Anything. - halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
- viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
- precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster. By default, set to 'False'.
- checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
- device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
- prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation. By default, set to 'True'.
Returns:
The napari viewer, only returned if
return_viewer=True
.