micro_sam.sam_annotator.annotator_3d

  1from typing import Optional, Tuple, Union
  2
  3import napari
  4import numpy as np
  5
  6import torch
  7
  8from ._annotator import _AnnotatorBase
  9from ._state import AnnotatorState
 10from . import _widgets as widgets
 11from .util import _initialize_parser, _sync_embedding_widget, _load_amg_state, _load_is_state
 12from .. import util
 13
 14
 15class Annotator3d(_AnnotatorBase):
 16    def _get_widgets(self):
 17        autosegment = widgets.AutoSegmentWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True)
 18        segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=False)
 19        return {
 20            "segment": widgets.segment_slice(),
 21            "segment_nd": segment_nd,
 22            "autosegment": autosegment,
 23            "commit": widgets.commit(),
 24            "clear": widgets.clear_volume(),
 25        }
 26
 27    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
 28        self._with_decoder = AnnotatorState().decoder is not None
 29        super().__init__(viewer=viewer, ndim=3)
 30
 31    def _update_image(self, segmentation_result=None):
 32        super()._update_image(segmentation_result=segmentation_result)
 33        # Load the amg state from the embedding path.
 34        state = AnnotatorState()
 35        if self._with_decoder:
 36            state.amg_state = _load_is_state(state.embedding_path)
 37        else:
 38            state.amg_state = _load_amg_state(state.embedding_path)
 39
 40
 41def annotator_3d(
 42    image: np.ndarray,
 43    embedding_path: Optional[str] = None,
 44    segmentation_result: Optional[np.ndarray] = None,
 45    model_type: str = util._DEFAULT_MODEL,
 46    tile_shape: Optional[Tuple[int, int]] = None,
 47    halo: Optional[Tuple[int, int]] = None,
 48    return_viewer: bool = False,
 49    viewer: Optional["napari.viewer.Viewer"] = None,
 50    precompute_amg_state: bool = False,
 51    checkpoint_path: Optional[str] = None,
 52    device: Optional[Union[str, torch.device]] = None,
 53    prefer_decoder: bool = True,
 54) -> Optional["napari.viewer.Viewer"]:
 55    """Start the 3d annotation tool for a given image volume.
 56
 57    Args:
 58        image: The volumetric image data.
 59        embedding_path: Filepath for saving the precomputed embeddings.
 60        segmentation_result: An initial segmentation to load.
 61            This can be used to correct segmentations with Segment Anything or to save and load progress.
 62            The segmentation will be loaded as the 'committed_objects' layer.
 63        model_type: The Segment Anything model to use. For details on the available models check out
 64            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
 65        tile_shape: Shape of tiles for tiled embedding prediction.
 66            If `None` then the whole image is passed to Segment Anything.
 67        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
 68        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
 69        viewer: The viewer to which the SegmentAnything functionality should be added.
 70            This enables using a pre-initialized viewer.
 71        precompute_amg_state: Whether to precompute the state for automatic mask generation.
 72            This will take more time when precomputing embeddings, but will then make
 73            automatic mask generation much faster.
 74        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
 75        device: The computational device to use for the SAM model.
 76        prefer_decoder: Whether to use decoder based instance segmentation if
 77            the model used has an additional decoder for instance segmentation.
 78
 79    Returns:
 80        The napari viewer, only returned if `return_viewer=True`.
 81    """
 82
 83    # Initialize the predictor state.
 84    state = AnnotatorState()
 85    state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape
 86    state.initialize_predictor(
 87        image, model_type=model_type, save_path=embedding_path,
 88        halo=halo, tile_shape=tile_shape, ndim=3, precompute_amg_state=precompute_amg_state,
 89        checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder,
 90    )
 91
 92    if viewer is None:
 93        viewer = napari.Viewer()
 94
 95    viewer.add_image(image, name="image")
 96    annotator = Annotator3d(viewer)
 97
 98    # Trigger layer update of the annotator so that layers have the correct shape.
 99    # And initialize the 'committed_objects' with the segmentation result if it was given.
100    annotator._update_image(segmentation_result=segmentation_result)
101
102    # Add the annotator widget to the viewer and sync widgets.
103    viewer.window.add_dock_widget(annotator)
104    _sync_embedding_widget(
105        state.widgets["embeddings"], model_type,
106        save_path=embedding_path, checkpoint_path=checkpoint_path,
107        device=device, tile_shape=tile_shape, halo=halo
108    )
109
110    if return_viewer:
111        return viewer
112
113    napari.run()
114
115
116def main():
117    """@private"""
118    parser = _initialize_parser(description="Run interactive segmentation for an image volume.")
119    args = parser.parse_args()
120    image = util.load_image_data(args.input, key=args.key)
121
122    if args.segmentation_result is None:
123        segmentation_result = None
124    else:
125        segmentation_result = util.load_image_data(args.segmentation_result, key=args.segmentation_key)
126
127    annotator_3d(
128        image, embedding_path=args.embedding_path,
129        segmentation_result=segmentation_result,
130        model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo,
131        checkpoint_path=args.checkpoint, device=args.device,
132        precompute_amg_state=args.precompute_amg_state, prefer_decoder=args.prefer_decoder,
133    )
class Annotator3d(micro_sam.sam_annotator._annotator._AnnotatorBase):
16class Annotator3d(_AnnotatorBase):
17    def _get_widgets(self):
18        autosegment = widgets.AutoSegmentWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True)
19        segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=False)
20        return {
21            "segment": widgets.segment_slice(),
22            "segment_nd": segment_nd,
23            "autosegment": autosegment,
24            "commit": widgets.commit(),
25            "clear": widgets.clear_volume(),
26        }
27
28    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
29        self._with_decoder = AnnotatorState().decoder is not None
30        super().__init__(viewer=viewer, ndim=3)
31
32    def _update_image(self, segmentation_result=None):
33        super()._update_image(segmentation_result=segmentation_result)
34        # Load the amg state from the embedding path.
35        state = AnnotatorState()
36        if self._with_decoder:
37            state.amg_state = _load_is_state(state.embedding_path)
38        else:
39            state.amg_state = _load_amg_state(state.embedding_path)

Base class for micro_sam annotation plugins.

Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.

Annotator3d(viewer: napari.viewer.Viewer)
28    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
29        self._with_decoder = AnnotatorState().decoder is not None
30        super().__init__(viewer=viewer, ndim=3)

Create the annotator GUI.

Arguments:
  • viewer: The napari viewer.
  • ndim: The number of spatial dimension of the image data (2 or 3).
Inherited Members
PyQt5.QtWidgets.QScrollArea
alignment
ensureVisible
ensureWidgetVisible
event
eventFilter
focusNextPrevChild
resizeEvent
scrollContentsBy
setAlignment
setWidget
setWidgetResizable
sizeHint
takeWidget
viewportSizeHint
widget
widgetResizable
PyQt5.QtWidgets.QAbstractScrollArea
SizeAdjustPolicy
addScrollBarWidget
contextMenuEvent
cornerWidget
dragEnterEvent
dragLeaveEvent
dragMoveEvent
dropEvent
horizontalScrollBar
horizontalScrollBarPolicy
keyPressEvent
maximumViewportSize
minimumSizeHint
mouseDoubleClickEvent
mouseMoveEvent
mousePressEvent
mouseReleaseEvent
paintEvent
scrollBarWidgets
setCornerWidget
setHorizontalScrollBar
setHorizontalScrollBarPolicy
setSizeAdjustPolicy
setVerticalScrollBar
setVerticalScrollBarPolicy
setViewport
setViewportMargins
setupViewport
sizeAdjustPolicy
verticalScrollBar
verticalScrollBarPolicy
viewport
viewportEvent
viewportMargins
wheelEvent
AdjustIgnored
AdjustToContents
AdjustToContentsOnFirstShow
PyQt5.QtWidgets.QFrame
Shadow
Shape
StyleMask
changeEvent
drawFrame
frameRect
frameShadow
frameShape
frameStyle
frameWidth
initStyleOption
lineWidth
midLineWidth
setFrameRect
setFrameShadow
setFrameShape
setFrameStyle
setLineWidth
setMidLineWidth
Box
HLine
NoFrame
Panel
Plain
Raised
Shadow_Mask
Shape_Mask
StyledPanel
Sunken
VLine
WinPanel
PyQt5.QtWidgets.QWidget
RenderFlag
RenderFlags
acceptDrops
accessibleDescription
accessibleName
actionEvent
actions
activateWindow
addAction
addActions
adjustSize
autoFillBackground
backgroundRole
baseSize
childAt
childrenRect
childrenRegion
clearFocus
clearMask
close
closeEvent
contentsMargins
contentsRect
contextMenuPolicy
create
createWindowContainer
cursor
destroy
devType
effectiveWinId
ensurePolished
enterEvent
find
focusInEvent
focusNextChild
focusOutEvent
focusPolicy
focusPreviousChild
focusProxy
focusWidget
font
fontInfo
fontMetrics
foregroundRole
frameGeometry
frameSize
geometry
getContentsMargins
grab
grabGesture
grabKeyboard
grabMouse
grabShortcut
graphicsEffect
graphicsProxyWidget
hasFocus
hasHeightForWidth
hasMouseTracking
hasTabletTracking
height
heightForWidth
hide
hideEvent
initPainter
inputMethodEvent
inputMethodHints
inputMethodQuery
insertAction
insertActions
isActiveWindow
isAncestorOf
isEnabled
isEnabledTo
isFullScreen
isHidden
isLeftToRight
isMaximized
isMinimized
isModal
isRightToLeft
isVisible
isVisibleTo
isWindow
isWindowModified
keyReleaseEvent
keyboardGrabber
layout
layoutDirection
leaveEvent
locale
lower
mapFrom
mapFromGlobal
mapFromParent
mapTo
mapToGlobal
mapToParent
mask
maximumHeight
maximumSize
maximumWidth
metric
minimumHeight
minimumSize
minimumWidth
mouseGrabber
move
moveEvent
nativeEvent
nativeParentWidget
nextInFocusChain
normalGeometry
overrideWindowFlags
overrideWindowState
paintEngine
palette
parentWidget
pos
previousInFocusChain
raise_
rect
releaseKeyboard
releaseMouse
releaseShortcut
removeAction
render
repaint
resize
restoreGeometry
saveGeometry
screen
scroll
setAcceptDrops
setAccessibleDescription
setAccessibleName
setAttribute
setAutoFillBackground
setBackgroundRole
setBaseSize
setContentsMargins
setContextMenuPolicy
setCursor
setDisabled
setEnabled
setFixedHeight
setFixedSize
setFixedWidth
setFocus
setFocusPolicy
setFocusProxy
setFont
setForegroundRole
setGeometry
setGraphicsEffect
setHidden
setInputMethodHints
setLayout
setLayoutDirection
setLocale
setMask
setMaximumHeight
setMaximumSize
setMaximumWidth
setMinimumHeight
setMinimumSize
setMinimumWidth
setMouseTracking
setPalette
setParent
setShortcutAutoRepeat
setShortcutEnabled
setSizeIncrement
setSizePolicy
setStatusTip
setStyle
setStyleSheet
setTabOrder
setTabletTracking
setToolTip
setToolTipDuration
setUpdatesEnabled
setVisible
setWhatsThis
setWindowFilePath
setWindowFlag
setWindowFlags
setWindowIcon
setWindowIconText
setWindowModality
setWindowModified
setWindowOpacity
setWindowRole
setWindowState
setWindowTitle
sharedPainter
show
showEvent
showFullScreen
showMaximized
showMinimized
showNormal
size
sizeIncrement
sizePolicy
stackUnder
statusTip
style
styleSheet
tabletEvent
testAttribute
toolTip
toolTipDuration
underMouse
ungrabGesture
unsetCursor
unsetLayoutDirection
unsetLocale
update
updateGeometry
updateMicroFocus
updatesEnabled
visibleRegion
whatsThis
width
winId
window
windowFilePath
windowFlags
windowHandle
windowIcon
windowIconText
windowModality
windowOpacity
windowRole
windowState
windowTitle
windowType
x
y
DrawChildren
DrawWindowBackground
IgnoreMask
windowIconTextChanged
windowIconChanged
windowTitleChanged
customContextMenuRequested
PyQt5.QtCore.QObject
blockSignals
childEvent
children
connectNotify
customEvent
deleteLater
disconnect
disconnectNotify
dumpObjectInfo
dumpObjectTree
dynamicPropertyNames
findChild
findChildren
inherits
installEventFilter
isSignalConnected
isWidgetType
isWindowType
killTimer
metaObject
moveToThread
objectName
parent
property
pyqtConfigure
receivers
removeEventFilter
sender
senderSignalIndex
setObjectName
setProperty
signalsBlocked
startTimer
thread
timerEvent
tr
staticMetaObject
objectNameChanged
destroyed
PyQt5.QtGui.QPaintDevice
PaintDeviceMetric
colorCount
depth
devicePixelRatio
devicePixelRatioF
devicePixelRatioFScale
heightMM
logicalDpiX
logicalDpiY
paintingActive
physicalDpiX
physicalDpiY
widthMM
PdmDepth
PdmDevicePixelRatio
PdmDevicePixelRatioScaled
PdmDpiX
PdmDpiY
PdmHeight
PdmHeightMM
PdmNumColors
PdmPhysicalDpiX
PdmPhysicalDpiY
PdmWidth
PdmWidthMM
def annotator_3d( image: numpy.ndarray, embedding_path: Optional[str] = None, segmentation_result: Optional[numpy.ndarray] = None, model_type: str = 'vit_l', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True) -> Optional[napari.viewer.Viewer]:
 42def annotator_3d(
 43    image: np.ndarray,
 44    embedding_path: Optional[str] = None,
 45    segmentation_result: Optional[np.ndarray] = None,
 46    model_type: str = util._DEFAULT_MODEL,
 47    tile_shape: Optional[Tuple[int, int]] = None,
 48    halo: Optional[Tuple[int, int]] = None,
 49    return_viewer: bool = False,
 50    viewer: Optional["napari.viewer.Viewer"] = None,
 51    precompute_amg_state: bool = False,
 52    checkpoint_path: Optional[str] = None,
 53    device: Optional[Union[str, torch.device]] = None,
 54    prefer_decoder: bool = True,
 55) -> Optional["napari.viewer.Viewer"]:
 56    """Start the 3d annotation tool for a given image volume.
 57
 58    Args:
 59        image: The volumetric image data.
 60        embedding_path: Filepath for saving the precomputed embeddings.
 61        segmentation_result: An initial segmentation to load.
 62            This can be used to correct segmentations with Segment Anything or to save and load progress.
 63            The segmentation will be loaded as the 'committed_objects' layer.
 64        model_type: The Segment Anything model to use. For details on the available models check out
 65            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
 66        tile_shape: Shape of tiles for tiled embedding prediction.
 67            If `None` then the whole image is passed to Segment Anything.
 68        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
 69        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
 70        viewer: The viewer to which the SegmentAnything functionality should be added.
 71            This enables using a pre-initialized viewer.
 72        precompute_amg_state: Whether to precompute the state for automatic mask generation.
 73            This will take more time when precomputing embeddings, but will then make
 74            automatic mask generation much faster.
 75        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
 76        device: The computational device to use for the SAM model.
 77        prefer_decoder: Whether to use decoder based instance segmentation if
 78            the model used has an additional decoder for instance segmentation.
 79
 80    Returns:
 81        The napari viewer, only returned if `return_viewer=True`.
 82    """
 83
 84    # Initialize the predictor state.
 85    state = AnnotatorState()
 86    state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape
 87    state.initialize_predictor(
 88        image, model_type=model_type, save_path=embedding_path,
 89        halo=halo, tile_shape=tile_shape, ndim=3, precompute_amg_state=precompute_amg_state,
 90        checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder,
 91    )
 92
 93    if viewer is None:
 94        viewer = napari.Viewer()
 95
 96    viewer.add_image(image, name="image")
 97    annotator = Annotator3d(viewer)
 98
 99    # Trigger layer update of the annotator so that layers have the correct shape.
100    # And initialize the 'committed_objects' with the segmentation result if it was given.
101    annotator._update_image(segmentation_result=segmentation_result)
102
103    # Add the annotator widget to the viewer and sync widgets.
104    viewer.window.add_dock_widget(annotator)
105    _sync_embedding_widget(
106        state.widgets["embeddings"], model_type,
107        save_path=embedding_path, checkpoint_path=checkpoint_path,
108        device=device, tile_shape=tile_shape, halo=halo
109    )
110
111    if return_viewer:
112        return viewer
113
114    napari.run()

Start the 3d annotation tool for a given image volume.

Arguments:
  • image: The volumetric image data.
  • embedding_path: Filepath for saving the precomputed embeddings.
  • segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
  • viewer: The viewer to which the SegmentAnything functionality should be added. This enables using a pre-initialized viewer.
  • precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • device: The computational device to use for the SAM model.
  • prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation.
Returns:

The napari viewer, only returned if return_viewer=True.