micro_sam.sam_annotator.annotator_2d

  1from typing import Optional, Tuple, Union
  2
  3import napari
  4import numpy as np
  5
  6import torch
  7
  8from .. import util
  9from . import _widgets as widgets
 10from ._state import AnnotatorState
 11from ._annotator import _AnnotatorBase
 12from .util import _initialize_parser, _sync_embedding_widget
 13
 14
 15class Annotator2d(_AnnotatorBase):
 16    def _get_widgets(self):
 17        autosegment = widgets.AutoSegmentWidget(
 18            self._viewer, with_decoder=AnnotatorState().decoder is not None, volumetric=False
 19        )
 20        return {
 21            "segment": widgets.segment(),
 22            "autosegment": autosegment,
 23            "commit": widgets.commit(),
 24            "clear": widgets.clear(),
 25        }
 26
 27    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
 28        super().__init__(viewer=viewer, ndim=2)
 29
 30        # Set the expected annotator class to the state.
 31        state = AnnotatorState()
 32
 33        # Reset the state.
 34        if reset_state:
 35            state.reset_state()
 36
 37        state.annotator = self
 38
 39
 40def annotator_2d(
 41    image: np.ndarray,
 42    embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None,
 43    segmentation_result: Optional[np.ndarray] = None,
 44    model_type: str = util._DEFAULT_MODEL,
 45    tile_shape: Optional[Tuple[int, int]] = None,
 46    halo: Optional[Tuple[int, int]] = None,
 47    return_viewer: bool = False,
 48    viewer: Optional["napari.viewer.Viewer"] = None,
 49    precompute_amg_state: bool = False,
 50    checkpoint_path: Optional[str] = None,
 51    device: Optional[Union[str, torch.device]] = None,
 52    prefer_decoder: bool = True,
 53) -> Optional["napari.viewer.Viewer"]:
 54    """Start the 2d annotation tool for a given image.
 55
 56    Args:
 57        image: The image data.
 58        embedding_path: Filepath where to save the embeddings
 59            or the precompted image embeddings computed by `precompute_image_embeddings`.
 60        segmentation_result: An initial segmentation to load.
 61            This can be used to correct segmentations with Segment Anything or to save and load progress.
 62            The segmentation will be loaded as the 'committed_objects' layer.
 63        model_type: The Segment Anything model to use. For details on the available models check out
 64            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
 65        tile_shape: Shape of tiles for tiled embedding prediction.
 66            If `None` then the whole image is passed to Segment Anything.
 67        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
 68        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
 69            By default, does not return the napari viewer.
 70        viewer: The viewer to which the Segment Anything functionality should be added.
 71            This enables using a pre-initialized viewer.
 72        precompute_amg_state: Whether to precompute the state for automatic mask generation.
 73            This will take more time when precomputing embeddings, but will then make
 74            automatic mask generation much faster. By default, set to 'False'.
 75        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
 76        device: The computational device to use for the SAM model.
 77            By default, automatically chooses the best available device.
 78        prefer_decoder: Whether to use decoder based instance segmentation if
 79            the model used has an additional decoder for instance segmentation.
 80            By default, set to 'True'.
 81
 82    Returns:
 83        The napari viewer, only returned if `return_viewer=True`.
 84    """
 85
 86    state = AnnotatorState()
 87    state.image_shape = image.shape[:-1] if image.ndim == 3 else image.shape
 88
 89    state.initialize_predictor(
 90        image, model_type=model_type, save_path=embedding_path,
 91        halo=halo, tile_shape=tile_shape, precompute_amg_state=precompute_amg_state,
 92        ndim=2, checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder,
 93        skip_load=False, use_cli=True,
 94    )
 95
 96    if viewer is None:
 97        viewer = napari.Viewer()
 98
 99    viewer.add_image(image, name="image")
100    annotator = Annotator2d(viewer, reset_state=False)
101
102    # Trigger layer update of the annotator so that layers have the correct shape.
103    # And initialize the 'committed_objects' with the segmentation result if it was given.
104    annotator._update_image(segmentation_result=segmentation_result)
105
106    # Add the annotator widget to the viewer and sync widgets.
107    viewer.window.add_dock_widget(annotator)
108    _sync_embedding_widget(
109        widget=state.widgets["embeddings"],
110        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
111        save_path=embedding_path,
112        checkpoint_path=checkpoint_path,
113        device=device,
114        tile_shape=tile_shape,
115        halo=halo,
116    )
117
118    if return_viewer:
119        return viewer
120
121    napari.run()
122
123
124def main():
125    """@private"""
126    parser = _initialize_parser(description="Run interactive segmentation for an image.")
127    args = parser.parse_args()
128    image = util.load_image_data(args.input, key=args.key)
129
130    if args.segmentation_result is None:
131        segmentation_result = None
132    else:
133        segmentation_result = util.load_image_data(args.segmentation_result, key=args.segmentation_key)
134
135    annotator_2d(
136        image, embedding_path=args.embedding_path,
137        segmentation_result=segmentation_result,
138        model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo,
139        precompute_amg_state=args.precompute_amg_state, checkpoint_path=args.checkpoint,
140        device=args.device, prefer_decoder=args.prefer_decoder,
141    )
class Annotator2d(micro_sam.sam_annotator._annotator._AnnotatorBase):
16class Annotator2d(_AnnotatorBase):
17    def _get_widgets(self):
18        autosegment = widgets.AutoSegmentWidget(
19            self._viewer, with_decoder=AnnotatorState().decoder is not None, volumetric=False
20        )
21        return {
22            "segment": widgets.segment(),
23            "autosegment": autosegment,
24            "commit": widgets.commit(),
25            "clear": widgets.clear(),
26        }
27
28    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
29        super().__init__(viewer=viewer, ndim=2)
30
31        # Set the expected annotator class to the state.
32        state = AnnotatorState()
33
34        # Reset the state.
35        if reset_state:
36            state.reset_state()
37
38        state.annotator = self

Base class for micro_sam annotation plugins.

Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.

Annotator2d(viewer: napari.viewer.Viewer, reset_state: bool = True)
28    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
29        super().__init__(viewer=viewer, ndim=2)
30
31        # Set the expected annotator class to the state.
32        state = AnnotatorState()
33
34        # Reset the state.
35        if reset_state:
36            state.reset_state()
37
38        state.annotator = self

Create the annotator GUI.

Arguments:
  • viewer: The napari viewer.
  • ndim: The number of spatial dimension of the image data (2 or 3).
Inherited Members
PyQt5.QtWidgets.QScrollArea
alignment
ensureVisible
ensureWidgetVisible
event
eventFilter
focusNextPrevChild
resizeEvent
scrollContentsBy
setAlignment
setWidget
setWidgetResizable
sizeHint
takeWidget
viewportSizeHint
widget
widgetResizable
PyQt5.QtWidgets.QAbstractScrollArea
SizeAdjustPolicy
addScrollBarWidget
contextMenuEvent
cornerWidget
dragEnterEvent
dragLeaveEvent
dragMoveEvent
dropEvent
horizontalScrollBar
horizontalScrollBarPolicy
keyPressEvent
maximumViewportSize
minimumSizeHint
mouseDoubleClickEvent
mouseMoveEvent
mousePressEvent
mouseReleaseEvent
paintEvent
scrollBarWidgets
setCornerWidget
setHorizontalScrollBar
setHorizontalScrollBarPolicy
setSizeAdjustPolicy
setVerticalScrollBar
setVerticalScrollBarPolicy
setViewport
setViewportMargins
setupViewport
sizeAdjustPolicy
verticalScrollBar
verticalScrollBarPolicy
viewport
viewportEvent
viewportMargins
wheelEvent
AdjustIgnored
AdjustToContents
AdjustToContentsOnFirstShow
PyQt5.QtWidgets.QFrame
Shadow
Shape
StyleMask
changeEvent
drawFrame
frameRect
frameShadow
frameShape
frameStyle
frameWidth
initStyleOption
lineWidth
midLineWidth
setFrameRect
setFrameShadow
setFrameShape
setFrameStyle
setLineWidth
setMidLineWidth
Box
HLine
NoFrame
Panel
Plain
Raised
Shadow_Mask
Shape_Mask
StyledPanel
Sunken
VLine
WinPanel
PyQt5.QtWidgets.QWidget
RenderFlag
RenderFlags
acceptDrops
accessibleDescription
accessibleName
actionEvent
actions
activateWindow
addAction
addActions
adjustSize
autoFillBackground
backgroundRole
baseSize
childAt
childrenRect
childrenRegion
clearFocus
clearMask
close
closeEvent
contentsMargins
contentsRect
contextMenuPolicy
create
createWindowContainer
cursor
destroy
devType
effectiveWinId
ensurePolished
enterEvent
find
focusInEvent
focusNextChild
focusOutEvent
focusPolicy
focusPreviousChild
focusProxy
focusWidget
font
fontInfo
fontMetrics
foregroundRole
frameGeometry
frameSize
geometry
getContentsMargins
grab
grabGesture
grabKeyboard
grabMouse
grabShortcut
graphicsEffect
graphicsProxyWidget
hasFocus
hasHeightForWidth
hasMouseTracking
hasTabletTracking
height
heightForWidth
hide
hideEvent
initPainter
inputMethodEvent
inputMethodHints
inputMethodQuery
insertAction
insertActions
isActiveWindow
isAncestorOf
isEnabled
isEnabledTo
isFullScreen
isHidden
isLeftToRight
isMaximized
isMinimized
isModal
isRightToLeft
isVisible
isVisibleTo
isWindow
isWindowModified
keyReleaseEvent
keyboardGrabber
layout
layoutDirection
leaveEvent
locale
lower
mapFrom
mapFromGlobal
mapFromParent
mapTo
mapToGlobal
mapToParent
mask
maximumHeight
maximumSize
maximumWidth
metric
minimumHeight
minimumSize
minimumWidth
mouseGrabber
move
moveEvent
nativeEvent
nativeParentWidget
nextInFocusChain
normalGeometry
overrideWindowFlags
overrideWindowState
paintEngine
palette
parentWidget
pos
previousInFocusChain
raise_
rect
releaseKeyboard
releaseMouse
releaseShortcut
removeAction
render
repaint
resize
restoreGeometry
saveGeometry
screen
scroll
setAcceptDrops
setAccessibleDescription
setAccessibleName
setAttribute
setAutoFillBackground
setBackgroundRole
setBaseSize
setContentsMargins
setContextMenuPolicy
setCursor
setDisabled
setEnabled
setFixedHeight
setFixedSize
setFixedWidth
setFocus
setFocusPolicy
setFocusProxy
setFont
setForegroundRole
setGeometry
setGraphicsEffect
setHidden
setInputMethodHints
setLayout
setLayoutDirection
setLocale
setMask
setMaximumHeight
setMaximumSize
setMaximumWidth
setMinimumHeight
setMinimumSize
setMinimumWidth
setMouseTracking
setPalette
setParent
setShortcutAutoRepeat
setShortcutEnabled
setSizeIncrement
setSizePolicy
setStatusTip
setStyle
setStyleSheet
setTabOrder
setTabletTracking
setToolTip
setToolTipDuration
setUpdatesEnabled
setVisible
setWhatsThis
setWindowFilePath
setWindowFlag
setWindowFlags
setWindowIcon
setWindowIconText
setWindowModality
setWindowModified
setWindowOpacity
setWindowRole
setWindowState
setWindowTitle
sharedPainter
show
showEvent
showFullScreen
showMaximized
showMinimized
showNormal
size
sizeIncrement
sizePolicy
stackUnder
statusTip
style
styleSheet
tabletEvent
testAttribute
toolTip
toolTipDuration
underMouse
ungrabGesture
unsetCursor
unsetLayoutDirection
unsetLocale
update
updateGeometry
updateMicroFocus
updatesEnabled
visibleRegion
whatsThis
width
winId
window
windowFilePath
windowFlags
windowHandle
windowIcon
windowIconText
windowModality
windowOpacity
windowRole
windowState
windowTitle
windowType
x
y
DrawChildren
DrawWindowBackground
IgnoreMask
windowIconTextChanged
windowIconChanged
windowTitleChanged
customContextMenuRequested
PyQt5.QtCore.QObject
blockSignals
childEvent
children
connectNotify
customEvent
deleteLater
disconnect
disconnectNotify
dumpObjectInfo
dumpObjectTree
dynamicPropertyNames
findChild
findChildren
inherits
installEventFilter
isSignalConnected
isWidgetType
isWindowType
killTimer
metaObject
moveToThread
objectName
parent
property
pyqtConfigure
receivers
removeEventFilter
sender
senderSignalIndex
setObjectName
setProperty
signalsBlocked
startTimer
thread
timerEvent
tr
staticMetaObject
objectNameChanged
destroyed
PyQt5.QtGui.QPaintDevice
PaintDeviceMetric
colorCount
depth
devicePixelRatio
devicePixelRatioF
devicePixelRatioFScale
heightMM
logicalDpiX
logicalDpiY
paintingActive
physicalDpiX
physicalDpiY
widthMM
PdmDepth
PdmDevicePixelRatio
PdmDevicePixelRatioScaled
PdmDpiX
PdmDpiY
PdmHeight
PdmHeightMM
PdmNumColors
PdmPhysicalDpiX
PdmPhysicalDpiY
PdmWidth
PdmWidthMM
def annotator_2d( image: numpy.ndarray, embedding_path: Union[str, Dict[str, Any], NoneType] = None, segmentation_result: Optional[numpy.ndarray] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True) -> Optional[napari.viewer.Viewer]:
 41def annotator_2d(
 42    image: np.ndarray,
 43    embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None,
 44    segmentation_result: Optional[np.ndarray] = None,
 45    model_type: str = util._DEFAULT_MODEL,
 46    tile_shape: Optional[Tuple[int, int]] = None,
 47    halo: Optional[Tuple[int, int]] = None,
 48    return_viewer: bool = False,
 49    viewer: Optional["napari.viewer.Viewer"] = None,
 50    precompute_amg_state: bool = False,
 51    checkpoint_path: Optional[str] = None,
 52    device: Optional[Union[str, torch.device]] = None,
 53    prefer_decoder: bool = True,
 54) -> Optional["napari.viewer.Viewer"]:
 55    """Start the 2d annotation tool for a given image.
 56
 57    Args:
 58        image: The image data.
 59        embedding_path: Filepath where to save the embeddings
 60            or the precompted image embeddings computed by `precompute_image_embeddings`.
 61        segmentation_result: An initial segmentation to load.
 62            This can be used to correct segmentations with Segment Anything or to save and load progress.
 63            The segmentation will be loaded as the 'committed_objects' layer.
 64        model_type: The Segment Anything model to use. For details on the available models check out
 65            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
 66        tile_shape: Shape of tiles for tiled embedding prediction.
 67            If `None` then the whole image is passed to Segment Anything.
 68        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
 69        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
 70            By default, does not return the napari viewer.
 71        viewer: The viewer to which the Segment Anything functionality should be added.
 72            This enables using a pre-initialized viewer.
 73        precompute_amg_state: Whether to precompute the state for automatic mask generation.
 74            This will take more time when precomputing embeddings, but will then make
 75            automatic mask generation much faster. By default, set to 'False'.
 76        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
 77        device: The computational device to use for the SAM model.
 78            By default, automatically chooses the best available device.
 79        prefer_decoder: Whether to use decoder based instance segmentation if
 80            the model used has an additional decoder for instance segmentation.
 81            By default, set to 'True'.
 82
 83    Returns:
 84        The napari viewer, only returned if `return_viewer=True`.
 85    """
 86
 87    state = AnnotatorState()
 88    state.image_shape = image.shape[:-1] if image.ndim == 3 else image.shape
 89
 90    state.initialize_predictor(
 91        image, model_type=model_type, save_path=embedding_path,
 92        halo=halo, tile_shape=tile_shape, precompute_amg_state=precompute_amg_state,
 93        ndim=2, checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder,
 94        skip_load=False, use_cli=True,
 95    )
 96
 97    if viewer is None:
 98        viewer = napari.Viewer()
 99
100    viewer.add_image(image, name="image")
101    annotator = Annotator2d(viewer, reset_state=False)
102
103    # Trigger layer update of the annotator so that layers have the correct shape.
104    # And initialize the 'committed_objects' with the segmentation result if it was given.
105    annotator._update_image(segmentation_result=segmentation_result)
106
107    # Add the annotator widget to the viewer and sync widgets.
108    viewer.window.add_dock_widget(annotator)
109    _sync_embedding_widget(
110        widget=state.widgets["embeddings"],
111        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
112        save_path=embedding_path,
113        checkpoint_path=checkpoint_path,
114        device=device,
115        tile_shape=tile_shape,
116        halo=halo,
117    )
118
119    if return_viewer:
120        return viewer
121
122    napari.run()

Start the 2d annotation tool for a given image.

Arguments:
  • image: The image data.
  • embedding_path: Filepath where to save the embeddings or the precompted image embeddings computed by precompute_image_embeddings.
  • segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster. By default, set to 'False'.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
  • prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation. By default, set to 'True'.
Returns:

The napari viewer, only returned if return_viewer=True.