micro_sam.sam_annotator.annotator_2d

  1from typing import Optional, Tuple, Union
  2
  3import napari
  4import numpy as np
  5
  6import torch
  7
  8from .. import util
  9from . import _widgets as widgets
 10from ._state import AnnotatorState
 11from ._annotator import _AnnotatorBase
 12from .util import _initialize_parser, _sync_embedding_widget
 13
 14
 15class Annotator2d(_AnnotatorBase):
 16    def _get_widgets(self):
 17        autosegment = widgets.AutoSegmentWidget(
 18            self._viewer, with_decoder=AnnotatorState().decoder is not None, volumetric=False
 19        )
 20        return {
 21            "segment": widgets.segment(),
 22            "autosegment": autosegment,
 23            "commit": widgets.commit(),
 24            "clear": widgets.clear(),
 25        }
 26
 27    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
 28        super().__init__(viewer=viewer, ndim=2)
 29
 30
 31def annotator_2d(
 32    image: np.ndarray,
 33    embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None,
 34    segmentation_result: Optional[np.ndarray] = None,
 35    model_type: str = util._DEFAULT_MODEL,
 36    tile_shape: Optional[Tuple[int, int]] = None,
 37    halo: Optional[Tuple[int, int]] = None,
 38    return_viewer: bool = False,
 39    viewer: Optional["napari.viewer.Viewer"] = None,
 40    precompute_amg_state: bool = False,
 41    checkpoint_path: Optional[str] = None,
 42    device: Optional[Union[str, torch.device]] = None,
 43    prefer_decoder: bool = True,
 44) -> Optional["napari.viewer.Viewer"]:
 45    """Start the 2d annotation tool for a given image.
 46
 47    Args:
 48        image: The image data.
 49        embedding_path: Filepath where to save the embeddings
 50            or the precompted image embeddings computed by `precompute_image_embeddings`.
 51        segmentation_result: An initial segmentation to load.
 52            This can be used to correct segmentations with Segment Anything or to save and load progress.
 53            The segmentation will be loaded as the 'committed_objects' layer.
 54        model_type: The Segment Anything model to use. For details on the available models check out
 55            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
 56        tile_shape: Shape of tiles for tiled embedding prediction.
 57            If `None` then the whole image is passed to Segment Anything.
 58        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
 59        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
 60        viewer: The viewer to which the Segment Anything functionality should be added.
 61            This enables using a pre-initialized viewer.
 62        precompute_amg_state: Whether to precompute the state for automatic mask generation.
 63            This will take more time when precomputing embeddings, but will then make
 64            automatic mask generation much faster.
 65        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
 66        device: The computational device to use for the SAM model.
 67        prefer_decoder: Whether to use decoder based instance segmentation if
 68            the model used has an additional decoder for instance segmentation.
 69
 70    Returns:
 71        The napari viewer, only returned if `return_viewer=True`.
 72    """
 73
 74    state = AnnotatorState()
 75    state.image_shape = image.shape[:-1] if image.ndim == 3 else image.shape
 76
 77    state.initialize_predictor(
 78        image, model_type=model_type, save_path=embedding_path,
 79        halo=halo, tile_shape=tile_shape, precompute_amg_state=precompute_amg_state,
 80        ndim=2, checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder,
 81        skip_load=False,
 82    )
 83
 84    if viewer is None:
 85        viewer = napari.Viewer()
 86
 87    viewer.add_image(image, name="image")
 88    annotator = Annotator2d(viewer)
 89
 90    # Trigger layer update of the annotator so that layers have the correct shape.
 91    # And initialize the 'committed_objects' with the segmentation result if it was given.
 92    annotator._update_image(segmentation_result=segmentation_result)
 93
 94    # Add the annotator widget to the viewer and sync widgets.
 95    viewer.window.add_dock_widget(annotator)
 96    _sync_embedding_widget(
 97        widget=state.widgets["embeddings"],
 98        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
 99        save_path=embedding_path,
100        checkpoint_path=checkpoint_path,
101        device=device,
102        tile_shape=tile_shape,
103        halo=halo,
104    )
105
106    if return_viewer:
107        return viewer
108
109    napari.run()
110
111
112def main():
113    """@private"""
114    parser = _initialize_parser(description="Run interactive segmentation for an image.")
115    args = parser.parse_args()
116    image = util.load_image_data(args.input, key=args.key)
117
118    if args.segmentation_result is None:
119        segmentation_result = None
120    else:
121        segmentation_result = util.load_image_data(args.segmentation_result, key=args.segmentation_key)
122
123    annotator_2d(
124        image, embedding_path=args.embedding_path,
125        segmentation_result=segmentation_result,
126        model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo,
127        precompute_amg_state=args.precompute_amg_state, checkpoint_path=args.checkpoint,
128        device=args.device, prefer_decoder=args.prefer_decoder,
129    )
class Annotator2d(micro_sam.sam_annotator._annotator._AnnotatorBase):
16class Annotator2d(_AnnotatorBase):
17    def _get_widgets(self):
18        autosegment = widgets.AutoSegmentWidget(
19            self._viewer, with_decoder=AnnotatorState().decoder is not None, volumetric=False
20        )
21        return {
22            "segment": widgets.segment(),
23            "autosegment": autosegment,
24            "commit": widgets.commit(),
25            "clear": widgets.clear(),
26        }
27
28    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
29        super().__init__(viewer=viewer, ndim=2)

Base class for micro_sam annotation plugins.

Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.

Annotator2d(viewer: napari.viewer.Viewer)
28    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
29        super().__init__(viewer=viewer, ndim=2)

Create the annotator GUI.

Arguments:
  • viewer: The napari viewer.
  • ndim: The number of spatial dimension of the image data (2 or 3).
Inherited Members
PyQt5.QtWidgets.QScrollArea
alignment
ensureVisible
ensureWidgetVisible
event
eventFilter
focusNextPrevChild
resizeEvent
scrollContentsBy
setAlignment
setWidget
setWidgetResizable
sizeHint
takeWidget
viewportSizeHint
widget
widgetResizable
PyQt5.QtWidgets.QAbstractScrollArea
SizeAdjustPolicy
addScrollBarWidget
contextMenuEvent
cornerWidget
dragEnterEvent
dragLeaveEvent
dragMoveEvent
dropEvent
horizontalScrollBar
horizontalScrollBarPolicy
keyPressEvent
maximumViewportSize
minimumSizeHint
mouseDoubleClickEvent
mouseMoveEvent
mousePressEvent
mouseReleaseEvent
paintEvent
scrollBarWidgets
setCornerWidget
setHorizontalScrollBar
setHorizontalScrollBarPolicy
setSizeAdjustPolicy
setVerticalScrollBar
setVerticalScrollBarPolicy
setViewport
setViewportMargins
setupViewport
sizeAdjustPolicy
verticalScrollBar
verticalScrollBarPolicy
viewport
viewportEvent
viewportMargins
wheelEvent
AdjustIgnored
AdjustToContents
AdjustToContentsOnFirstShow
PyQt5.QtWidgets.QFrame
Shadow
Shape
StyleMask
changeEvent
drawFrame
frameRect
frameShadow
frameShape
frameStyle
frameWidth
initStyleOption
lineWidth
midLineWidth
setFrameRect
setFrameShadow
setFrameShape
setFrameStyle
setLineWidth
setMidLineWidth
Box
HLine
NoFrame
Panel
Plain
Raised
Shadow_Mask
Shape_Mask
StyledPanel
Sunken
VLine
WinPanel
PyQt5.QtWidgets.QWidget
RenderFlag
RenderFlags
acceptDrops
accessibleDescription
accessibleName
actionEvent
actions
activateWindow
addAction
addActions
adjustSize
autoFillBackground
backgroundRole
baseSize
childAt
childrenRect
childrenRegion
clearFocus
clearMask
close
closeEvent
contentsMargins
contentsRect
contextMenuPolicy
create
createWindowContainer
cursor
destroy
devType
effectiveWinId
ensurePolished
enterEvent
find
focusInEvent
focusNextChild
focusOutEvent
focusPolicy
focusPreviousChild
focusProxy
focusWidget
font
fontInfo
fontMetrics
foregroundRole
frameGeometry
frameSize
geometry
getContentsMargins
grab
grabGesture
grabKeyboard
grabMouse
grabShortcut
graphicsEffect
graphicsProxyWidget
hasFocus
hasHeightForWidth
hasMouseTracking
hasTabletTracking
height
heightForWidth
hide
hideEvent
initPainter
inputMethodEvent
inputMethodHints
inputMethodQuery
insertAction
insertActions
isActiveWindow
isAncestorOf
isEnabled
isEnabledTo
isFullScreen
isHidden
isLeftToRight
isMaximized
isMinimized
isModal
isRightToLeft
isVisible
isVisibleTo
isWindow
isWindowModified
keyReleaseEvent
keyboardGrabber
layout
layoutDirection
leaveEvent
locale
lower
mapFrom
mapFromGlobal
mapFromParent
mapTo
mapToGlobal
mapToParent
mask
maximumHeight
maximumSize
maximumWidth
metric
minimumHeight
minimumSize
minimumWidth
mouseGrabber
move
moveEvent
nativeEvent
nativeParentWidget
nextInFocusChain
normalGeometry
overrideWindowFlags
overrideWindowState
paintEngine
palette
parentWidget
pos
previousInFocusChain
raise_
rect
releaseKeyboard
releaseMouse
releaseShortcut
removeAction
render
repaint
resize
restoreGeometry
saveGeometry
screen
scroll
setAcceptDrops
setAccessibleDescription
setAccessibleName
setAttribute
setAutoFillBackground
setBackgroundRole
setBaseSize
setContentsMargins
setContextMenuPolicy
setCursor
setDisabled
setEnabled
setFixedHeight
setFixedSize
setFixedWidth
setFocus
setFocusPolicy
setFocusProxy
setFont
setForegroundRole
setGeometry
setGraphicsEffect
setHidden
setInputMethodHints
setLayout
setLayoutDirection
setLocale
setMask
setMaximumHeight
setMaximumSize
setMaximumWidth
setMinimumHeight
setMinimumSize
setMinimumWidth
setMouseTracking
setPalette
setParent
setShortcutAutoRepeat
setShortcutEnabled
setSizeIncrement
setSizePolicy
setStatusTip
setStyle
setStyleSheet
setTabOrder
setTabletTracking
setToolTip
setToolTipDuration
setUpdatesEnabled
setVisible
setWhatsThis
setWindowFilePath
setWindowFlag
setWindowFlags
setWindowIcon
setWindowIconText
setWindowModality
setWindowModified
setWindowOpacity
setWindowRole
setWindowState
setWindowTitle
sharedPainter
show
showEvent
showFullScreen
showMaximized
showMinimized
showNormal
size
sizeIncrement
sizePolicy
stackUnder
statusTip
style
styleSheet
tabletEvent
testAttribute
toolTip
toolTipDuration
underMouse
ungrabGesture
unsetCursor
unsetLayoutDirection
unsetLocale
update
updateGeometry
updateMicroFocus
updatesEnabled
visibleRegion
whatsThis
width
winId
window
windowFilePath
windowFlags
windowHandle
windowIcon
windowIconText
windowModality
windowOpacity
windowRole
windowState
windowTitle
windowType
x
y
DrawChildren
DrawWindowBackground
IgnoreMask
windowIconTextChanged
windowIconChanged
windowTitleChanged
customContextMenuRequested
PyQt5.QtCore.QObject
blockSignals
childEvent
children
connectNotify
customEvent
deleteLater
disconnect
disconnectNotify
dumpObjectInfo
dumpObjectTree
dynamicPropertyNames
findChild
findChildren
inherits
installEventFilter
isSignalConnected
isWidgetType
isWindowType
killTimer
metaObject
moveToThread
objectName
parent
property
pyqtConfigure
receivers
removeEventFilter
sender
senderSignalIndex
setObjectName
setProperty
signalsBlocked
startTimer
thread
timerEvent
tr
staticMetaObject
objectNameChanged
destroyed
PyQt5.QtGui.QPaintDevice
PaintDeviceMetric
colorCount
depth
devicePixelRatio
devicePixelRatioF
devicePixelRatioFScale
heightMM
logicalDpiX
logicalDpiY
paintingActive
physicalDpiX
physicalDpiY
widthMM
PdmDepth
PdmDevicePixelRatio
PdmDevicePixelRatioScaled
PdmDpiX
PdmDpiY
PdmHeight
PdmHeightMM
PdmNumColors
PdmPhysicalDpiX
PdmPhysicalDpiY
PdmWidth
PdmWidthMM
def annotator_2d( image: numpy.ndarray, embedding_path: Union[str, Dict[str, Any], NoneType] = None, segmentation_result: Optional[numpy.ndarray] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, prefer_decoder: bool = True) -> Optional[napari.viewer.Viewer]:
 32def annotator_2d(
 33    image: np.ndarray,
 34    embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None,
 35    segmentation_result: Optional[np.ndarray] = None,
 36    model_type: str = util._DEFAULT_MODEL,
 37    tile_shape: Optional[Tuple[int, int]] = None,
 38    halo: Optional[Tuple[int, int]] = None,
 39    return_viewer: bool = False,
 40    viewer: Optional["napari.viewer.Viewer"] = None,
 41    precompute_amg_state: bool = False,
 42    checkpoint_path: Optional[str] = None,
 43    device: Optional[Union[str, torch.device]] = None,
 44    prefer_decoder: bool = True,
 45) -> Optional["napari.viewer.Viewer"]:
 46    """Start the 2d annotation tool for a given image.
 47
 48    Args:
 49        image: The image data.
 50        embedding_path: Filepath where to save the embeddings
 51            or the precompted image embeddings computed by `precompute_image_embeddings`.
 52        segmentation_result: An initial segmentation to load.
 53            This can be used to correct segmentations with Segment Anything or to save and load progress.
 54            The segmentation will be loaded as the 'committed_objects' layer.
 55        model_type: The Segment Anything model to use. For details on the available models check out
 56            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
 57        tile_shape: Shape of tiles for tiled embedding prediction.
 58            If `None` then the whole image is passed to Segment Anything.
 59        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
 60        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
 61        viewer: The viewer to which the Segment Anything functionality should be added.
 62            This enables using a pre-initialized viewer.
 63        precompute_amg_state: Whether to precompute the state for automatic mask generation.
 64            This will take more time when precomputing embeddings, but will then make
 65            automatic mask generation much faster.
 66        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
 67        device: The computational device to use for the SAM model.
 68        prefer_decoder: Whether to use decoder based instance segmentation if
 69            the model used has an additional decoder for instance segmentation.
 70
 71    Returns:
 72        The napari viewer, only returned if `return_viewer=True`.
 73    """
 74
 75    state = AnnotatorState()
 76    state.image_shape = image.shape[:-1] if image.ndim == 3 else image.shape
 77
 78    state.initialize_predictor(
 79        image, model_type=model_type, save_path=embedding_path,
 80        halo=halo, tile_shape=tile_shape, precompute_amg_state=precompute_amg_state,
 81        ndim=2, checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder,
 82        skip_load=False,
 83    )
 84
 85    if viewer is None:
 86        viewer = napari.Viewer()
 87
 88    viewer.add_image(image, name="image")
 89    annotator = Annotator2d(viewer)
 90
 91    # Trigger layer update of the annotator so that layers have the correct shape.
 92    # And initialize the 'committed_objects' with the segmentation result if it was given.
 93    annotator._update_image(segmentation_result=segmentation_result)
 94
 95    # Add the annotator widget to the viewer and sync widgets.
 96    viewer.window.add_dock_widget(annotator)
 97    _sync_embedding_widget(
 98        widget=state.widgets["embeddings"],
 99        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
100        save_path=embedding_path,
101        checkpoint_path=checkpoint_path,
102        device=device,
103        tile_shape=tile_shape,
104        halo=halo,
105    )
106
107    if return_viewer:
108        return viewer
109
110    napari.run()

Start the 2d annotation tool for a given image.

Arguments:
  • image: The image data.
  • embedding_path: Filepath where to save the embeddings or the precompted image embeddings computed by precompute_image_embeddings.
  • segmentation_result: An initial segmentation to load. This can be used to correct segmentations with Segment Anything or to save and load progress. The segmentation will be loaded as the 'committed_objects' layer.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • device: The computational device to use for the SAM model.
  • prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation.
Returns:

The napari viewer, only returned if return_viewer=True.