micro_sam.sam_annotator.annotator_tracking

  1from typing import Optional, Tuple, Union
  2
  3import napari
  4import numpy as np
  5
  6import torch
  7
  8from magicgui.widgets import ComboBox, Container
  9
 10from .. import util
 11from . import util as vutil
 12from . import _widgets as widgets
 13from ._tooltips import get_tooltip
 14from ._state import AnnotatorState
 15from ._annotator import _AnnotatorBase
 16
 17
 18# Cyan (track) and Magenta (division)
 19STATE_COLOR_CYCLE = ["#00FFFF", "#FF00FF", ]
 20"""@private"""
 21
 22
 23# This solution is a bit hacky, so I won't move it to _widgets.py yet.
 24def create_tracking_menu(points_layer, box_layer, states, track_ids):
 25    """@private"""
 26    state = AnnotatorState()
 27
 28    state_menu = ComboBox(label="track_state", choices=states, tooltip=get_tooltip("annotator_tracking", "track_state"))
 29    track_id_menu = ComboBox(
 30        label="track_id", choices=list(map(str, track_ids)), tooltip=get_tooltip("annotator_tracking", "track_id")
 31    )
 32    tracking_widget = Container(widgets=[state_menu, track_id_menu])
 33
 34    def update_state(event):
 35        new_state = str(points_layer.current_properties["state"][0])
 36        if new_state != state_menu.value:
 37            state_menu.value = new_state
 38
 39    def update_track_id(event):
 40        new_id = str(points_layer.current_properties["track_id"][0])
 41        if new_id != track_id_menu.value:
 42            track_id_menu.value = new_id
 43            state.current_track_id = int(new_id)
 44
 45    # def update_state_boxes(event):
 46    #     new_state = str(box_layer.current_properties["state"][0])
 47    #     if new_state != state_menu.value:
 48    #         state_menu.value = new_state
 49
 50    def update_track_id_boxes(event):
 51        new_id = str(box_layer.current_properties["track_id"][0])
 52        if new_id != track_id_menu.value:
 53            track_id_menu.value = new_id
 54            state.current_track_id = int(new_id)
 55
 56    points_layer.events.current_properties.connect(update_state)
 57    points_layer.events.current_properties.connect(update_track_id)
 58    # box_layer.events.current_properties.connect(update_state_boxes)
 59    box_layer.events.current_properties.connect(update_track_id_boxes)
 60
 61    def state_changed(new_state):
 62        current_properties = points_layer.current_properties
 63        current_properties["state"] = np.array([new_state])
 64        points_layer.current_properties = current_properties
 65        points_layer.refresh_colors()
 66
 67    def track_id_changed(new_track_id):
 68        current_properties = points_layer.current_properties
 69        current_properties["track_id"] = np.array([new_track_id])
 70        # Note: this fails with a key error after committing a lineage with multiple tracks.
 71        # I think this does not cause any further errors, so we just skip this.
 72        try:
 73            points_layer.current_properties = current_properties
 74        except KeyError:
 75            pass
 76        state.current_track_id = int(new_track_id)
 77
 78    # def state_changed_boxes(new_state):
 79    #     current_properties = box_layer.current_properties
 80    #     current_properties["state"] = np.array([new_state])
 81    #     box_layer.current_properties = current_properties
 82    #     box_layer.refresh_colors()
 83
 84    def track_id_changed_boxes(new_track_id):
 85        current_properties = box_layer.current_properties
 86        current_properties["track_id"] = np.array([new_track_id])
 87        box_layer.current_properties = current_properties
 88        state.current_track_id = int(new_track_id)
 89
 90    state_menu.changed.connect(state_changed)
 91    track_id_menu.changed.connect(track_id_changed)
 92    # state_menu.changed.connect(state_changed_boxes)
 93    track_id_menu.changed.connect(track_id_changed_boxes)
 94
 95    state_menu.set_choice("track")
 96    return tracking_widget
 97
 98
 99class AnnotatorTracking(_AnnotatorBase):
100
101    # The tracking annotator needs different settings for the prompt layers
102    # to support the additional tracking state.
103    # That's why we over-ride this function.
104    def _create_layers(self):
105        self._point_labels = ["positive", "negative"]
106        self._track_state_labels = ["track", "division"]
107
108        self._point_prompt_layer = self._viewer.add_points(
109            name="point_prompts",
110            property_choices={
111                "label": self._point_labels,
112                "state": self._track_state_labels,
113                "track_id": ["1"],  # we use string to avoid pandas warning
114            },
115            border_color="label",
116            border_color_cycle=vutil.LABEL_COLOR_CYCLE,
117            symbol="o",
118            face_color="state",
119            face_color_cycle=STATE_COLOR_CYCLE,
120            border_width=0.4,
121            size=12,
122            ndim=self._ndim,
123        )
124        self._point_prompt_layer.border_color_mode = "cycle"
125        self._point_prompt_layer.face_color_mode = "cycle"
126
127        # Using the box layer to set divisions currently doesn't work.
128        # That's why some of the code below is commented out.
129        self._box_prompt_layer = self._viewer.add_shapes(
130            shape_type="rectangle",
131            edge_width=4,
132            ndim=self._ndim,
133            face_color="transparent",
134            name="prompts",
135            edge_color="green",
136            property_choices={"track_id": ["1"]},
137            # property_choces={"track_id": ["1"], "state": self._track_state_labels},
138            # edge_color_cycle=STATE_COLOR_CYCLE,
139        )
140        # self._box_prompt_layer.edge_color_mode = "cycle"
141
142        # Add the label layers for the current object, the automatic segmentation and the committed segmentation.
143        dummy_data = np.zeros(self._shape, dtype="uint32")
144        self._viewer.add_labels(data=dummy_data, name="current_object")
145        self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
146        self._viewer.add_labels(data=dummy_data, name="committed_objects")
147        # Randomize colors so it is easy to see when object committed.
148        self._viewer.layers["committed_objects"].new_colormap()
149
150    def _get_widgets(self):
151        state = AnnotatorState()
152        # Create the tracking state menu.
153        self._tracking_widget = create_tracking_menu(
154            self._point_prompt_layer, self._box_prompt_layer,
155            states=self._track_state_labels, track_ids=list(state.lineage.keys()),
156        )
157        segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=True)
158        return {
159            "tracking": self._tracking_widget,
160            "segment": widgets.segment_frame(),
161            "segment_nd": segment_nd,
162            "commit": widgets.commit_track(),
163            "clear": widgets.clear_track(),
164        }
165
166    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
167        # Initialize the state for tracking.
168        self._init_track_state()
169        super().__init__(viewer=viewer, ndim=3)
170        # Go to t=0.
171        self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:])
172
173    def _init_track_state(self):
174        state = AnnotatorState()
175        state.current_track_id = 1
176        state.lineage = {1: []}
177        state.committed_lineages = []
178
179    def _update_image(self):
180        super()._update_image()
181        self._init_track_state()
182
183
184def annotator_tracking(
185    image: np.ndarray,
186    embedding_path: Optional[str] = None,
187    # tracking_result: Optional[str] = None,
188    model_type: str = util._DEFAULT_MODEL,
189    tile_shape: Optional[Tuple[int, int]] = None,
190    halo: Optional[Tuple[int, int]] = None,
191    return_viewer: bool = False,
192    viewer: Optional["napari.viewer.Viewer"] = None,
193    checkpoint_path: Optional[str] = None,
194    device: Optional[Union[str, torch.device]] = None,
195) -> Optional["napari.viewer.Viewer"]:
196    """Start the tracking annotation tool fora given timeseries.
197
198    Args:
199        image: The image data.
200        embedding_path: Filepath for saving the precomputed embeddings.
201        model_type: The Segment Anything model to use. For details on the available models check out
202            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
203        tile_shape: Shape of tiles for tiled embedding prediction.
204            If `None` then the whole image is passed to Segment Anything.
205        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
206        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
207        viewer: The viewer to which the SegmentAnything functionality should be added.
208            This enables using a pre-initialized viewer.
209        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
210        device: The computational device to use for the SAM model.
211
212    Returns:
213        The napari viewer, only returned if `return_viewer=True`.
214    """
215
216    # TODO update this to match the new annotator design
217    # Initialize the predictor state.
218    state = AnnotatorState()
219    state.initialize_predictor(
220        image, model_type=model_type, save_path=embedding_path,
221        halo=halo, tile_shape=tile_shape, prefer_decoder=False,
222        ndim=3, checkpoint_path=checkpoint_path, device=device,
223    )
224    state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape
225
226    if viewer is None:
227        viewer = napari.Viewer()
228
229    viewer.add_image(image, name="image")
230    annotator = AnnotatorTracking(viewer)
231
232    # Trigger layer update of the annotator so that layers have the correct shape.
233    annotator._update_image()
234
235    # Add the annotator widget to the viewer and sync widgets.
236    viewer.window.add_dock_widget(annotator)
237    vutil._sync_embedding_widget(
238        state.widgets["embeddings"], model_type,
239        save_path=embedding_path, checkpoint_path=checkpoint_path,
240        device=device, tile_shape=tile_shape, halo=halo
241    )
242
243    if return_viewer:
244        return viewer
245
246    napari.run()
247
248
249def main():
250    """@private"""
251    parser = vutil._initialize_parser(
252        description="Run interactive segmentation for an image volume.",
253        with_segmentation_result=False,
254        with_instance_segmentation=False,
255    )
256
257    # Tracking result is not yet supported, we need to also deserialize the lineage.
258    # parser.add_argument(
259    #     "-t", "--tracking_result",
260    #     help="Optional filepath to a precomputed tracking result. If passed this will be used to initialize the "
261    #     "'committed_tracks' layer. This can be useful if you want to correct an existing tracking result or if you "
262    #     "have saved intermediate results from the annotator and want to continue. "
263    #     "Supports the same file formats as 'input'."
264    # )
265    # parser.add_argument(
266    #     "-tk", "--tracking_key",
267    #     help="The key for opening the tracking result. Same rules as for 'key' apply."
268    # )
269
270    args = parser.parse_args()
271    image = util.load_image_data(args.input, key=args.key)
272
273    annotator_tracking(
274        image, embedding_path=args.embedding_path, model_type=args.model_type,
275        tile_shape=args.tile_shape, halo=args.halo,
276        checkpoint_path=args.checkpoint, device=args.device,
277    )
class AnnotatorTracking(micro_sam.sam_annotator._annotator._AnnotatorBase):
100class AnnotatorTracking(_AnnotatorBase):
101
102    # The tracking annotator needs different settings for the prompt layers
103    # to support the additional tracking state.
104    # That's why we over-ride this function.
105    def _create_layers(self):
106        self._point_labels = ["positive", "negative"]
107        self._track_state_labels = ["track", "division"]
108
109        self._point_prompt_layer = self._viewer.add_points(
110            name="point_prompts",
111            property_choices={
112                "label": self._point_labels,
113                "state": self._track_state_labels,
114                "track_id": ["1"],  # we use string to avoid pandas warning
115            },
116            border_color="label",
117            border_color_cycle=vutil.LABEL_COLOR_CYCLE,
118            symbol="o",
119            face_color="state",
120            face_color_cycle=STATE_COLOR_CYCLE,
121            border_width=0.4,
122            size=12,
123            ndim=self._ndim,
124        )
125        self._point_prompt_layer.border_color_mode = "cycle"
126        self._point_prompt_layer.face_color_mode = "cycle"
127
128        # Using the box layer to set divisions currently doesn't work.
129        # That's why some of the code below is commented out.
130        self._box_prompt_layer = self._viewer.add_shapes(
131            shape_type="rectangle",
132            edge_width=4,
133            ndim=self._ndim,
134            face_color="transparent",
135            name="prompts",
136            edge_color="green",
137            property_choices={"track_id": ["1"]},
138            # property_choces={"track_id": ["1"], "state": self._track_state_labels},
139            # edge_color_cycle=STATE_COLOR_CYCLE,
140        )
141        # self._box_prompt_layer.edge_color_mode = "cycle"
142
143        # Add the label layers for the current object, the automatic segmentation and the committed segmentation.
144        dummy_data = np.zeros(self._shape, dtype="uint32")
145        self._viewer.add_labels(data=dummy_data, name="current_object")
146        self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
147        self._viewer.add_labels(data=dummy_data, name="committed_objects")
148        # Randomize colors so it is easy to see when object committed.
149        self._viewer.layers["committed_objects"].new_colormap()
150
151    def _get_widgets(self):
152        state = AnnotatorState()
153        # Create the tracking state menu.
154        self._tracking_widget = create_tracking_menu(
155            self._point_prompt_layer, self._box_prompt_layer,
156            states=self._track_state_labels, track_ids=list(state.lineage.keys()),
157        )
158        segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=True)
159        return {
160            "tracking": self._tracking_widget,
161            "segment": widgets.segment_frame(),
162            "segment_nd": segment_nd,
163            "commit": widgets.commit_track(),
164            "clear": widgets.clear_track(),
165        }
166
167    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
168        # Initialize the state for tracking.
169        self._init_track_state()
170        super().__init__(viewer=viewer, ndim=3)
171        # Go to t=0.
172        self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:])
173
174    def _init_track_state(self):
175        state = AnnotatorState()
176        state.current_track_id = 1
177        state.lineage = {1: []}
178        state.committed_lineages = []
179
180    def _update_image(self):
181        super()._update_image()
182        self._init_track_state()

Base class for micro_sam annotation plugins.

Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.

AnnotatorTracking(viewer: napari.viewer.Viewer)
167    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
168        # Initialize the state for tracking.
169        self._init_track_state()
170        super().__init__(viewer=viewer, ndim=3)
171        # Go to t=0.
172        self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:])

Create the annotator GUI.

Arguments:
  • viewer: The napari viewer.
  • ndim: The number of spatial dimension of the image data (2 or 3).
Inherited Members
PyQt5.QtWidgets.QScrollArea
alignment
ensureVisible
ensureWidgetVisible
event
eventFilter
focusNextPrevChild
resizeEvent
scrollContentsBy
setAlignment
setWidget
setWidgetResizable
sizeHint
takeWidget
viewportSizeHint
widget
widgetResizable
PyQt5.QtWidgets.QAbstractScrollArea
SizeAdjustPolicy
addScrollBarWidget
contextMenuEvent
cornerWidget
dragEnterEvent
dragLeaveEvent
dragMoveEvent
dropEvent
horizontalScrollBar
horizontalScrollBarPolicy
keyPressEvent
maximumViewportSize
minimumSizeHint
mouseDoubleClickEvent
mouseMoveEvent
mousePressEvent
mouseReleaseEvent
paintEvent
scrollBarWidgets
setCornerWidget
setHorizontalScrollBar
setHorizontalScrollBarPolicy
setSizeAdjustPolicy
setVerticalScrollBar
setVerticalScrollBarPolicy
setViewport
setViewportMargins
setupViewport
sizeAdjustPolicy
verticalScrollBar
verticalScrollBarPolicy
viewport
viewportEvent
viewportMargins
wheelEvent
AdjustIgnored
AdjustToContents
AdjustToContentsOnFirstShow
PyQt5.QtWidgets.QFrame
Shadow
Shape
StyleMask
changeEvent
drawFrame
frameRect
frameShadow
frameShape
frameStyle
frameWidth
initStyleOption
lineWidth
midLineWidth
setFrameRect
setFrameShadow
setFrameShape
setFrameStyle
setLineWidth
setMidLineWidth
Box
HLine
NoFrame
Panel
Plain
Raised
Shadow_Mask
Shape_Mask
StyledPanel
Sunken
VLine
WinPanel
PyQt5.QtWidgets.QWidget
RenderFlag
RenderFlags
acceptDrops
accessibleDescription
accessibleName
actionEvent
actions
activateWindow
addAction
addActions
adjustSize
autoFillBackground
backgroundRole
baseSize
childAt
childrenRect
childrenRegion
clearFocus
clearMask
close
closeEvent
contentsMargins
contentsRect
contextMenuPolicy
create
createWindowContainer
cursor
destroy
devType
effectiveWinId
ensurePolished
enterEvent
find
focusInEvent
focusNextChild
focusOutEvent
focusPolicy
focusPreviousChild
focusProxy
focusWidget
font
fontInfo
fontMetrics
foregroundRole
frameGeometry
frameSize
geometry
getContentsMargins
grab
grabGesture
grabKeyboard
grabMouse
grabShortcut
graphicsEffect
graphicsProxyWidget
hasFocus
hasHeightForWidth
hasMouseTracking
hasTabletTracking
height
heightForWidth
hide
hideEvent
initPainter
inputMethodEvent
inputMethodHints
inputMethodQuery
insertAction
insertActions
isActiveWindow
isAncestorOf
isEnabled
isEnabledTo
isFullScreen
isHidden
isLeftToRight
isMaximized
isMinimized
isModal
isRightToLeft
isVisible
isVisibleTo
isWindow
isWindowModified
keyReleaseEvent
keyboardGrabber
layout
layoutDirection
leaveEvent
locale
lower
mapFrom
mapFromGlobal
mapFromParent
mapTo
mapToGlobal
mapToParent
mask
maximumHeight
maximumSize
maximumWidth
metric
minimumHeight
minimumSize
minimumWidth
mouseGrabber
move
moveEvent
nativeEvent
nativeParentWidget
nextInFocusChain
normalGeometry
overrideWindowFlags
overrideWindowState
paintEngine
palette
parentWidget
pos
previousInFocusChain
raise_
rect
releaseKeyboard
releaseMouse
releaseShortcut
removeAction
render
repaint
resize
restoreGeometry
saveGeometry
screen
scroll
setAcceptDrops
setAccessibleDescription
setAccessibleName
setAttribute
setAutoFillBackground
setBackgroundRole
setBaseSize
setContentsMargins
setContextMenuPolicy
setCursor
setDisabled
setEnabled
setFixedHeight
setFixedSize
setFixedWidth
setFocus
setFocusPolicy
setFocusProxy
setFont
setForegroundRole
setGeometry
setGraphicsEffect
setHidden
setInputMethodHints
setLayout
setLayoutDirection
setLocale
setMask
setMaximumHeight
setMaximumSize
setMaximumWidth
setMinimumHeight
setMinimumSize
setMinimumWidth
setMouseTracking
setPalette
setParent
setShortcutAutoRepeat
setShortcutEnabled
setSizeIncrement
setSizePolicy
setStatusTip
setStyle
setStyleSheet
setTabOrder
setTabletTracking
setToolTip
setToolTipDuration
setUpdatesEnabled
setVisible
setWhatsThis
setWindowFilePath
setWindowFlag
setWindowFlags
setWindowIcon
setWindowIconText
setWindowModality
setWindowModified
setWindowOpacity
setWindowRole
setWindowState
setWindowTitle
sharedPainter
show
showEvent
showFullScreen
showMaximized
showMinimized
showNormal
size
sizeIncrement
sizePolicy
stackUnder
statusTip
style
styleSheet
tabletEvent
testAttribute
toolTip
toolTipDuration
underMouse
ungrabGesture
unsetCursor
unsetLayoutDirection
unsetLocale
update
updateGeometry
updateMicroFocus
updatesEnabled
visibleRegion
whatsThis
width
winId
window
windowFilePath
windowFlags
windowHandle
windowIcon
windowIconText
windowModality
windowOpacity
windowRole
windowState
windowTitle
windowType
x
y
DrawChildren
DrawWindowBackground
IgnoreMask
windowIconTextChanged
windowIconChanged
windowTitleChanged
customContextMenuRequested
PyQt5.QtCore.QObject
blockSignals
childEvent
children
connectNotify
customEvent
deleteLater
disconnect
disconnectNotify
dumpObjectInfo
dumpObjectTree
dynamicPropertyNames
findChild
findChildren
inherits
installEventFilter
isSignalConnected
isWidgetType
isWindowType
killTimer
metaObject
moveToThread
objectName
parent
property
pyqtConfigure
receivers
removeEventFilter
sender
senderSignalIndex
setObjectName
setProperty
signalsBlocked
startTimer
thread
timerEvent
tr
staticMetaObject
objectNameChanged
destroyed
PyQt5.QtGui.QPaintDevice
PaintDeviceMetric
colorCount
depth
devicePixelRatio
devicePixelRatioF
devicePixelRatioFScale
heightMM
logicalDpiX
logicalDpiY
paintingActive
physicalDpiX
physicalDpiY
widthMM
PdmDepth
PdmDevicePixelRatio
PdmDevicePixelRatioScaled
PdmDpiX
PdmDpiY
PdmHeight
PdmHeightMM
PdmNumColors
PdmPhysicalDpiX
PdmPhysicalDpiY
PdmWidth
PdmWidthMM
def annotator_tracking( image: numpy.ndarray, embedding_path: Optional[str] = None, model_type: str = 'vit_l', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None) -> Optional[napari.viewer.Viewer]:
185def annotator_tracking(
186    image: np.ndarray,
187    embedding_path: Optional[str] = None,
188    # tracking_result: Optional[str] = None,
189    model_type: str = util._DEFAULT_MODEL,
190    tile_shape: Optional[Tuple[int, int]] = None,
191    halo: Optional[Tuple[int, int]] = None,
192    return_viewer: bool = False,
193    viewer: Optional["napari.viewer.Viewer"] = None,
194    checkpoint_path: Optional[str] = None,
195    device: Optional[Union[str, torch.device]] = None,
196) -> Optional["napari.viewer.Viewer"]:
197    """Start the tracking annotation tool fora given timeseries.
198
199    Args:
200        image: The image data.
201        embedding_path: Filepath for saving the precomputed embeddings.
202        model_type: The Segment Anything model to use. For details on the available models check out
203            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
204        tile_shape: Shape of tiles for tiled embedding prediction.
205            If `None` then the whole image is passed to Segment Anything.
206        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
207        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
208        viewer: The viewer to which the SegmentAnything functionality should be added.
209            This enables using a pre-initialized viewer.
210        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
211        device: The computational device to use for the SAM model.
212
213    Returns:
214        The napari viewer, only returned if `return_viewer=True`.
215    """
216
217    # TODO update this to match the new annotator design
218    # Initialize the predictor state.
219    state = AnnotatorState()
220    state.initialize_predictor(
221        image, model_type=model_type, save_path=embedding_path,
222        halo=halo, tile_shape=tile_shape, prefer_decoder=False,
223        ndim=3, checkpoint_path=checkpoint_path, device=device,
224    )
225    state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape
226
227    if viewer is None:
228        viewer = napari.Viewer()
229
230    viewer.add_image(image, name="image")
231    annotator = AnnotatorTracking(viewer)
232
233    # Trigger layer update of the annotator so that layers have the correct shape.
234    annotator._update_image()
235
236    # Add the annotator widget to the viewer and sync widgets.
237    viewer.window.add_dock_widget(annotator)
238    vutil._sync_embedding_widget(
239        state.widgets["embeddings"], model_type,
240        save_path=embedding_path, checkpoint_path=checkpoint_path,
241        device=device, tile_shape=tile_shape, halo=halo
242    )
243
244    if return_viewer:
245        return viewer
246
247    napari.run()

Start the tracking annotation tool fora given timeseries.

Arguments:
  • image: The image data.
  • embedding_path: Filepath for saving the precomputed embeddings.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
  • viewer: The viewer to which the SegmentAnything functionality should be added. This enables using a pre-initialized viewer.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • device: The computational device to use for the SAM model.
Returns:

The napari viewer, only returned if return_viewer=True.