micro_sam.sam_annotator.annotator_tracking

  1from typing import Optional, Tuple, Union
  2
  3import napari
  4import numpy as np
  5
  6import torch
  7
  8from magicgui.widgets import ComboBox, Container
  9
 10from ._annotator import _AnnotatorBase
 11from ._state import AnnotatorState
 12from . import util as vutil
 13from ._tooltips import get_tooltip
 14from . import _widgets as widgets
 15from .. import util
 16
 17# Cyan (track) and Magenta (division)
 18STATE_COLOR_CYCLE = ["#00FFFF", "#FF00FF", ]
 19"""@private"""
 20
 21
 22# This solution is a bit hacky, so I won't move it to _widgets.py yet.
 23def create_tracking_menu(points_layer, box_layer, states, track_ids):
 24    """@private"""
 25    state = AnnotatorState()
 26
 27    state_menu = ComboBox(label="track_state", choices=states, tooltip=get_tooltip("annotator_tracking", "track_state"))
 28    track_id_menu = ComboBox(label="track_id", choices=list(map(str, track_ids)),
 29                             tooltip=get_tooltip("annotator_tracking", "track_id"))
 30    tracking_widget = Container(widgets=[state_menu, track_id_menu])
 31
 32    def update_state(event):
 33        new_state = str(points_layer.current_properties["state"][0])
 34        if new_state != state_menu.value:
 35            state_menu.value = new_state
 36
 37    def update_track_id(event):
 38        new_id = str(points_layer.current_properties["track_id"][0])
 39        if new_id != track_id_menu.value:
 40            track_id_menu.value = new_id
 41            state.current_track_id = int(new_id)
 42
 43    # def update_state_boxes(event):
 44    #     new_state = str(box_layer.current_properties["state"][0])
 45    #     if new_state != state_menu.value:
 46    #         state_menu.value = new_state
 47
 48    def update_track_id_boxes(event):
 49        new_id = str(box_layer.current_properties["track_id"][0])
 50        if new_id != track_id_menu.value:
 51            track_id_menu.value = new_id
 52            state.current_track_id = int(new_id)
 53
 54    points_layer.events.current_properties.connect(update_state)
 55    points_layer.events.current_properties.connect(update_track_id)
 56    # box_layer.events.current_properties.connect(update_state_boxes)
 57    box_layer.events.current_properties.connect(update_track_id_boxes)
 58
 59    def state_changed(new_state):
 60        current_properties = points_layer.current_properties
 61        current_properties["state"] = np.array([new_state])
 62        points_layer.current_properties = current_properties
 63        points_layer.refresh_colors()
 64
 65    def track_id_changed(new_track_id):
 66        current_properties = points_layer.current_properties
 67        current_properties["track_id"] = np.array([new_track_id])
 68        # Note: this fails with a key error after committing a lineage with multiple tracks.
 69        # I think this does not cause any further errors, so we just skip this.
 70        try:
 71            points_layer.current_properties = current_properties
 72        except KeyError:
 73            pass
 74        state.current_track_id = int(new_track_id)
 75
 76    # def state_changed_boxes(new_state):
 77    #     current_properties = box_layer.current_properties
 78    #     current_properties["state"] = np.array([new_state])
 79    #     box_layer.current_properties = current_properties
 80    #     box_layer.refresh_colors()
 81
 82    def track_id_changed_boxes(new_track_id):
 83        current_properties = box_layer.current_properties
 84        current_properties["track_id"] = np.array([new_track_id])
 85        box_layer.current_properties = current_properties
 86        state.current_track_id = int(new_track_id)
 87
 88    state_menu.changed.connect(state_changed)
 89    track_id_menu.changed.connect(track_id_changed)
 90    # state_menu.changed.connect(state_changed_boxes)
 91    track_id_menu.changed.connect(track_id_changed_boxes)
 92
 93    state_menu.set_choice("track")
 94    return tracking_widget
 95
 96
 97class AnnotatorTracking(_AnnotatorBase):
 98
 99    # The tracking annotator needs different settings for the prompt layers
100    # to support the additional tracking state.
101    # That's why we over-ride this function.
102    def _create_layers(self):
103        self._point_labels = ["positive", "negative"]
104        self._track_state_labels = ["track", "division"]
105
106        self._point_prompt_layer = self._viewer.add_points(
107            name="point_prompts",
108            property_choices={
109                "label": self._point_labels,
110                "state": self._track_state_labels,
111                "track_id": ["1"],  # we use string to avoid pandas warning
112            },
113            border_color="label",
114            border_color_cycle=vutil.LABEL_COLOR_CYCLE,
115            symbol="o",
116            face_color="state",
117            face_color_cycle=STATE_COLOR_CYCLE,
118            border_width=0.4,
119            size=12,
120            ndim=self._ndim,
121        )
122        self._point_prompt_layer.border_color_mode = "cycle"
123        self._point_prompt_layer.face_color_mode = "cycle"
124
125        # Using the box layer to set divisions currently doesn't work.
126        # That's why some of the code below is commented out.
127        self._box_prompt_layer = self._viewer.add_shapes(
128            shape_type="rectangle",
129            edge_width=4,
130            ndim=self._ndim,
131            face_color="transparent",
132            name="prompts",
133            edge_color="green",
134            property_choices={"track_id": ["1"]},
135            # property_choces={"track_id": ["1"], "state": self._track_state_labels},
136            # edge_color_cycle=STATE_COLOR_CYCLE,
137        )
138        # self._box_prompt_layer.edge_color_mode = "cycle"
139
140        # Add the label layers for the current object, the automatic segmentation and the committed segmentation.
141        dummy_data = np.zeros(self._shape, dtype="uint32")
142        self._viewer.add_labels(data=dummy_data, name="current_object")
143        self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
144        self._viewer.add_labels(data=dummy_data, name="committed_objects")
145        # Randomize colors so it is easy to see when object committed.
146        self._viewer.layers["committed_objects"].new_colormap()
147
148    def _get_widgets(self):
149        state = AnnotatorState()
150        # Create the tracking state menu.
151        self._tracking_widget = create_tracking_menu(
152            self._point_prompt_layer, self._box_prompt_layer,
153            states=self._track_state_labels, track_ids=list(state.lineage.keys()),
154        )
155        segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=True)
156        return {
157            "tracking": self._tracking_widget,
158            "segment": widgets.segment_frame(),
159            "segment_nd": segment_nd,
160            "commit": widgets.commit_track(),
161            "clear": widgets.clear_track(),
162        }
163
164    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
165        # Initialize the state for tracking.
166        self._init_track_state()
167        super().__init__(viewer=viewer, ndim=3)
168        # Go to t=0.
169        self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:])
170
171    def _init_track_state(self):
172        state = AnnotatorState()
173        state.current_track_id = 1
174        state.lineage = {1: []}
175        state.committed_lineages = []
176
177    def _update_image(self):
178        super()._update_image()
179        self._init_track_state()
180
181
182def annotator_tracking(
183    image: np.ndarray,
184    embedding_path: Optional[str] = None,
185    # tracking_result: Optional[str] = None,
186    model_type: str = util._DEFAULT_MODEL,
187    tile_shape: Optional[Tuple[int, int]] = None,
188    halo: Optional[Tuple[int, int]] = None,
189    return_viewer: bool = False,
190    viewer: Optional["napari.viewer.Viewer"] = None,
191    checkpoint_path: Optional[str] = None,
192    device: Optional[Union[str, torch.device]] = None,
193) -> Optional["napari.viewer.Viewer"]:
194    """Start the tracking annotation tool fora given timeseries.
195
196    Args:
197        raw: The image data.
198        embedding_path: Filepath for saving the precomputed embeddings.
199        model_type: The Segment Anything model to use. For details on the available models check out
200            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
201        tile_shape: Shape of tiles for tiled embedding prediction.
202            If `None` then the whole image is passed to Segment Anything.
203        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
204        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
205        viewer: The viewer to which the SegmentAnything functionality should be added.
206            This enables using a pre-initialized viewer.
207        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
208        device: The computational device to use for the SAM model.
209
210    Returns:
211        The napari viewer, only returned if `return_viewer=True`.
212    """
213
214    # TODO update this to match the new annotator design
215    # Initialize the predictor state.
216    state = AnnotatorState()
217    state.initialize_predictor(
218        image, model_type=model_type, save_path=embedding_path,
219        halo=halo, tile_shape=tile_shape, prefer_decoder=False,
220        ndim=3, checkpoint_path=checkpoint_path, device=device,
221    )
222    state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape
223
224    if viewer is None:
225        viewer = napari.Viewer()
226
227    viewer.add_image(image, name="image")
228    annotator = AnnotatorTracking(viewer)
229
230    # Trigger layer update of the annotator so that layers have the correct shape.
231    annotator._update_image()
232
233    # Add the annotator widget to the viewer and sync widgets.
234    viewer.window.add_dock_widget(annotator)
235    vutil._sync_embedding_widget(
236        state.widgets["embeddings"], model_type,
237        save_path=embedding_path, checkpoint_path=checkpoint_path,
238        device=device, tile_shape=tile_shape, halo=halo
239    )
240
241    if return_viewer:
242        return viewer
243
244    napari.run()
245
246
247def main():
248    """@private"""
249    parser = vutil._initialize_parser(
250        description="Run interactive segmentation for an image volume.",
251        with_segmentation_result=False,
252        with_instance_segmentation=False,
253    )
254
255    # Tracking result is not yet supported, we need to also deserialize the lineage.
256    # parser.add_argument(
257    #     "-t", "--tracking_result",
258    #     help="Optional filepath to a precomputed tracking result. If passed this will be used to initialize the "
259    #     "'committed_tracks' layer. This can be useful if you want to correct an existing tracking result or if you "
260    #     "have saved intermediate results from the annotator and want to continue. "
261    #     "Supports the same file formats as 'input'."
262    # )
263    # parser.add_argument(
264    #     "-tk", "--tracking_key",
265    #     help="The key for opening the tracking result. Same rules as for 'key' apply."
266    # )
267
268    args = parser.parse_args()
269    image = util.load_image_data(args.input, key=args.key)
270
271    annotator_tracking(
272        image, embedding_path=args.embedding_path, model_type=args.model_type,
273        tile_shape=args.tile_shape, halo=args.halo,
274        checkpoint_path=args.checkpoint, device=args.device,
275    )
class AnnotatorTracking(micro_sam.sam_annotator._annotator._AnnotatorBase):
 98class AnnotatorTracking(_AnnotatorBase):
 99
100    # The tracking annotator needs different settings for the prompt layers
101    # to support the additional tracking state.
102    # That's why we over-ride this function.
103    def _create_layers(self):
104        self._point_labels = ["positive", "negative"]
105        self._track_state_labels = ["track", "division"]
106
107        self._point_prompt_layer = self._viewer.add_points(
108            name="point_prompts",
109            property_choices={
110                "label": self._point_labels,
111                "state": self._track_state_labels,
112                "track_id": ["1"],  # we use string to avoid pandas warning
113            },
114            border_color="label",
115            border_color_cycle=vutil.LABEL_COLOR_CYCLE,
116            symbol="o",
117            face_color="state",
118            face_color_cycle=STATE_COLOR_CYCLE,
119            border_width=0.4,
120            size=12,
121            ndim=self._ndim,
122        )
123        self._point_prompt_layer.border_color_mode = "cycle"
124        self._point_prompt_layer.face_color_mode = "cycle"
125
126        # Using the box layer to set divisions currently doesn't work.
127        # That's why some of the code below is commented out.
128        self._box_prompt_layer = self._viewer.add_shapes(
129            shape_type="rectangle",
130            edge_width=4,
131            ndim=self._ndim,
132            face_color="transparent",
133            name="prompts",
134            edge_color="green",
135            property_choices={"track_id": ["1"]},
136            # property_choces={"track_id": ["1"], "state": self._track_state_labels},
137            # edge_color_cycle=STATE_COLOR_CYCLE,
138        )
139        # self._box_prompt_layer.edge_color_mode = "cycle"
140
141        # Add the label layers for the current object, the automatic segmentation and the committed segmentation.
142        dummy_data = np.zeros(self._shape, dtype="uint32")
143        self._viewer.add_labels(data=dummy_data, name="current_object")
144        self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
145        self._viewer.add_labels(data=dummy_data, name="committed_objects")
146        # Randomize colors so it is easy to see when object committed.
147        self._viewer.layers["committed_objects"].new_colormap()
148
149    def _get_widgets(self):
150        state = AnnotatorState()
151        # Create the tracking state menu.
152        self._tracking_widget = create_tracking_menu(
153            self._point_prompt_layer, self._box_prompt_layer,
154            states=self._track_state_labels, track_ids=list(state.lineage.keys()),
155        )
156        segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=True)
157        return {
158            "tracking": self._tracking_widget,
159            "segment": widgets.segment_frame(),
160            "segment_nd": segment_nd,
161            "commit": widgets.commit_track(),
162            "clear": widgets.clear_track(),
163        }
164
165    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
166        # Initialize the state for tracking.
167        self._init_track_state()
168        super().__init__(viewer=viewer, ndim=3)
169        # Go to t=0.
170        self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:])
171
172    def _init_track_state(self):
173        state = AnnotatorState()
174        state.current_track_id = 1
175        state.lineage = {1: []}
176        state.committed_lineages = []
177
178    def _update_image(self):
179        super()._update_image()
180        self._init_track_state()

Base class for micro_sam annotation plugins.

Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.

AnnotatorTracking(viewer: napari.viewer.Viewer)
165    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
166        # Initialize the state for tracking.
167        self._init_track_state()
168        super().__init__(viewer=viewer, ndim=3)
169        # Go to t=0.
170        self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:])

Create the annotator GUI.

Arguments:
  • viewer: The napari viewer.
  • ndim: The number of spatial dimension of the image data (2 or 3).
Inherited Members
PyQt5.QtWidgets.QScrollArea
alignment
ensureVisible
ensureWidgetVisible
event
eventFilter
focusNextPrevChild
resizeEvent
scrollContentsBy
setAlignment
setWidget
setWidgetResizable
sizeHint
takeWidget
viewportSizeHint
widget
widgetResizable
PyQt5.QtWidgets.QAbstractScrollArea
SizeAdjustPolicy
addScrollBarWidget
contextMenuEvent
cornerWidget
dragEnterEvent
dragLeaveEvent
dragMoveEvent
dropEvent
horizontalScrollBar
horizontalScrollBarPolicy
keyPressEvent
maximumViewportSize
minimumSizeHint
mouseDoubleClickEvent
mouseMoveEvent
mousePressEvent
mouseReleaseEvent
paintEvent
scrollBarWidgets
setCornerWidget
setHorizontalScrollBar
setHorizontalScrollBarPolicy
setSizeAdjustPolicy
setVerticalScrollBar
setVerticalScrollBarPolicy
setViewport
setViewportMargins
setupViewport
sizeAdjustPolicy
verticalScrollBar
verticalScrollBarPolicy
viewport
viewportEvent
viewportMargins
wheelEvent
AdjustIgnored
AdjustToContents
AdjustToContentsOnFirstShow
PyQt5.QtWidgets.QFrame
Shadow
Shape
StyleMask
changeEvent
drawFrame
frameRect
frameShadow
frameShape
frameStyle
frameWidth
initStyleOption
lineWidth
midLineWidth
setFrameRect
setFrameShadow
setFrameShape
setFrameStyle
setLineWidth
setMidLineWidth
Box
HLine
NoFrame
Panel
Plain
Raised
Shadow_Mask
Shape_Mask
StyledPanel
Sunken
VLine
WinPanel
PyQt5.QtWidgets.QWidget
RenderFlag
RenderFlags
acceptDrops
accessibleDescription
accessibleName
actionEvent
actions
activateWindow
addAction
addActions
adjustSize
autoFillBackground
backgroundRole
baseSize
childAt
childrenRect
childrenRegion
clearFocus
clearMask
close
closeEvent
contentsMargins
contentsRect
contextMenuPolicy
create
createWindowContainer
cursor
destroy
devType
effectiveWinId
ensurePolished
enterEvent
find
focusInEvent
focusNextChild
focusOutEvent
focusPolicy
focusPreviousChild
focusProxy
focusWidget
font
fontInfo
fontMetrics
foregroundRole
frameGeometry
frameSize
geometry
getContentsMargins
grab
grabGesture
grabKeyboard
grabMouse
grabShortcut
graphicsEffect
graphicsProxyWidget
hasFocus
hasHeightForWidth
hasMouseTracking
hasTabletTracking
height
heightForWidth
hide
hideEvent
initPainter
inputMethodEvent
inputMethodHints
inputMethodQuery
insertAction
insertActions
isActiveWindow
isAncestorOf
isEnabled
isEnabledTo
isFullScreen
isHidden
isLeftToRight
isMaximized
isMinimized
isModal
isRightToLeft
isVisible
isVisibleTo
isWindow
isWindowModified
keyReleaseEvent
keyboardGrabber
layout
layoutDirection
leaveEvent
locale
lower
mapFrom
mapFromGlobal
mapFromParent
mapTo
mapToGlobal
mapToParent
mask
maximumHeight
maximumSize
maximumWidth
metric
minimumHeight
minimumSize
minimumWidth
mouseGrabber
move
moveEvent
nativeEvent
nativeParentWidget
nextInFocusChain
normalGeometry
overrideWindowFlags
overrideWindowState
paintEngine
palette
parentWidget
pos
previousInFocusChain
raise_
rect
releaseKeyboard
releaseMouse
releaseShortcut
removeAction
render
repaint
resize
restoreGeometry
saveGeometry
screen
scroll
setAcceptDrops
setAccessibleDescription
setAccessibleName
setAttribute
setAutoFillBackground
setBackgroundRole
setBaseSize
setContentsMargins
setContextMenuPolicy
setCursor
setDisabled
setEnabled
setFixedHeight
setFixedSize
setFixedWidth
setFocus
setFocusPolicy
setFocusProxy
setFont
setForegroundRole
setGeometry
setGraphicsEffect
setHidden
setInputMethodHints
setLayout
setLayoutDirection
setLocale
setMask
setMaximumHeight
setMaximumSize
setMaximumWidth
setMinimumHeight
setMinimumSize
setMinimumWidth
setMouseTracking
setPalette
setParent
setShortcutAutoRepeat
setShortcutEnabled
setSizeIncrement
setSizePolicy
setStatusTip
setStyle
setStyleSheet
setTabOrder
setTabletTracking
setToolTip
setToolTipDuration
setUpdatesEnabled
setVisible
setWhatsThis
setWindowFilePath
setWindowFlag
setWindowFlags
setWindowIcon
setWindowIconText
setWindowModality
setWindowModified
setWindowOpacity
setWindowRole
setWindowState
setWindowTitle
sharedPainter
show
showEvent
showFullScreen
showMaximized
showMinimized
showNormal
size
sizeIncrement
sizePolicy
stackUnder
statusTip
style
styleSheet
tabletEvent
testAttribute
toolTip
toolTipDuration
underMouse
ungrabGesture
unsetCursor
unsetLayoutDirection
unsetLocale
update
updateGeometry
updateMicroFocus
updatesEnabled
visibleRegion
whatsThis
width
winId
window
windowFilePath
windowFlags
windowHandle
windowIcon
windowIconText
windowModality
windowOpacity
windowRole
windowState
windowTitle
windowType
x
y
DrawChildren
DrawWindowBackground
IgnoreMask
windowIconTextChanged
windowIconChanged
windowTitleChanged
customContextMenuRequested
PyQt5.QtCore.QObject
blockSignals
childEvent
children
connectNotify
customEvent
deleteLater
disconnect
disconnectNotify
dumpObjectInfo
dumpObjectTree
dynamicPropertyNames
findChild
findChildren
inherits
installEventFilter
isSignalConnected
isWidgetType
isWindowType
killTimer
metaObject
moveToThread
objectName
parent
property
pyqtConfigure
receivers
removeEventFilter
sender
senderSignalIndex
setObjectName
setProperty
signalsBlocked
startTimer
thread
timerEvent
tr
staticMetaObject
objectNameChanged
destroyed
PyQt5.QtGui.QPaintDevice
PaintDeviceMetric
colorCount
depth
devicePixelRatio
devicePixelRatioF
devicePixelRatioFScale
heightMM
logicalDpiX
logicalDpiY
paintingActive
physicalDpiX
physicalDpiY
widthMM
PdmDepth
PdmDevicePixelRatio
PdmDevicePixelRatioScaled
PdmDpiX
PdmDpiY
PdmHeight
PdmHeightMM
PdmNumColors
PdmPhysicalDpiX
PdmPhysicalDpiY
PdmWidth
PdmWidthMM
def annotator_tracking( image: numpy.ndarray, embedding_path: Optional[str] = None, model_type: str = 'vit_l', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None) -> Optional[napari.viewer.Viewer]:
183def annotator_tracking(
184    image: np.ndarray,
185    embedding_path: Optional[str] = None,
186    # tracking_result: Optional[str] = None,
187    model_type: str = util._DEFAULT_MODEL,
188    tile_shape: Optional[Tuple[int, int]] = None,
189    halo: Optional[Tuple[int, int]] = None,
190    return_viewer: bool = False,
191    viewer: Optional["napari.viewer.Viewer"] = None,
192    checkpoint_path: Optional[str] = None,
193    device: Optional[Union[str, torch.device]] = None,
194) -> Optional["napari.viewer.Viewer"]:
195    """Start the tracking annotation tool fora given timeseries.
196
197    Args:
198        raw: The image data.
199        embedding_path: Filepath for saving the precomputed embeddings.
200        model_type: The Segment Anything model to use. For details on the available models check out
201            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
202        tile_shape: Shape of tiles for tiled embedding prediction.
203            If `None` then the whole image is passed to Segment Anything.
204        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
205        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
206        viewer: The viewer to which the SegmentAnything functionality should be added.
207            This enables using a pre-initialized viewer.
208        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
209        device: The computational device to use for the SAM model.
210
211    Returns:
212        The napari viewer, only returned if `return_viewer=True`.
213    """
214
215    # TODO update this to match the new annotator design
216    # Initialize the predictor state.
217    state = AnnotatorState()
218    state.initialize_predictor(
219        image, model_type=model_type, save_path=embedding_path,
220        halo=halo, tile_shape=tile_shape, prefer_decoder=False,
221        ndim=3, checkpoint_path=checkpoint_path, device=device,
222    )
223    state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape
224
225    if viewer is None:
226        viewer = napari.Viewer()
227
228    viewer.add_image(image, name="image")
229    annotator = AnnotatorTracking(viewer)
230
231    # Trigger layer update of the annotator so that layers have the correct shape.
232    annotator._update_image()
233
234    # Add the annotator widget to the viewer and sync widgets.
235    viewer.window.add_dock_widget(annotator)
236    vutil._sync_embedding_widget(
237        state.widgets["embeddings"], model_type,
238        save_path=embedding_path, checkpoint_path=checkpoint_path,
239        device=device, tile_shape=tile_shape, halo=halo
240    )
241
242    if return_viewer:
243        return viewer
244
245    napari.run()

Start the tracking annotation tool fora given timeseries.

Arguments:
  • raw: The image data.
  • embedding_path: Filepath for saving the precomputed embeddings.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
  • viewer: The viewer to which the SegmentAnything functionality should be added. This enables using a pre-initialized viewer.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • device: The computational device to use for the SAM model.
Returns:

The napari viewer, only returned if return_viewer=True.