micro_sam.sam_annotator.annotator_2d
1from typing import Optional, Tuple, Union 2 3import napari 4import numpy as np 5 6import torch 7 8from .. import util 9from . import _widgets as widgets 10from ._state import AnnotatorState 11from ._annotator import _AnnotatorBase 12from .util import _initialize_parser, _sync_embedding_widget 13 14 15class Annotator2d(_AnnotatorBase): 16 def _get_widgets(self): 17 autosegment = widgets.AutoSegmentWidget( 18 self._viewer, with_decoder=AnnotatorState().decoder is not None, volumetric=False 19 ) 20 return { 21 "segment": widgets.segment(), 22 "autosegment": autosegment, 23 "commit": widgets.commit(), 24 "clear": widgets.clear(), 25 } 26 27 def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: 28 super().__init__(viewer=viewer, ndim=2) 29 30 # Set the expected annotator class to the state. 31 state = AnnotatorState() 32 33 # Reset the state. 34 if reset_state: 35 state.reset_state() 36 37 state.annotator = self 38 39 40def annotator_2d( 41 image: np.ndarray, 42 embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None, 43 segmentation_result: Optional[np.ndarray] = None, 44 model_type: str = util._DEFAULT_MODEL, 45 tile_shape: Optional[Tuple[int, int]] = None, 46 halo: Optional[Tuple[int, int]] = None, 47 return_viewer: bool = False, 48 viewer: Optional["napari.viewer.Viewer"] = None, 49 precompute_amg_state: bool = False, 50 checkpoint_path: Optional[str] = None, 51 device: Optional[Union[str, torch.device]] = None, 52 prefer_decoder: bool = True, 53) -> Optional["napari.viewer.Viewer"]: 54 """Start the 2d annotation tool for a given image. 55 56 Args: 57 image: The image data. 58 embedding_path: Filepath where to save the embeddings 59 or the precompted image embeddings computed by `precompute_image_embeddings`. 60 segmentation_result: An initial segmentation to load. 61 This can be used to correct segmentations with Segment Anything or to save and load progress. 62 The segmentation will be loaded as the 'committed_objects' layer. 63 model_type: The Segment Anything model to use. For details on the available models check out 64 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 65 tile_shape: Shape of tiles for tiled embedding prediction. 66 If `None` then the whole image is passed to Segment Anything. 67 halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders. 68 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 69 By default, does not return the napari viewer. 70 viewer: The viewer to which the Segment Anything functionality should be added. 71 This enables using a pre-initialized viewer. 72 precompute_amg_state: Whether to precompute the state for automatic mask generation. 73 This will take more time when precomputing embeddings, but will then make 74 automatic mask generation much faster. By default, set to 'False'. 75 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 76 device: The computational device to use for the SAM model. 77 By default, automatically chooses the best available device. 78 prefer_decoder: Whether to use decoder based instance segmentation if 79 the model used has an additional decoder for instance segmentation. 80 By default, set to 'True'. 81 82 Returns: 83 The napari viewer, only returned if `return_viewer=True`. 84 """ 85 86 state = AnnotatorState() 87 state.image_shape = image.shape[:-1] if image.ndim == 3 else image.shape 88 89 state.initialize_predictor( 90 image, model_type=model_type, save_path=embedding_path, 91 halo=halo, tile_shape=tile_shape, precompute_amg_state=precompute_amg_state, 92 ndim=2, checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder, 93 skip_load=False, use_cli=True, 94 ) 95 96 if viewer is None: 97 viewer = napari.Viewer() 98 99 viewer.add_image(image, name="image") 100 annotator = Annotator2d(viewer, reset_state=False) 101 102 # Trigger layer update of the annotator so that layers have the correct shape. 103 # And initialize the 'committed_objects' with the segmentation result if it was given. 104 annotator._update_image(segmentation_result=segmentation_result) 105 106 # Add the annotator widget to the viewer and sync widgets. 107 viewer.window.add_dock_widget(annotator) 108 _sync_embedding_widget( 109 widget=state.widgets["embeddings"], 110 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 111 save_path=embedding_path, 112 checkpoint_path=checkpoint_path, 113 device=device, 114 tile_shape=tile_shape, 115 halo=halo, 116 ) 117 118 if return_viewer: 119 return viewer 120 121 napari.run() 122 123 124def main(): 125 """@private""" 126 parser = _initialize_parser(description="Run interactive segmentation for an image.") 127 args = parser.parse_args() 128 image = util.load_image_data(args.input, key=args.key) 129 130 if args.segmentation_result is None: 131 segmentation_result = None 132 else: 133 segmentation_result = util.load_image_data(args.segmentation_result, key=args.segmentation_key) 134 135 annotator_2d( 136 image, embedding_path=args.embedding_path, 137 segmentation_result=segmentation_result, 138 model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo, 139 precompute_amg_state=args.precompute_amg_state, checkpoint_path=args.checkpoint, 140 device=args.device, prefer_decoder=args.prefer_decoder, 141 )
16class Annotator2d(_AnnotatorBase): 17 def _get_widgets(self): 18 autosegment = widgets.AutoSegmentWidget( 19 self._viewer, with_decoder=AnnotatorState().decoder is not None, volumetric=False 20 ) 21 return { 22 "segment": widgets.segment(), 23 "autosegment": autosegment, 24 "commit": widgets.commit(), 25 "clear": widgets.clear(), 26 } 27 28 def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: 29 super().__init__(viewer=viewer, ndim=2) 30 31 # Set the expected annotator class to the state. 32 state = AnnotatorState() 33 34 # Reset the state. 35 if reset_state: 36 state.reset_state() 37 38 state.annotator = self
Base class for micro_sam annotation plugins.
Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.
Annotator2d(viewer: napari.viewer.Viewer, reset_state: bool = True)
28 def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: 29 super().__init__(viewer=viewer, ndim=2) 30 31 # Set the expected annotator class to the state. 32 state = AnnotatorState() 33 34 # Reset the state. 35 if reset_state: 36 state.reset_state() 37 38 state.annotator = self
Create the annotator GUI.
Arguments:
- viewer: The napari viewer.
- ndim: The number of spatial dimension of the image data (2 or 3).
Inherited Members
- PyQt5.QtWidgets.QScrollArea
- alignment
- ensureVisible
- ensureWidgetVisible
- event
- eventFilter
- focusNextPrevChild
- resizeEvent
- scrollContentsBy
- setAlignment
- setWidget
- setWidgetResizable
- sizeHint
- takeWidget
- viewportSizeHint
- widget
- widgetResizable
- PyQt5.QtWidgets.QAbstractScrollArea
- SizeAdjustPolicy
- addScrollBarWidget
- contextMenuEvent
- cornerWidget
- dragEnterEvent
- dragLeaveEvent
- dragMoveEvent
- dropEvent
- horizontalScrollBar
- horizontalScrollBarPolicy
- keyPressEvent
- maximumViewportSize
- minimumSizeHint
- mouseDoubleClickEvent
- mouseMoveEvent
- mousePressEvent
- mouseReleaseEvent
- paintEvent
- scrollBarWidgets
- setCornerWidget
- setHorizontalScrollBar
- setHorizontalScrollBarPolicy
- setSizeAdjustPolicy
- setVerticalScrollBar
- setVerticalScrollBarPolicy
- setViewport
- setViewportMargins
- setupViewport
- sizeAdjustPolicy
- verticalScrollBar
- verticalScrollBarPolicy
- viewport
- viewportEvent
- viewportMargins
- wheelEvent
- AdjustIgnored
- AdjustToContents
- AdjustToContentsOnFirstShow
- PyQt5.QtWidgets.QFrame
- Shadow
- Shape
- StyleMask
- changeEvent
- drawFrame
- frameRect
- frameShadow
- frameShape
- frameStyle
- frameWidth
- initStyleOption
- lineWidth
- midLineWidth
- setFrameRect
- setFrameShadow
- setFrameShape
- setFrameStyle
- setLineWidth
- setMidLineWidth
- Box
- HLine
- NoFrame
- Panel
- Plain
- Raised
- Shadow_Mask
- Shape_Mask
- StyledPanel
- Sunken
- VLine
- WinPanel
- PyQt5.QtWidgets.QWidget
- RenderFlag
- RenderFlags
- acceptDrops
- accessibleDescription
- accessibleName
- actionEvent
- actions
- activateWindow
- addAction
- addActions
- adjustSize
- autoFillBackground
- backgroundRole
- baseSize
- childAt
- childrenRect
- childrenRegion
- clearFocus
- clearMask
- close
- closeEvent
- contentsMargins
- contentsRect
- contextMenuPolicy
- create
- createWindowContainer
- cursor
- destroy
- devType
- effectiveWinId
- ensurePolished
- enterEvent
- find
- focusInEvent
- focusNextChild
- focusOutEvent
- focusPolicy
- focusPreviousChild
- focusProxy
- focusWidget
- font
- fontInfo
- fontMetrics
- foregroundRole
- frameGeometry
- frameSize
- geometry
- getContentsMargins
- grab
- grabGesture
- grabKeyboard
- grabMouse
- grabShortcut
- graphicsEffect
- graphicsProxyWidget
- hasFocus
- hasHeightForWidth
- hasMouseTracking
- hasTabletTracking
- height
- heightForWidth
- hide
- hideEvent
- initPainter
- inputMethodEvent
- inputMethodHints
- inputMethodQuery
- insertAction
- insertActions
- isActiveWindow
- isAncestorOf
- isEnabled
- isEnabledTo
- isFullScreen
- isHidden
- isLeftToRight
- isMaximized
- isMinimized
- isModal
- isRightToLeft
- isVisible
- isVisibleTo
- isWindow
- isWindowModified
- keyReleaseEvent
- keyboardGrabber
- layout
- layoutDirection
- leaveEvent
- locale
- lower
- mapFrom
- mapFromGlobal
- mapFromParent
- mapTo
- mapToGlobal
- mapToParent
- mask
- maximumHeight
- maximumSize
- maximumWidth
- metric
- minimumHeight
- minimumSize
- minimumWidth
- mouseGrabber
- move
- moveEvent
- nativeEvent
- nativeParentWidget
- nextInFocusChain
- normalGeometry
- overrideWindowFlags
- overrideWindowState
- paintEngine
- palette
- parentWidget
- pos
- previousInFocusChain
- raise_
- rect
- releaseKeyboard
- releaseMouse
- releaseShortcut
- removeAction
- render
- repaint
- resize
- restoreGeometry
- saveGeometry
- screen
- scroll
- setAcceptDrops
- setAccessibleDescription
- setAccessibleName
- setAttribute
- setAutoFillBackground
- setBackgroundRole
- setBaseSize
- setContentsMargins
- setContextMenuPolicy
- setCursor
- setDisabled
- setEnabled
- setFixedHeight
- setFixedSize
- setFixedWidth
- setFocus
- setFocusPolicy
- setFocusProxy
- setFont
- setForegroundRole
- setGeometry
- setGraphicsEffect
- setHidden
- setInputMethodHints
- setLayout
- setLayoutDirection
- setLocale
- setMask
- setMaximumHeight
- setMaximumSize
- setMaximumWidth
- setMinimumHeight
- setMinimumSize
- setMinimumWidth
- setMouseTracking
- setPalette
- setParent
- setShortcutAutoRepeat
- setShortcutEnabled
- setSizeIncrement
- setSizePolicy
- setStatusTip
- setStyle
- setStyleSheet
- setTabOrder
- setTabletTracking
- setToolTip
- setToolTipDuration
- setUpdatesEnabled
- setVisible
- setWhatsThis
- setWindowFilePath
- setWindowFlag
- setWindowFlags
- setWindowIcon
- setWindowIconText
- setWindowModality
- setWindowModified
- setWindowOpacity
- setWindowRole
- setWindowState
- setWindowTitle
- show
- showEvent
- showFullScreen
- showMaximized
- showMinimized
- showNormal
- size
- sizeIncrement
- sizePolicy
- stackUnder
- statusTip
- style
- styleSheet
- tabletEvent
- testAttribute
- toolTip
- toolTipDuration
- underMouse
- ungrabGesture
- unsetCursor
- unsetLayoutDirection
- unsetLocale
- update
- updateGeometry
- updateMicroFocus
- updatesEnabled
- visibleRegion
- whatsThis
- width
- winId
- window
- windowFilePath
- windowFlags
- windowHandle
- windowIcon
- windowIconText
- windowModality
- windowOpacity
- windowRole
- windowState
- windowTitle
- windowType
- x
- y
- DrawChildren
- DrawWindowBackground
- IgnoreMask
- windowIconTextChanged
- windowIconChanged
- windowTitleChanged
- customContextMenuRequested
- PyQt5.QtCore.QObject
- blockSignals
- childEvent
- children
- connectNotify
- customEvent
- deleteLater
- disconnect
- disconnectNotify
- dumpObjectInfo
- dumpObjectTree
- dynamicPropertyNames
- findChild
- findChildren
- inherits
- installEventFilter
- isSignalConnected
- isWidgetType
- isWindowType
- killTimer
- metaObject
- moveToThread
- objectName
- parent
- property
- pyqtConfigure
- receivers
- removeEventFilter
- sender
- senderSignalIndex
- setObjectName
- setProperty
- signalsBlocked
- startTimer
- thread
- timerEvent
- tr
- staticMetaObject
- objectNameChanged
- destroyed
- PyQt5.QtGui.QPaintDevice
- PaintDeviceMetric
- colorCount
- depth
- devicePixelRatio
- devicePixelRatioF
- devicePixelRatioFScale
- heightMM
- logicalDpiX
- logicalDpiY
- paintingActive
- physicalDpiX
- physicalDpiY
- widthMM
- PdmDepth
- PdmDevicePixelRatio
- PdmDevicePixelRatioScaled
- PdmDpiX
- PdmDpiY
- PdmHeight
- PdmHeightMM
- PdmNumColors
- PdmPhysicalDpiX
- PdmPhysicalDpiY
- PdmWidth
- PdmWidthMM
def
annotator_2d( image: numpy.ndarray, embedding_path: Union[str, Dict[str, Any], NoneType] = None, segmentation_result: Optional[numpy.ndarray] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True) -> Optional[napari.viewer.Viewer]:
41def annotator_2d( 42 image: np.ndarray, 43 embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None, 44 segmentation_result: Optional[np.ndarray] = None, 45 model_type: str = util._DEFAULT_MODEL, 46 tile_shape: Optional[Tuple[int, int]] = None, 47 halo: Optional[Tuple[int, int]] = None, 48 return_viewer: bool = False, 49 viewer: Optional["napari.viewer.Viewer"] = None, 50 precompute_amg_state: bool = False, 51 checkpoint_path: Optional[str] = None, 52 device: Optional[Union[str, torch.device]] = None, 53 prefer_decoder: bool = True, 54) -> Optional["napari.viewer.Viewer"]: 55 """Start the 2d annotation tool for a given image. 56 57 Args: 58 image: The image data. 59 embedding_path: Filepath where to save the embeddings 60 or the precompted image embeddings computed by `precompute_image_embeddings`. 61 segmentation_result: An initial segmentation to load. 62 This can be used to correct segmentations with Segment Anything or to save and load progress. 63 The segmentation will be loaded as the 'committed_objects' layer. 64 model_type: The Segment Anything model to use. For details on the available models check out 65 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 66 tile_shape: Shape of tiles for tiled embedding prediction. 67 If `None` then the whole image is passed to Segment Anything. 68 halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders. 69 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 70 By default, does not return the napari viewer. 71 viewer: The viewer to which the Segment Anything functionality should be added. 72 This enables using a pre-initialized viewer. 73 precompute_amg_state: Whether to precompute the state for automatic mask generation. 74 This will take more time when precomputing embeddings, but will then make 75 automatic mask generation much faster. By default, set to 'False'. 76 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 77 device: The computational device to use for the SAM model. 78 By default, automatically chooses the best available device. 79 prefer_decoder: Whether to use decoder based instance segmentation if 80 the model used has an additional decoder for instance segmentation. 81 By default, set to 'True'. 82 83 Returns: 84 The napari viewer, only returned if `return_viewer=True`. 85 """ 86 87 state = AnnotatorState() 88 state.image_shape = image.shape[:-1] if image.ndim == 3 else image.shape 89 90 state.initialize_predictor( 91 image, model_type=model_type, save_path=embedding_path, 92 halo=halo, tile_shape=tile_shape, precompute_amg_state=precompute_amg_state, 93 ndim=2, checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder, 94 skip_load=False, use_cli=True, 95 ) 96 97 if viewer is None: 98 viewer = napari.Viewer() 99 100 viewer.add_image(image, name="image") 101 annotator = Annotator2d(viewer, reset_state=False) 102 103 # Trigger layer update of the annotator so that layers have the correct shape. 104 # And initialize the 'committed_objects' with the segmentation result if it was given. 105 annotator._update_image(segmentation_result=segmentation_result) 106 107 # Add the annotator widget to the viewer and sync widgets. 108 viewer.window.add_dock_widget(annotator) 109 _sync_embedding_widget( 110 widget=state.widgets["embeddings"], 111 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 112 save_path=embedding_path, 113 checkpoint_path=checkpoint_path, 114 device=device, 115 tile_shape=tile_shape, 116 halo=halo, 117 ) 118 119 if return_viewer: 120 return viewer 121 122 napari.run()
Start the 2d annotation tool for a given image.
Arguments:
- image: The image data.
- embedding_path: Filepath where to save the embeddings
or the precompted image embeddings computed by
precompute_image_embeddings
. - segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer.
- model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
- tile_shape: Shape of tiles for tiled embedding prediction.
If
None
then the whole image is passed to Segment Anything. - halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
- viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
- precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster. By default, set to 'False'.
- checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
- device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
- prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation. By default, set to 'True'.
Returns:
The napari viewer, only returned if
return_viewer=True
.