micro_sam.sam_annotator.annotator_tracking

  1from typing import Optional, Tuple, Union
  2
  3import napari
  4import numpy as np
  5
  6import torch
  7
  8from magicgui.widgets import ComboBox, Container
  9
 10from .. import util
 11from . import util as vutil
 12from . import _widgets as widgets
 13from ._tooltips import get_tooltip
 14from ._state import AnnotatorState
 15from ._annotator import _AnnotatorBase
 16
 17
 18# Cyan (track) and Magenta (division)
 19STATE_COLOR_CYCLE = ["#00FFFF", "#FF00FF", ]
 20"""@private"""
 21
 22
 23# This solution is a bit hacky, so I won't move it to _widgets.py yet.
 24def create_tracking_menu(points_layer, box_layer, states, track_ids):
 25    """@private"""
 26    state = AnnotatorState()
 27
 28    state_menu = ComboBox(label="track_state", choices=states, tooltip=get_tooltip("annotator_tracking", "track_state"))
 29    track_id_menu = ComboBox(
 30        label="track_id", choices=list(map(str, track_ids)), tooltip=get_tooltip("annotator_tracking", "track_id")
 31    )
 32    tracking_widget = Container(widgets=[state_menu, track_id_menu])
 33
 34    def update_state(event):
 35        new_state = str(points_layer.current_properties["state"][0])
 36        if new_state != state_menu.value:
 37            state_menu.value = new_state
 38
 39    def update_track_id(event):
 40        new_id = str(points_layer.current_properties["track_id"][0])
 41        if new_id != track_id_menu.value:
 42            track_id_menu.value = new_id
 43            state.current_track_id = int(new_id)
 44
 45    # def update_state_boxes(event):
 46    #     new_state = str(box_layer.current_properties["state"][0])
 47    #     if new_state != state_menu.value:
 48    #         state_menu.value = new_state
 49
 50    def update_track_id_boxes(event):
 51        new_id = str(box_layer.current_properties["track_id"][0])
 52        if new_id != track_id_menu.value:
 53            track_id_menu.value = new_id
 54            state.current_track_id = int(new_id)
 55
 56    points_layer.events.current_properties.connect(update_state)
 57    points_layer.events.current_properties.connect(update_track_id)
 58    # box_layer.events.current_properties.connect(update_state_boxes)
 59    box_layer.events.current_properties.connect(update_track_id_boxes)
 60
 61    def state_changed(new_state):
 62        current_properties = points_layer.current_properties
 63        current_properties["state"] = np.array([new_state])
 64        points_layer.current_properties = current_properties
 65        points_layer.refresh_colors()
 66
 67    def track_id_changed(new_track_id):
 68        current_properties = points_layer.current_properties
 69        current_properties["track_id"] = np.array([new_track_id])
 70        # Note: this fails with a key error after committing a lineage with multiple tracks.
 71        # I think this does not cause any further errors, so we just skip this.
 72        try:
 73            points_layer.current_properties = current_properties
 74        except KeyError:
 75            pass
 76        state.current_track_id = int(new_track_id)
 77
 78    # def state_changed_boxes(new_state):
 79    #     current_properties = box_layer.current_properties
 80    #     current_properties["state"] = np.array([new_state])
 81    #     box_layer.current_properties = current_properties
 82    #     box_layer.refresh_colors()
 83
 84    def track_id_changed_boxes(new_track_id):
 85        current_properties = box_layer.current_properties
 86        current_properties["track_id"] = np.array([new_track_id])
 87        box_layer.current_properties = current_properties
 88        state.current_track_id = int(new_track_id)
 89
 90    state_menu.changed.connect(state_changed)
 91    track_id_menu.changed.connect(track_id_changed)
 92    # state_menu.changed.connect(state_changed_boxes)
 93    track_id_menu.changed.connect(track_id_changed_boxes)
 94
 95    state_menu.set_choice("track")
 96    return tracking_widget
 97
 98
 99class AnnotatorTracking(_AnnotatorBase):
100
101    # The tracking annotator needs different settings for the prompt layers
102    # to support the additional tracking state.
103    # That's why we over-ride this function.
104    def _create_layers(self):
105        self._point_labels = ["positive", "negative"]
106        self._track_state_labels = ["track", "division"]
107
108        self._point_prompt_layer = self._viewer.add_points(
109            name="point_prompts",
110            property_choices={
111                "label": self._point_labels,
112                "state": self._track_state_labels,
113                "track_id": ["1"],  # we use string to avoid pandas warning
114            },
115            border_color="label",
116            border_color_cycle=vutil.LABEL_COLOR_CYCLE,
117            symbol="o",
118            face_color="state",
119            face_color_cycle=STATE_COLOR_CYCLE,
120            border_width=0.4,
121            size=12,
122            ndim=self._ndim,
123        )
124        self._point_prompt_layer.border_color_mode = "cycle"
125        self._point_prompt_layer.face_color_mode = "cycle"
126
127        # Using the box layer to set divisions currently doesn't work.
128        # That's why some of the code below is commented out.
129        self._box_prompt_layer = self._viewer.add_shapes(
130            shape_type="rectangle",
131            edge_width=4,
132            ndim=self._ndim,
133            face_color="transparent",
134            name="prompts",
135            edge_color="green",
136            property_choices={"track_id": ["1"]},
137            # property_choces={"track_id": ["1"], "state": self._track_state_labels},
138            # edge_color_cycle=STATE_COLOR_CYCLE,
139        )
140        # self._box_prompt_layer.edge_color_mode = "cycle"
141
142        # Add the label layers for the current object, the automatic segmentation and the committed segmentation.
143        dummy_data = np.zeros(self._shape, dtype="uint32")
144        self._viewer.add_labels(data=dummy_data, name="current_object")
145        self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
146        self._viewer.add_labels(data=dummy_data, name="committed_objects")
147        # Randomize colors so it is easy to see when object committed.
148        self._viewer.layers["committed_objects"].new_colormap()
149
150    def _get_widgets(self):
151        state = AnnotatorState()
152        # Create the tracking state menu.
153        self._tracking_widget = create_tracking_menu(
154            self._point_prompt_layer, self._box_prompt_layer,
155            states=self._track_state_labels, track_ids=list(state.lineage.keys()),
156        )
157        segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=True)
158        autotrack = widgets.AutoTrackWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True)
159        return {
160            "tracking": self._tracking_widget,
161            "segment": widgets.segment_frame(),
162            "segment_nd": segment_nd,
163            "autosegment": autotrack,
164            "commit": widgets.commit_track(),
165            "clear": widgets.clear_track(),
166        }
167
168    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
169        # Initialize the state for tracking.
170        self._init_track_state()
171        self._with_decoder = AnnotatorState().decoder is not None
172        super().__init__(viewer=viewer, ndim=3)
173        # Go to t=0.
174        self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:])
175
176    def _init_track_state(self):
177        state = AnnotatorState()
178        state.current_track_id = 1
179        state.lineage = {1: []}
180        state.committed_lineages = []
181
182    def _update_image(self):
183        super()._update_image()
184        self._init_track_state()
185        state = AnnotatorState()
186        if self._with_decoder:
187            state.amg_state = vutil._load_is_state(state.embedding_path)
188        else:
189            state.amg_state = vutil._load_amg_state(state.embedding_path)
190
191
192def annotator_tracking(
193    image: np.ndarray,
194    embedding_path: Optional[str] = None,
195    # tracking_result: Optional[str] = None,
196    model_type: str = util._DEFAULT_MODEL,
197    tile_shape: Optional[Tuple[int, int]] = None,
198    halo: Optional[Tuple[int, int]] = None,
199    return_viewer: bool = False,
200    viewer: Optional["napari.viewer.Viewer"] = None,
201    precompute_amg_state: bool = False,
202    checkpoint_path: Optional[str] = None,
203    device: Optional[Union[str, torch.device]] = None,
204) -> Optional["napari.viewer.Viewer"]:
205    """Start the tracking annotation tool fora given timeseries.
206
207    Args:
208        image: The image data.
209        embedding_path: Filepath for saving the precomputed embeddings.
210        model_type: The Segment Anything model to use. For details on the available models check out
211            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
212        tile_shape: Shape of tiles for tiled embedding prediction.
213            If `None` then the whole image is passed to Segment Anything.
214        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
215        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
216        viewer: The viewer to which the Segment Anything functionality should be added.
217            This enables using a pre-initialized viewer.
218        precompute_amg_state: Whether to precompute the state for automatic mask generation.
219            This will take more time when precomputing embeddings, but will then make
220            automatic mask generation much faster.
221        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
222        device: The computational device to use for the SAM model.
223
224    Returns:
225        The napari viewer, only returned if `return_viewer=True`.
226    """
227
228    # Initialize the predictor state.
229    state = AnnotatorState()
230    state.initialize_predictor(
231        image, model_type=model_type, save_path=embedding_path,
232        halo=halo, tile_shape=tile_shape, prefer_decoder=True,
233        ndim=3, checkpoint_path=checkpoint_path, device=device,
234        precompute_amg_state=precompute_amg_state,
235    )
236    state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape
237
238    if viewer is None:
239        viewer = napari.Viewer()
240
241    viewer.add_image(image, name="image")
242    annotator = AnnotatorTracking(viewer)
243
244    # Trigger layer update of the annotator so that layers have the correct shape.
245    annotator._update_image()
246
247    # Add the annotator widget to the viewer and sync widgets.
248    viewer.window.add_dock_widget(annotator)
249    vutil._sync_embedding_widget(
250        widget=state.widgets["embeddings"],
251        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
252        save_path=embedding_path,
253        checkpoint_path=checkpoint_path,
254        device=device,
255        tile_shape=tile_shape,
256        halo=halo,
257    )
258
259    if return_viewer:
260        return viewer
261
262    napari.run()
263
264
265def main():
266    """@private"""
267    parser = vutil._initialize_parser(
268        description="Run interactive segmentation for an image volume.",
269        with_segmentation_result=False,
270        with_instance_segmentation=False,
271    )
272
273    # Tracking result is not yet supported, we need to also deserialize the lineage.
274    # parser.add_argument(
275    #     "-t", "--tracking_result",
276    #     help="Optional filepath to a precomputed tracking result. If passed this will be used to initialize the "
277    #     "'committed_tracks' layer. This can be useful if you want to correct an existing tracking result or if you "
278    #     "have saved intermediate results from the annotator and want to continue. "
279    #     "Supports the same file formats as 'input'."
280    # )
281    # parser.add_argument(
282    #     "-tk", "--tracking_key",
283    #     help="The key for opening the tracking result. Same rules as for 'key' apply."
284    # )
285
286    args = parser.parse_args()
287    image = util.load_image_data(args.input, key=args.key)
288
289    annotator_tracking(
290        image, embedding_path=args.embedding_path, model_type=args.model_type,
291        tile_shape=args.tile_shape, halo=args.halo,
292        checkpoint_path=args.checkpoint, device=args.device,
293    )
class AnnotatorTracking(micro_sam.sam_annotator._annotator._AnnotatorBase):
100class AnnotatorTracking(_AnnotatorBase):
101
102    # The tracking annotator needs different settings for the prompt layers
103    # to support the additional tracking state.
104    # That's why we over-ride this function.
105    def _create_layers(self):
106        self._point_labels = ["positive", "negative"]
107        self._track_state_labels = ["track", "division"]
108
109        self._point_prompt_layer = self._viewer.add_points(
110            name="point_prompts",
111            property_choices={
112                "label": self._point_labels,
113                "state": self._track_state_labels,
114                "track_id": ["1"],  # we use string to avoid pandas warning
115            },
116            border_color="label",
117            border_color_cycle=vutil.LABEL_COLOR_CYCLE,
118            symbol="o",
119            face_color="state",
120            face_color_cycle=STATE_COLOR_CYCLE,
121            border_width=0.4,
122            size=12,
123            ndim=self._ndim,
124        )
125        self._point_prompt_layer.border_color_mode = "cycle"
126        self._point_prompt_layer.face_color_mode = "cycle"
127
128        # Using the box layer to set divisions currently doesn't work.
129        # That's why some of the code below is commented out.
130        self._box_prompt_layer = self._viewer.add_shapes(
131            shape_type="rectangle",
132            edge_width=4,
133            ndim=self._ndim,
134            face_color="transparent",
135            name="prompts",
136            edge_color="green",
137            property_choices={"track_id": ["1"]},
138            # property_choces={"track_id": ["1"], "state": self._track_state_labels},
139            # edge_color_cycle=STATE_COLOR_CYCLE,
140        )
141        # self._box_prompt_layer.edge_color_mode = "cycle"
142
143        # Add the label layers for the current object, the automatic segmentation and the committed segmentation.
144        dummy_data = np.zeros(self._shape, dtype="uint32")
145        self._viewer.add_labels(data=dummy_data, name="current_object")
146        self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
147        self._viewer.add_labels(data=dummy_data, name="committed_objects")
148        # Randomize colors so it is easy to see when object committed.
149        self._viewer.layers["committed_objects"].new_colormap()
150
151    def _get_widgets(self):
152        state = AnnotatorState()
153        # Create the tracking state menu.
154        self._tracking_widget = create_tracking_menu(
155            self._point_prompt_layer, self._box_prompt_layer,
156            states=self._track_state_labels, track_ids=list(state.lineage.keys()),
157        )
158        segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=True)
159        autotrack = widgets.AutoTrackWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True)
160        return {
161            "tracking": self._tracking_widget,
162            "segment": widgets.segment_frame(),
163            "segment_nd": segment_nd,
164            "autosegment": autotrack,
165            "commit": widgets.commit_track(),
166            "clear": widgets.clear_track(),
167        }
168
169    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
170        # Initialize the state for tracking.
171        self._init_track_state()
172        self._with_decoder = AnnotatorState().decoder is not None
173        super().__init__(viewer=viewer, ndim=3)
174        # Go to t=0.
175        self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:])
176
177    def _init_track_state(self):
178        state = AnnotatorState()
179        state.current_track_id = 1
180        state.lineage = {1: []}
181        state.committed_lineages = []
182
183    def _update_image(self):
184        super()._update_image()
185        self._init_track_state()
186        state = AnnotatorState()
187        if self._with_decoder:
188            state.amg_state = vutil._load_is_state(state.embedding_path)
189        else:
190            state.amg_state = vutil._load_amg_state(state.embedding_path)

Base class for micro_sam annotation plugins.

Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.

AnnotatorTracking(viewer: napari.viewer.Viewer)
169    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
170        # Initialize the state for tracking.
171        self._init_track_state()
172        self._with_decoder = AnnotatorState().decoder is not None
173        super().__init__(viewer=viewer, ndim=3)
174        # Go to t=0.
175        self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:])

Create the annotator GUI.

Arguments:
  • viewer: The napari viewer.
  • ndim: The number of spatial dimension of the image data (2 or 3).
Inherited Members
PyQt5.QtWidgets.QScrollArea
alignment
ensureVisible
ensureWidgetVisible
event
eventFilter
focusNextPrevChild
resizeEvent
scrollContentsBy
setAlignment
setWidget
setWidgetResizable
sizeHint
takeWidget
viewportSizeHint
widget
widgetResizable
PyQt5.QtWidgets.QAbstractScrollArea
SizeAdjustPolicy
addScrollBarWidget
contextMenuEvent
cornerWidget
dragEnterEvent
dragLeaveEvent
dragMoveEvent
dropEvent
horizontalScrollBar
horizontalScrollBarPolicy
keyPressEvent
maximumViewportSize
minimumSizeHint
mouseDoubleClickEvent
mouseMoveEvent
mousePressEvent
mouseReleaseEvent
paintEvent
scrollBarWidgets
setCornerWidget
setHorizontalScrollBar
setHorizontalScrollBarPolicy
setSizeAdjustPolicy
setVerticalScrollBar
setVerticalScrollBarPolicy
setViewport
setViewportMargins
setupViewport
sizeAdjustPolicy
verticalScrollBar
verticalScrollBarPolicy
viewport
viewportEvent
viewportMargins
wheelEvent
AdjustIgnored
AdjustToContents
AdjustToContentsOnFirstShow
PyQt5.QtWidgets.QFrame
Shadow
Shape
StyleMask
changeEvent
drawFrame
frameRect
frameShadow
frameShape
frameStyle
frameWidth
initStyleOption
lineWidth
midLineWidth
setFrameRect
setFrameShadow
setFrameShape
setFrameStyle
setLineWidth
setMidLineWidth
Box
HLine
NoFrame
Panel
Plain
Raised
Shadow_Mask
Shape_Mask
StyledPanel
Sunken
VLine
WinPanel
PyQt5.QtWidgets.QWidget
RenderFlag
RenderFlags
acceptDrops
accessibleDescription
accessibleName
actionEvent
actions
activateWindow
addAction
addActions
adjustSize
autoFillBackground
backgroundRole
baseSize
childAt
childrenRect
childrenRegion
clearFocus
clearMask
close
closeEvent
contentsMargins
contentsRect
contextMenuPolicy
create
createWindowContainer
cursor
destroy
devType
effectiveWinId
ensurePolished
enterEvent
find
focusInEvent
focusNextChild
focusOutEvent
focusPolicy
focusPreviousChild
focusProxy
focusWidget
font
fontInfo
fontMetrics
foregroundRole
frameGeometry
frameSize
geometry
getContentsMargins
grab
grabGesture
grabKeyboard
grabMouse
grabShortcut
graphicsEffect
graphicsProxyWidget
hasFocus
hasHeightForWidth
hasMouseTracking
hasTabletTracking
height
heightForWidth
hide
hideEvent
initPainter
inputMethodEvent
inputMethodHints
inputMethodQuery
insertAction
insertActions
isActiveWindow
isAncestorOf
isEnabled
isEnabledTo
isFullScreen
isHidden
isLeftToRight
isMaximized
isMinimized
isModal
isRightToLeft
isVisible
isVisibleTo
isWindow
isWindowModified
keyReleaseEvent
keyboardGrabber
layout
layoutDirection
leaveEvent
locale
lower
mapFrom
mapFromGlobal
mapFromParent
mapTo
mapToGlobal
mapToParent
mask
maximumHeight
maximumSize
maximumWidth
metric
minimumHeight
minimumSize
minimumWidth
mouseGrabber
move
moveEvent
nativeEvent
nativeParentWidget
nextInFocusChain
normalGeometry
overrideWindowFlags
overrideWindowState
paintEngine
palette
parentWidget
pos
previousInFocusChain
raise_
rect
releaseKeyboard
releaseMouse
releaseShortcut
removeAction
render
repaint
resize
restoreGeometry
saveGeometry
screen
scroll
setAcceptDrops
setAccessibleDescription
setAccessibleName
setAttribute
setAutoFillBackground
setBackgroundRole
setBaseSize
setContentsMargins
setContextMenuPolicy
setCursor
setDisabled
setEnabled
setFixedHeight
setFixedSize
setFixedWidth
setFocus
setFocusPolicy
setFocusProxy
setFont
setForegroundRole
setGeometry
setGraphicsEffect
setHidden
setInputMethodHints
setLayout
setLayoutDirection
setLocale
setMask
setMaximumHeight
setMaximumSize
setMaximumWidth
setMinimumHeight
setMinimumSize
setMinimumWidth
setMouseTracking
setPalette
setParent
setShortcutAutoRepeat
setShortcutEnabled
setSizeIncrement
setSizePolicy
setStatusTip
setStyle
setStyleSheet
setTabOrder
setTabletTracking
setToolTip
setToolTipDuration
setUpdatesEnabled
setVisible
setWhatsThis
setWindowFilePath
setWindowFlag
setWindowFlags
setWindowIcon
setWindowIconText
setWindowModality
setWindowModified
setWindowOpacity
setWindowRole
setWindowState
setWindowTitle
sharedPainter
show
showEvent
showFullScreen
showMaximized
showMinimized
showNormal
size
sizeIncrement
sizePolicy
stackUnder
statusTip
style
styleSheet
tabletEvent
testAttribute
toolTip
toolTipDuration
underMouse
ungrabGesture
unsetCursor
unsetLayoutDirection
unsetLocale
update
updateGeometry
updateMicroFocus
updatesEnabled
visibleRegion
whatsThis
width
winId
window
windowFilePath
windowFlags
windowHandle
windowIcon
windowIconText
windowModality
windowOpacity
windowRole
windowState
windowTitle
windowType
x
y
DrawChildren
DrawWindowBackground
IgnoreMask
windowIconTextChanged
windowIconChanged
windowTitleChanged
customContextMenuRequested
PyQt5.QtCore.QObject
blockSignals
childEvent
children
connectNotify
customEvent
deleteLater
disconnect
disconnectNotify
dumpObjectInfo
dumpObjectTree
dynamicPropertyNames
findChild
findChildren
inherits
installEventFilter
isSignalConnected
isWidgetType
isWindowType
killTimer
metaObject
moveToThread
objectName
parent
property
pyqtConfigure
receivers
removeEventFilter
sender
senderSignalIndex
setObjectName
setProperty
signalsBlocked
startTimer
thread
timerEvent
tr
staticMetaObject
objectNameChanged
destroyed
PyQt5.QtGui.QPaintDevice
PaintDeviceMetric
colorCount
depth
devicePixelRatio
devicePixelRatioF
devicePixelRatioFScale
heightMM
logicalDpiX
logicalDpiY
paintingActive
physicalDpiX
physicalDpiY
widthMM
PdmDepth
PdmDevicePixelRatio
PdmDevicePixelRatioScaled
PdmDpiX
PdmDpiY
PdmHeight
PdmHeightMM
PdmNumColors
PdmPhysicalDpiX
PdmPhysicalDpiY
PdmWidth
PdmWidthMM
def annotator_tracking( image: numpy.ndarray, embedding_path: Optional[str] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None) -> Optional[napari.viewer.Viewer]:
193def annotator_tracking(
194    image: np.ndarray,
195    embedding_path: Optional[str] = None,
196    # tracking_result: Optional[str] = None,
197    model_type: str = util._DEFAULT_MODEL,
198    tile_shape: Optional[Tuple[int, int]] = None,
199    halo: Optional[Tuple[int, int]] = None,
200    return_viewer: bool = False,
201    viewer: Optional["napari.viewer.Viewer"] = None,
202    precompute_amg_state: bool = False,
203    checkpoint_path: Optional[str] = None,
204    device: Optional[Union[str, torch.device]] = None,
205) -> Optional["napari.viewer.Viewer"]:
206    """Start the tracking annotation tool fora given timeseries.
207
208    Args:
209        image: The image data.
210        embedding_path: Filepath for saving the precomputed embeddings.
211        model_type: The Segment Anything model to use. For details on the available models check out
212            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
213        tile_shape: Shape of tiles for tiled embedding prediction.
214            If `None` then the whole image is passed to Segment Anything.
215        halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
216        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
217        viewer: The viewer to which the Segment Anything functionality should be added.
218            This enables using a pre-initialized viewer.
219        precompute_amg_state: Whether to precompute the state for automatic mask generation.
220            This will take more time when precomputing embeddings, but will then make
221            automatic mask generation much faster.
222        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
223        device: The computational device to use for the SAM model.
224
225    Returns:
226        The napari viewer, only returned if `return_viewer=True`.
227    """
228
229    # Initialize the predictor state.
230    state = AnnotatorState()
231    state.initialize_predictor(
232        image, model_type=model_type, save_path=embedding_path,
233        halo=halo, tile_shape=tile_shape, prefer_decoder=True,
234        ndim=3, checkpoint_path=checkpoint_path, device=device,
235        precompute_amg_state=precompute_amg_state,
236    )
237    state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape
238
239    if viewer is None:
240        viewer = napari.Viewer()
241
242    viewer.add_image(image, name="image")
243    annotator = AnnotatorTracking(viewer)
244
245    # Trigger layer update of the annotator so that layers have the correct shape.
246    annotator._update_image()
247
248    # Add the annotator widget to the viewer and sync widgets.
249    viewer.window.add_dock_widget(annotator)
250    vutil._sync_embedding_widget(
251        widget=state.widgets["embeddings"],
252        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
253        save_path=embedding_path,
254        checkpoint_path=checkpoint_path,
255        device=device,
256        tile_shape=tile_shape,
257        halo=halo,
258    )
259
260    if return_viewer:
261        return viewer
262
263    napari.run()

Start the tracking annotation tool fora given timeseries.

Arguments:
  • image: The image data.
  • embedding_path: Filepath for saving the precomputed embeddings.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile boarders.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • device: The computational device to use for the SAM model.
Returns:

The napari viewer, only returned if return_viewer=True.