micro_sam.sam_annotator.annotator_3d

  1from typing import Optional, Tuple, Union
  2
  3import napari
  4import numpy as np
  5
  6import torch
  7
  8from .. import util
  9from . import _widgets as widgets
 10from ._state import AnnotatorState
 11from ._annotator import _AnnotatorBase
 12from .util import _initialize_parser, _sync_embedding_widget, _load_amg_state, _load_is_state
 13
 14
 15class Annotator3d(_AnnotatorBase):
 16    def _get_widgets(self):
 17        autosegment = widgets.AutoSegmentWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True)
 18        segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=False)
 19        return {
 20            "segment": widgets.segment_slice(),
 21            "segment_nd": segment_nd,
 22            "autosegment": autosegment,
 23            "commit": widgets.commit(),
 24            "clear": widgets.clear_volume(),
 25        }
 26
 27    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
 28        self._with_decoder = AnnotatorState().decoder is not None
 29        super().__init__(viewer=viewer, ndim=3)
 30
 31        # Set the expected annotator class to the state.
 32        state = AnnotatorState()
 33
 34        # Reset the state.
 35        if reset_state:
 36            state.reset_state()
 37
 38        state.annotator = self
 39
 40    def _update_image(self, segmentation_result=None):
 41        super()._update_image(segmentation_result=segmentation_result)
 42        # Load the amg state from the embedding path.
 43        state = AnnotatorState()
 44        if self._with_decoder:
 45            state.amg_state = _load_is_state(state.embedding_path)
 46        else:
 47            state.amg_state = _load_amg_state(state.embedding_path)
 48
 49
 50def annotator_3d(
 51    image: np.ndarray,
 52    embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None,
 53    segmentation_result: Optional[np.ndarray] = None,
 54    model_type: str = util._DEFAULT_MODEL,
 55    tile_shape: Optional[Tuple[int, int]] = None,
 56    halo: Optional[Tuple[int, int]] = None,
 57    return_viewer: bool = False,
 58    viewer: Optional["napari.viewer.Viewer"] = None,
 59    precompute_amg_state: bool = False,
 60    checkpoint_path: Optional[str] = None,
 61    device: Optional[Union[str, torch.device]] = None,
 62    prefer_decoder: bool = True,
 63) -> Optional["napari.viewer.Viewer"]:
 64    """Start the 3d annotation tool for a given image volume.
 65
 66    Args:
 67        image: The volumetric image data.
 68        embedding_path: Filepath where to save the embeddings
 69            or the precompted image embeddings computed by `precompute_image_embeddings`.
 70        segmentation_result: An initial segmentation to load.
 71            This can be used to correct segmentations with Segment Anything or to save and load progress.
 72            The segmentation will be loaded as the 'committed_objects' layer.
 73        model_type: The Segment Anything model to use. For details on the available models check out
 74            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
 75        tile_shape: Shape of tiles for tiled embedding prediction.
 76            If `None` then the whole image is passed to Segment Anything.
 77        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
 78        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
 79            By default, does not return the napari viewer.
 80        viewer: The viewer to which the Segment Anything functionality should be added.
 81            This enables using a pre-initialized viewer.
 82        precompute_amg_state: Whether to precompute the state for automatic mask generation.
 83            This will take more time when precomputing embeddings, but will then make
 84            automatic mask generation much faster. By default, set to 'False'.
 85        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
 86        device: The computational device to use for the SAM model.
 87            By default, automatically chooses the best available device.
 88        prefer_decoder: Whether to use decoder based instance segmentation if
 89            the model used has an additional decoder for instance segmentation.
 90            By default, set to 'True'.
 91
 92    Returns:
 93        The napari viewer, only returned if `return_viewer=True`.
 94    """
 95
 96    # Initialize the predictor state.
 97    state = AnnotatorState()
 98    state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape
 99    state.initialize_predictor(
100        image, model_type=model_type, save_path=embedding_path,
101        halo=halo, tile_shape=tile_shape, ndim=3, precompute_amg_state=precompute_amg_state,
102        checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder,
103        use_cli=True,
104    )
105
106    if viewer is None:
107        viewer = napari.Viewer()
108
109    viewer.add_image(image, name="image")
110    annotator = Annotator3d(viewer, reset_state=False)
111
112    # Trigger layer update of the annotator so that layers have the correct shape.
113    # And initialize the 'committed_objects' with the segmentation result if it was given.
114    annotator._update_image(segmentation_result=segmentation_result)
115
116    # Add the annotator widget to the viewer and sync widgets.
117    viewer.window.add_dock_widget(annotator)
118    _sync_embedding_widget(
119        widget=state.widgets["embeddings"],
120        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
121        save_path=embedding_path,
122        checkpoint_path=checkpoint_path,
123        device=device,
124        tile_shape=tile_shape,
125        halo=halo
126    )
127
128    if return_viewer:
129        return viewer
130
131    napari.run()
132
133
134def main():
135    """@private"""
136    parser = _initialize_parser(description="Run interactive segmentation for an image volume.")
137    args = parser.parse_args()
138    image = util.load_image_data(args.input, key=args.key)
139
140    if args.segmentation_result is None:
141        segmentation_result = None
142    else:
143        segmentation_result = util.load_image_data(args.segmentation_result, key=args.segmentation_key)
144
145    annotator_3d(
146        image, embedding_path=args.embedding_path,
147        segmentation_result=segmentation_result,
148        model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo,
149        checkpoint_path=args.checkpoint, device=args.device,
150        precompute_amg_state=args.precompute_amg_state, prefer_decoder=args.prefer_decoder,
151    )
class Annotator3d(micro_sam.sam_annotator._annotator._AnnotatorBase):
16class Annotator3d(_AnnotatorBase):
17    def _get_widgets(self):
18        autosegment = widgets.AutoSegmentWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True)
19        segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=False)
20        return {
21            "segment": widgets.segment_slice(),
22            "segment_nd": segment_nd,
23            "autosegment": autosegment,
24            "commit": widgets.commit(),
25            "clear": widgets.clear_volume(),
26        }
27
28    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
29        self._with_decoder = AnnotatorState().decoder is not None
30        super().__init__(viewer=viewer, ndim=3)
31
32        # Set the expected annotator class to the state.
33        state = AnnotatorState()
34
35        # Reset the state.
36        if reset_state:
37            state.reset_state()
38
39        state.annotator = self
40
41    def _update_image(self, segmentation_result=None):
42        super()._update_image(segmentation_result=segmentation_result)
43        # Load the amg state from the embedding path.
44        state = AnnotatorState()
45        if self._with_decoder:
46            state.amg_state = _load_is_state(state.embedding_path)
47        else:
48            state.amg_state = _load_amg_state(state.embedding_path)

Base class for micro_sam annotation plugins.

Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.

Annotator3d(viewer: napari.viewer.Viewer, reset_state: bool = True)
28    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
29        self._with_decoder = AnnotatorState().decoder is not None
30        super().__init__(viewer=viewer, ndim=3)
31
32        # Set the expected annotator class to the state.
33        state = AnnotatorState()
34
35        # Reset the state.
36        if reset_state:
37            state.reset_state()
38
39        state.annotator = self

Create the annotator GUI.

Arguments:
  • viewer: The napari viewer.
  • ndim: The number of spatial dimension of the image data (2 or 3).
Inherited Members
PyQt5.QtWidgets.QScrollArea
alignment
ensureVisible
ensureWidgetVisible
event
eventFilter
focusNextPrevChild
resizeEvent
scrollContentsBy
setAlignment
setWidget
setWidgetResizable
sizeHint
takeWidget
viewportSizeHint
widget
widgetResizable
PyQt5.QtWidgets.QAbstractScrollArea
SizeAdjustPolicy
addScrollBarWidget
contextMenuEvent
cornerWidget
dragEnterEvent
dragLeaveEvent
dragMoveEvent
dropEvent
horizontalScrollBar
horizontalScrollBarPolicy
keyPressEvent
maximumViewportSize
minimumSizeHint
mouseDoubleClickEvent
mouseMoveEvent
mousePressEvent
mouseReleaseEvent
paintEvent
scrollBarWidgets
setCornerWidget
setHorizontalScrollBar
setHorizontalScrollBarPolicy
setSizeAdjustPolicy
setVerticalScrollBar
setVerticalScrollBarPolicy
setViewport
setViewportMargins
setupViewport
sizeAdjustPolicy
verticalScrollBar
verticalScrollBarPolicy
viewport
viewportEvent
viewportMargins
wheelEvent
AdjustIgnored
AdjustToContents
AdjustToContentsOnFirstShow
PyQt5.QtWidgets.QFrame
Shadow
Shape
StyleMask
changeEvent
drawFrame
frameRect
frameShadow
frameShape
frameStyle
frameWidth
initStyleOption
lineWidth
midLineWidth
setFrameRect
setFrameShadow
setFrameShape
setFrameStyle
setLineWidth
setMidLineWidth
Box
HLine
NoFrame
Panel
Plain
Raised
Shadow_Mask
Shape_Mask
StyledPanel
Sunken
VLine
WinPanel
PyQt5.QtWidgets.QWidget
RenderFlag
RenderFlags
acceptDrops
accessibleDescription
accessibleName
actionEvent
actions
activateWindow
addAction
addActions
adjustSize
autoFillBackground
backgroundRole
baseSize
childAt
childrenRect
childrenRegion
clearFocus
clearMask
close
closeEvent
contentsMargins
contentsRect
contextMenuPolicy
create
createWindowContainer
cursor
destroy
devType
effectiveWinId
ensurePolished
enterEvent
find
focusInEvent
focusNextChild
focusOutEvent
focusPolicy
focusPreviousChild
focusProxy
focusWidget
font
fontInfo
fontMetrics
foregroundRole
frameGeometry
frameSize
geometry
getContentsMargins
grab
grabGesture
grabKeyboard
grabMouse
grabShortcut
graphicsEffect
graphicsProxyWidget
hasFocus
hasHeightForWidth
hasMouseTracking
hasTabletTracking
height
heightForWidth
hide
hideEvent
initPainter
inputMethodEvent
inputMethodHints
inputMethodQuery
insertAction
insertActions
isActiveWindow
isAncestorOf
isEnabled
isEnabledTo
isFullScreen
isHidden
isLeftToRight
isMaximized
isMinimized
isModal
isRightToLeft
isVisible
isVisibleTo
isWindow
isWindowModified
keyReleaseEvent
keyboardGrabber
layout
layoutDirection
leaveEvent
locale
lower
mapFrom
mapFromGlobal
mapFromParent
mapTo
mapToGlobal
mapToParent
mask
maximumHeight
maximumSize
maximumWidth
metric
minimumHeight
minimumSize
minimumWidth
mouseGrabber
move
moveEvent
nativeEvent
nativeParentWidget
nextInFocusChain
normalGeometry
overrideWindowFlags
overrideWindowState
paintEngine
palette
parentWidget
pos
previousInFocusChain
raise_
rect
releaseKeyboard
releaseMouse
releaseShortcut
removeAction
render
repaint
resize
restoreGeometry
saveGeometry
screen
scroll
setAcceptDrops
setAccessibleDescription
setAccessibleName
setAttribute
setAutoFillBackground
setBackgroundRole
setBaseSize
setContentsMargins
setContextMenuPolicy
setCursor
setDisabled
setEnabled
setFixedHeight
setFixedSize
setFixedWidth
setFocus
setFocusPolicy
setFocusProxy
setFont
setForegroundRole
setGeometry
setGraphicsEffect
setHidden
setInputMethodHints
setLayout
setLayoutDirection
setLocale
setMask
setMaximumHeight
setMaximumSize
setMaximumWidth
setMinimumHeight
setMinimumSize
setMinimumWidth
setMouseTracking
setPalette
setParent
setShortcutAutoRepeat
setShortcutEnabled
setSizeIncrement
setSizePolicy
setStatusTip
setStyle
setStyleSheet
setTabOrder
setTabletTracking
setToolTip
setToolTipDuration
setUpdatesEnabled
setVisible
setWhatsThis
setWindowFilePath
setWindowFlag
setWindowFlags
setWindowIcon
setWindowIconText
setWindowModality
setWindowModified
setWindowOpacity
setWindowRole
setWindowState
setWindowTitle
sharedPainter
show
showEvent
showFullScreen
showMaximized
showMinimized
showNormal
size
sizeIncrement
sizePolicy
stackUnder
statusTip
style
styleSheet
tabletEvent
testAttribute
toolTip
toolTipDuration
underMouse
ungrabGesture
unsetCursor
unsetLayoutDirection
unsetLocale
update
updateGeometry
updateMicroFocus
updatesEnabled
visibleRegion
whatsThis
width
winId
window
windowFilePath
windowFlags
windowHandle
windowIcon
windowIconText
windowModality
windowOpacity
windowRole
windowState
windowTitle
windowType
x
y
DrawChildren
DrawWindowBackground
IgnoreMask
windowIconTextChanged
windowIconChanged
windowTitleChanged
customContextMenuRequested
PyQt5.QtCore.QObject
blockSignals
childEvent
children
connectNotify
customEvent
deleteLater
disconnect
disconnectNotify
dumpObjectInfo
dumpObjectTree
dynamicPropertyNames
findChild
findChildren
inherits
installEventFilter
isSignalConnected
isWidgetType
isWindowType
killTimer
metaObject
moveToThread
objectName
parent
property
pyqtConfigure
receivers
removeEventFilter
sender
senderSignalIndex
setObjectName
setProperty
signalsBlocked
startTimer
thread
timerEvent
tr
staticMetaObject
objectNameChanged
destroyed
PyQt5.QtGui.QPaintDevice
PaintDeviceMetric
colorCount
depth
devicePixelRatio
devicePixelRatioF
devicePixelRatioFScale
heightMM
logicalDpiX
logicalDpiY
paintingActive
physicalDpiX
physicalDpiY
widthMM
PdmDepth
PdmDevicePixelRatio
PdmDevicePixelRatioScaled
PdmDpiX
PdmDpiY
PdmHeight
PdmHeightMM
PdmNumColors
PdmPhysicalDpiX
PdmPhysicalDpiY
PdmWidth
PdmWidthMM
def annotator_3d( image: numpy.ndarray, embedding_path: Union[str, Dict[str, Any], NoneType] = None, segmentation_result: Optional[numpy.ndarray] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True) -> Optional[napari.viewer.Viewer]:
 51def annotator_3d(
 52    image: np.ndarray,
 53    embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None,
 54    segmentation_result: Optional[np.ndarray] = None,
 55    model_type: str = util._DEFAULT_MODEL,
 56    tile_shape: Optional[Tuple[int, int]] = None,
 57    halo: Optional[Tuple[int, int]] = None,
 58    return_viewer: bool = False,
 59    viewer: Optional["napari.viewer.Viewer"] = None,
 60    precompute_amg_state: bool = False,
 61    checkpoint_path: Optional[str] = None,
 62    device: Optional[Union[str, torch.device]] = None,
 63    prefer_decoder: bool = True,
 64) -> Optional["napari.viewer.Viewer"]:
 65    """Start the 3d annotation tool for a given image volume.
 66
 67    Args:
 68        image: The volumetric image data.
 69        embedding_path: Filepath where to save the embeddings
 70            or the precompted image embeddings computed by `precompute_image_embeddings`.
 71        segmentation_result: An initial segmentation to load.
 72            This can be used to correct segmentations with Segment Anything or to save and load progress.
 73            The segmentation will be loaded as the 'committed_objects' layer.
 74        model_type: The Segment Anything model to use. For details on the available models check out
 75            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
 76        tile_shape: Shape of tiles for tiled embedding prediction.
 77            If `None` then the whole image is passed to Segment Anything.
 78        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
 79        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
 80            By default, does not return the napari viewer.
 81        viewer: The viewer to which the Segment Anything functionality should be added.
 82            This enables using a pre-initialized viewer.
 83        precompute_amg_state: Whether to precompute the state for automatic mask generation.
 84            This will take more time when precomputing embeddings, but will then make
 85            automatic mask generation much faster. By default, set to 'False'.
 86        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
 87        device: The computational device to use for the SAM model.
 88            By default, automatically chooses the best available device.
 89        prefer_decoder: Whether to use decoder based instance segmentation if
 90            the model used has an additional decoder for instance segmentation.
 91            By default, set to 'True'.
 92
 93    Returns:
 94        The napari viewer, only returned if `return_viewer=True`.
 95    """
 96
 97    # Initialize the predictor state.
 98    state = AnnotatorState()
 99    state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape
100    state.initialize_predictor(
101        image, model_type=model_type, save_path=embedding_path,
102        halo=halo, tile_shape=tile_shape, ndim=3, precompute_amg_state=precompute_amg_state,
103        checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder,
104        use_cli=True,
105    )
106
107    if viewer is None:
108        viewer = napari.Viewer()
109
110    viewer.add_image(image, name="image")
111    annotator = Annotator3d(viewer, reset_state=False)
112
113    # Trigger layer update of the annotator so that layers have the correct shape.
114    # And initialize the 'committed_objects' with the segmentation result if it was given.
115    annotator._update_image(segmentation_result=segmentation_result)
116
117    # Add the annotator widget to the viewer and sync widgets.
118    viewer.window.add_dock_widget(annotator)
119    _sync_embedding_widget(
120        widget=state.widgets["embeddings"],
121        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
122        save_path=embedding_path,
123        checkpoint_path=checkpoint_path,
124        device=device,
125        tile_shape=tile_shape,
126        halo=halo
127    )
128
129    if return_viewer:
130        return viewer
131
132    napari.run()

Start the 3d annotation tool for a given image volume.

Arguments:
  • image: The volumetric image data.
  • embedding_path: Filepath where to save the embeddings or the precompted image embeddings computed by precompute_image_embeddings.
  • segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster. By default, set to 'False'.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
  • prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation. By default, set to 'True'.
Returns:

The napari viewer, only returned if return_viewer=True.