micro_sam.sam_annotator.annotator_3d
1from typing import Optional, Tuple, Union 2 3import napari 4import numpy as np 5 6import torch 7 8from ._annotator import _AnnotatorBase 9from ._state import AnnotatorState 10from . import _widgets as widgets 11from .util import _initialize_parser, _sync_embedding_widget, _load_amg_state, _load_is_state 12from .. import util 13 14 15class Annotator3d(_AnnotatorBase): 16 def _get_widgets(self): 17 autosegment = widgets.AutoSegmentWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True) 18 segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=False) 19 return { 20 "segment": widgets.segment_slice(), 21 "segment_nd": segment_nd, 22 "autosegment": autosegment, 23 "commit": widgets.commit(), 24 "clear": widgets.clear_volume(), 25 } 26 27 def __init__(self, viewer: "napari.viewer.Viewer") -> None: 28 self._with_decoder = AnnotatorState().decoder is not None 29 super().__init__(viewer=viewer, ndim=3) 30 31 def _update_image(self, segmentation_result=None): 32 super()._update_image(segmentation_result=segmentation_result) 33 # Load the amg state from the embedding path. 34 state = AnnotatorState() 35 if self._with_decoder: 36 state.amg_state = _load_is_state(state.embedding_path) 37 else: 38 state.amg_state = _load_amg_state(state.embedding_path) 39 40 41def annotator_3d( 42 image: np.ndarray, 43 embedding_path: Optional[str] = None, 44 segmentation_result: Optional[np.ndarray] = None, 45 model_type: str = util._DEFAULT_MODEL, 46 tile_shape: Optional[Tuple[int, int]] = None, 47 halo: Optional[Tuple[int, int]] = None, 48 return_viewer: bool = False, 49 viewer: Optional["napari.viewer.Viewer"] = None, 50 precompute_amg_state: bool = False, 51 checkpoint_path: Optional[str] = None, 52 device: Optional[Union[str, torch.device]] = None, 53 prefer_decoder: bool = True, 54) -> Optional["napari.viewer.Viewer"]: 55 """Start the 3d annotation tool for a given image volume. 56 57 Args: 58 image: The volumetric image data. 59 embedding_path: Filepath for saving the precomputed embeddings. 60 segmentation_result: An initial segmentation to load. 61 This can be used to correct segmentations with Segment Anything or to save and load progress. 62 The segmentation will be loaded as the 'committed_objects' layer. 63 model_type: The Segment Anything model to use. For details on the available models check out 64 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 65 tile_shape: Shape of tiles for tiled embedding prediction. 66 If `None` then the whole image is passed to Segment Anything. 67 halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. 68 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 69 viewer: The viewer to which the SegmentAnything functionality should be added. 70 This enables using a pre-initialized viewer. 71 precompute_amg_state: Whether to precompute the state for automatic mask generation. 72 This will take more time when precomputing embeddings, but will then make 73 automatic mask generation much faster. 74 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 75 device: The computational device to use for the SAM model. 76 prefer_decoder: Whether to use decoder based instance segmentation if 77 the model used has an additional decoder for instance segmentation. 78 79 Returns: 80 The napari viewer, only returned if `return_viewer=True`. 81 """ 82 83 # Initialize the predictor state. 84 state = AnnotatorState() 85 state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape 86 state.initialize_predictor( 87 image, model_type=model_type, save_path=embedding_path, 88 halo=halo, tile_shape=tile_shape, ndim=3, precompute_amg_state=precompute_amg_state, 89 checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder, 90 ) 91 92 if viewer is None: 93 viewer = napari.Viewer() 94 95 viewer.add_image(image, name="image") 96 annotator = Annotator3d(viewer) 97 98 # Trigger layer update of the annotator so that layers have the correct shape. 99 # And initialize the 'committed_objects' with the segmentation result if it was given. 100 annotator._update_image(segmentation_result=segmentation_result) 101 102 # Add the annotator widget to the viewer and sync widgets. 103 viewer.window.add_dock_widget(annotator) 104 _sync_embedding_widget( 105 state.widgets["embeddings"], model_type, 106 save_path=embedding_path, checkpoint_path=checkpoint_path, 107 device=device, tile_shape=tile_shape, halo=halo 108 ) 109 110 if return_viewer: 111 return viewer 112 113 napari.run() 114 115 116def main(): 117 """@private""" 118 parser = _initialize_parser(description="Run interactive segmentation for an image volume.") 119 args = parser.parse_args() 120 image = util.load_image_data(args.input, key=args.key) 121 122 if args.segmentation_result is None: 123 segmentation_result = None 124 else: 125 segmentation_result = util.load_image_data(args.segmentation_result, key=args.segmentation_key) 126 127 annotator_3d( 128 image, embedding_path=args.embedding_path, 129 segmentation_result=segmentation_result, 130 model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo, 131 checkpoint_path=args.checkpoint, device=args.device, 132 precompute_amg_state=args.precompute_amg_state, prefer_decoder=args.prefer_decoder, 133 )
16class Annotator3d(_AnnotatorBase): 17 def _get_widgets(self): 18 autosegment = widgets.AutoSegmentWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True) 19 segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=False) 20 return { 21 "segment": widgets.segment_slice(), 22 "segment_nd": segment_nd, 23 "autosegment": autosegment, 24 "commit": widgets.commit(), 25 "clear": widgets.clear_volume(), 26 } 27 28 def __init__(self, viewer: "napari.viewer.Viewer") -> None: 29 self._with_decoder = AnnotatorState().decoder is not None 30 super().__init__(viewer=viewer, ndim=3) 31 32 def _update_image(self, segmentation_result=None): 33 super()._update_image(segmentation_result=segmentation_result) 34 # Load the amg state from the embedding path. 35 state = AnnotatorState() 36 if self._with_decoder: 37 state.amg_state = _load_is_state(state.embedding_path) 38 else: 39 state.amg_state = _load_amg_state(state.embedding_path)
Base class for micro_sam annotation plugins.
Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.
Annotator3d(viewer: napari.viewer.Viewer)
28 def __init__(self, viewer: "napari.viewer.Viewer") -> None: 29 self._with_decoder = AnnotatorState().decoder is not None 30 super().__init__(viewer=viewer, ndim=3)
Create the annotator GUI.
Arguments:
- viewer: The napari viewer.
- ndim: The number of spatial dimension of the image data (2 or 3).
Inherited Members
- PyQt5.QtWidgets.QScrollArea
- alignment
- ensureVisible
- ensureWidgetVisible
- event
- eventFilter
- focusNextPrevChild
- resizeEvent
- scrollContentsBy
- setAlignment
- setWidget
- setWidgetResizable
- sizeHint
- takeWidget
- viewportSizeHint
- widget
- widgetResizable
- PyQt5.QtWidgets.QAbstractScrollArea
- SizeAdjustPolicy
- addScrollBarWidget
- contextMenuEvent
- cornerWidget
- dragEnterEvent
- dragLeaveEvent
- dragMoveEvent
- dropEvent
- horizontalScrollBar
- horizontalScrollBarPolicy
- keyPressEvent
- maximumViewportSize
- minimumSizeHint
- mouseDoubleClickEvent
- mouseMoveEvent
- mousePressEvent
- mouseReleaseEvent
- paintEvent
- scrollBarWidgets
- setCornerWidget
- setHorizontalScrollBar
- setHorizontalScrollBarPolicy
- setSizeAdjustPolicy
- setVerticalScrollBar
- setVerticalScrollBarPolicy
- setViewport
- setViewportMargins
- setupViewport
- sizeAdjustPolicy
- verticalScrollBar
- verticalScrollBarPolicy
- viewport
- viewportEvent
- viewportMargins
- wheelEvent
- AdjustIgnored
- AdjustToContents
- AdjustToContentsOnFirstShow
- PyQt5.QtWidgets.QFrame
- Shadow
- Shape
- StyleMask
- changeEvent
- drawFrame
- frameRect
- frameShadow
- frameShape
- frameStyle
- frameWidth
- initStyleOption
- lineWidth
- midLineWidth
- setFrameRect
- setFrameShadow
- setFrameShape
- setFrameStyle
- setLineWidth
- setMidLineWidth
- Box
- HLine
- NoFrame
- Panel
- Plain
- Raised
- Shadow_Mask
- Shape_Mask
- StyledPanel
- Sunken
- VLine
- WinPanel
- PyQt5.QtWidgets.QWidget
- RenderFlag
- RenderFlags
- acceptDrops
- accessibleDescription
- accessibleName
- actionEvent
- actions
- activateWindow
- addAction
- addActions
- adjustSize
- autoFillBackground
- backgroundRole
- baseSize
- childAt
- childrenRect
- childrenRegion
- clearFocus
- clearMask
- close
- closeEvent
- contentsMargins
- contentsRect
- contextMenuPolicy
- create
- createWindowContainer
- cursor
- destroy
- devType
- effectiveWinId
- ensurePolished
- enterEvent
- find
- focusInEvent
- focusNextChild
- focusOutEvent
- focusPolicy
- focusPreviousChild
- focusProxy
- focusWidget
- font
- fontInfo
- fontMetrics
- foregroundRole
- frameGeometry
- frameSize
- geometry
- getContentsMargins
- grab
- grabGesture
- grabKeyboard
- grabMouse
- grabShortcut
- graphicsEffect
- graphicsProxyWidget
- hasFocus
- hasHeightForWidth
- hasMouseTracking
- hasTabletTracking
- height
- heightForWidth
- hide
- hideEvent
- initPainter
- inputMethodEvent
- inputMethodHints
- inputMethodQuery
- insertAction
- insertActions
- isActiveWindow
- isAncestorOf
- isEnabled
- isEnabledTo
- isFullScreen
- isHidden
- isLeftToRight
- isMaximized
- isMinimized
- isModal
- isRightToLeft
- isVisible
- isVisibleTo
- isWindow
- isWindowModified
- keyReleaseEvent
- keyboardGrabber
- layout
- layoutDirection
- leaveEvent
- locale
- lower
- mapFrom
- mapFromGlobal
- mapFromParent
- mapTo
- mapToGlobal
- mapToParent
- mask
- maximumHeight
- maximumSize
- maximumWidth
- metric
- minimumHeight
- minimumSize
- minimumWidth
- mouseGrabber
- move
- moveEvent
- nativeEvent
- nativeParentWidget
- nextInFocusChain
- normalGeometry
- overrideWindowFlags
- overrideWindowState
- paintEngine
- palette
- parentWidget
- pos
- previousInFocusChain
- raise_
- rect
- releaseKeyboard
- releaseMouse
- releaseShortcut
- removeAction
- render
- repaint
- resize
- restoreGeometry
- saveGeometry
- screen
- scroll
- setAcceptDrops
- setAccessibleDescription
- setAccessibleName
- setAttribute
- setAutoFillBackground
- setBackgroundRole
- setBaseSize
- setContentsMargins
- setContextMenuPolicy
- setCursor
- setDisabled
- setEnabled
- setFixedHeight
- setFixedSize
- setFixedWidth
- setFocus
- setFocusPolicy
- setFocusProxy
- setFont
- setForegroundRole
- setGeometry
- setGraphicsEffect
- setHidden
- setInputMethodHints
- setLayout
- setLayoutDirection
- setLocale
- setMask
- setMaximumHeight
- setMaximumSize
- setMaximumWidth
- setMinimumHeight
- setMinimumSize
- setMinimumWidth
- setMouseTracking
- setPalette
- setParent
- setShortcutAutoRepeat
- setShortcutEnabled
- setSizeIncrement
- setSizePolicy
- setStatusTip
- setStyle
- setStyleSheet
- setTabOrder
- setTabletTracking
- setToolTip
- setToolTipDuration
- setUpdatesEnabled
- setVisible
- setWhatsThis
- setWindowFilePath
- setWindowFlag
- setWindowFlags
- setWindowIcon
- setWindowIconText
- setWindowModality
- setWindowModified
- setWindowOpacity
- setWindowRole
- setWindowState
- setWindowTitle
- show
- showEvent
- showFullScreen
- showMaximized
- showMinimized
- showNormal
- size
- sizeIncrement
- sizePolicy
- stackUnder
- statusTip
- style
- styleSheet
- tabletEvent
- testAttribute
- toolTip
- toolTipDuration
- underMouse
- ungrabGesture
- unsetCursor
- unsetLayoutDirection
- unsetLocale
- update
- updateGeometry
- updateMicroFocus
- updatesEnabled
- visibleRegion
- whatsThis
- width
- winId
- window
- windowFilePath
- windowFlags
- windowHandle
- windowIcon
- windowIconText
- windowModality
- windowOpacity
- windowRole
- windowState
- windowTitle
- windowType
- x
- y
- DrawChildren
- DrawWindowBackground
- IgnoreMask
- windowIconTextChanged
- windowIconChanged
- windowTitleChanged
- customContextMenuRequested
- PyQt5.QtCore.QObject
- blockSignals
- childEvent
- children
- connectNotify
- customEvent
- deleteLater
- disconnect
- disconnectNotify
- dumpObjectInfo
- dumpObjectTree
- dynamicPropertyNames
- findChild
- findChildren
- inherits
- installEventFilter
- isSignalConnected
- isWidgetType
- isWindowType
- killTimer
- metaObject
- moveToThread
- objectName
- parent
- property
- pyqtConfigure
- receivers
- removeEventFilter
- sender
- senderSignalIndex
- setObjectName
- setProperty
- signalsBlocked
- startTimer
- thread
- timerEvent
- tr
- staticMetaObject
- objectNameChanged
- destroyed
- PyQt5.QtGui.QPaintDevice
- PaintDeviceMetric
- colorCount
- depth
- devicePixelRatio
- devicePixelRatioF
- devicePixelRatioFScale
- heightMM
- logicalDpiX
- logicalDpiY
- paintingActive
- physicalDpiX
- physicalDpiY
- widthMM
- PdmDepth
- PdmDevicePixelRatio
- PdmDevicePixelRatioScaled
- PdmDpiX
- PdmDpiY
- PdmHeight
- PdmHeightMM
- PdmNumColors
- PdmPhysicalDpiX
- PdmPhysicalDpiY
- PdmWidth
- PdmWidthMM
def
annotator_3d( image: numpy.ndarray, embedding_path: Optional[str] = None, segmentation_result: Optional[numpy.ndarray] = None, model_type: str = 'vit_l', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True) -> Optional[napari.viewer.Viewer]:
42def annotator_3d( 43 image: np.ndarray, 44 embedding_path: Optional[str] = None, 45 segmentation_result: Optional[np.ndarray] = None, 46 model_type: str = util._DEFAULT_MODEL, 47 tile_shape: Optional[Tuple[int, int]] = None, 48 halo: Optional[Tuple[int, int]] = None, 49 return_viewer: bool = False, 50 viewer: Optional["napari.viewer.Viewer"] = None, 51 precompute_amg_state: bool = False, 52 checkpoint_path: Optional[str] = None, 53 device: Optional[Union[str, torch.device]] = None, 54 prefer_decoder: bool = True, 55) -> Optional["napari.viewer.Viewer"]: 56 """Start the 3d annotation tool for a given image volume. 57 58 Args: 59 image: The volumetric image data. 60 embedding_path: Filepath for saving the precomputed embeddings. 61 segmentation_result: An initial segmentation to load. 62 This can be used to correct segmentations with Segment Anything or to save and load progress. 63 The segmentation will be loaded as the 'committed_objects' layer. 64 model_type: The Segment Anything model to use. For details on the available models check out 65 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 66 tile_shape: Shape of tiles for tiled embedding prediction. 67 If `None` then the whole image is passed to Segment Anything. 68 halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders. 69 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 70 viewer: The viewer to which the SegmentAnything functionality should be added. 71 This enables using a pre-initialized viewer. 72 precompute_amg_state: Whether to precompute the state for automatic mask generation. 73 This will take more time when precomputing embeddings, but will then make 74 automatic mask generation much faster. 75 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 76 device: The computational device to use for the SAM model. 77 prefer_decoder: Whether to use decoder based instance segmentation if 78 the model used has an additional decoder for instance segmentation. 79 80 Returns: 81 The napari viewer, only returned if `return_viewer=True`. 82 """ 83 84 # Initialize the predictor state. 85 state = AnnotatorState() 86 state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape 87 state.initialize_predictor( 88 image, model_type=model_type, save_path=embedding_path, 89 halo=halo, tile_shape=tile_shape, ndim=3, precompute_amg_state=precompute_amg_state, 90 checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder, 91 ) 92 93 if viewer is None: 94 viewer = napari.Viewer() 95 96 viewer.add_image(image, name="image") 97 annotator = Annotator3d(viewer) 98 99 # Trigger layer update of the annotator so that layers have the correct shape. 100 # And initialize the 'committed_objects' with the segmentation result if it was given. 101 annotator._update_image(segmentation_result=segmentation_result) 102 103 # Add the annotator widget to the viewer and sync widgets. 104 viewer.window.add_dock_widget(annotator) 105 _sync_embedding_widget( 106 state.widgets["embeddings"], model_type, 107 save_path=embedding_path, checkpoint_path=checkpoint_path, 108 device=device, tile_shape=tile_shape, halo=halo 109 ) 110 111 if return_viewer: 112 return viewer 113 114 napari.run()
Start the 3d annotation tool for a given image volume.
Arguments:
- image: The volumetric image data.
- embedding_path: Filepath for saving the precomputed embeddings.
- segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer.
- model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
- tile_shape: Shape of tiles for tiled embedding prediction.
If
None
then the whole image is passed to Segment Anything. - halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
- viewer: The viewer to which the SegmentAnything functionality should be added. This enables using a pre-initialized viewer.
- precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster.
- checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
- device: The computational device to use for the SAM model.
- prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation.
Returns:
The napari viewer, only returned if
return_viewer=True
.