micro_sam.sam_annotator.annotator_tracking

  1from typing import Optional, Tuple, Union, List
  2
  3import numpy as np
  4
  5import torch
  6
  7import napari
  8from magicgui.widgets import ComboBox, Container
  9
 10from .. import util
 11from . import util as vutil
 12from . import _widgets as widgets
 13from ._tooltips import get_tooltip
 14from ._state import AnnotatorState
 15from ._annotator import _AnnotatorBase
 16
 17
 18# Cyan (track) and Magenta (division)
 19STATE_COLOR_CYCLE = ["#00FFFF", "#FF00FF", ]
 20"""@private"""
 21
 22
 23# This solution is a bit hacky, so I won't move it to _widgets.py yet.
 24def create_tracking_menu(points_layer, box_layer, states, track_ids, tracking_widget=None):
 25    """@private"""
 26    state = AnnotatorState()
 27
 28    def _get_widget_menu(container, label):
 29        for w in container:
 30            if isinstance(w, ComboBox) and w.label == label:
 31                return w
 32        raise ValueError(f"ComboBox with label '{label}' not found.")
 33
 34    if tracking_widget is None:
 35        state_menu = ComboBox(
 36            label="track_state", choices=states, tooltip=get_tooltip("annotator_tracking", "track_state")
 37        )
 38        track_id_menu = ComboBox(
 39            label="track_id", choices=list(map(str, track_ids)), tooltip=get_tooltip("annotator_tracking", "track_id")
 40        )
 41        tracking_widget = Container(widgets=[state_menu, track_id_menu])
 42    else:
 43        state_menu = _get_widget_menu(tracking_widget, "track_state")
 44        track_id_menu = _get_widget_menu(tracking_widget, "track_id")
 45
 46    def update_state(event):
 47        if "state" in points_layer.current_properties:
 48            new_state = str(points_layer.current_properties["state"][0])
 49            if new_state != state_menu.value:
 50                state_menu.value = new_state
 51
 52    def update_track_id(event):
 53        if "track_id" in points_layer.current_properties:
 54            new_id = str(points_layer.current_properties["track_id"][0])
 55            if new_id != track_id_menu.value:
 56                track_id_menu.value = new_id
 57                state.current_track_id = int(new_id)
 58
 59    # def update_state_boxes(event):
 60    #     new_state = str(box_layer.current_properties["state"][0])
 61    #     if new_state != state_menu.value:
 62    #         state_menu.value = new_state
 63
 64    def update_track_id_boxes(event):
 65        if "track_id" in box_layer.current_properties:
 66            new_id = str(box_layer.current_properties["track_id"][0])
 67            if new_id != track_id_menu.value:
 68                track_id_menu.value = new_id
 69                state.current_track_id = int(new_id)
 70
 71    points_layer.events.current_properties.connect(update_state)
 72    points_layer.events.current_properties.connect(update_track_id)
 73    # box_layer.events.current_properties.connect(update_state_boxes)
 74    box_layer.events.current_properties.connect(update_track_id_boxes)
 75
 76    def state_changed(new_state):
 77        current_properties = points_layer.current_properties
 78        current_properties["state"] = np.array([new_state])
 79        points_layer.current_properties = current_properties
 80        points_layer.refresh_colors()
 81
 82    def track_id_changed(new_track_id):
 83        current_properties = points_layer.current_properties
 84        current_properties["track_id"] = np.array([new_track_id])
 85        # Note: this fails with a key error after committing a lineage with multiple tracks.
 86        # I think this does not cause any further errors, so we just skip this.
 87        try:
 88            points_layer.current_properties = current_properties
 89        except KeyError:
 90            pass
 91        state.current_track_id = int(new_track_id)
 92
 93    # def state_changed_boxes(new_state):
 94    #     current_properties = box_layer.current_properties
 95    #     current_properties["state"] = np.array([new_state])
 96    #     box_layer.current_properties = current_properties
 97    #     box_layer.refresh_colors()
 98
 99    def track_id_changed_boxes(new_track_id):
100        current_properties = box_layer.current_properties
101        current_properties["track_id"] = np.array([new_track_id])
102        box_layer.current_properties = current_properties
103        state.current_track_id = int(new_track_id)
104
105    state_menu.changed.connect(state_changed)
106    track_id_menu.changed.connect(track_id_changed)
107    # state_menu.changed.connect(state_changed_boxes)
108    track_id_menu.changed.connect(track_id_changed_boxes)
109
110    state_menu.set_choice("track")
111    return tracking_widget
112
113
114class AnnotatorTracking(_AnnotatorBase):
115
116    # The tracking annotator needs different settings for the prompt layers
117    # to support the additional tracking state.
118    # That's why we over-ride this function.
119    def _require_layers(self, layer_choices: Optional[List[str]] = None):
120
121        # Check whether the image is initialized already. And use the image shape and scale for the layers.
122        state = AnnotatorState()
123        shape = self._shape if state.image_shape is None else state.image_shape
124
125        # Add the label layers for the current object, the automatic segmentation and the committed segmentation.
126        dummy_data = np.zeros(shape, dtype="uint32")
127        image_scale = state.image_scale
128
129        # Before adding new layers, we always check whether a layer with this name already exists or not.
130        if "current_object" not in self._viewer.layers:
131            if layer_choices and "current_object" in layer_choices:  # Check at 'commit' call button.
132                widgets._validation_window_for_missing_layer("current_object")
133            self._viewer.add_labels(data=dummy_data, name="current_object")
134            if image_scale is not None:
135                self._viewer.layers["current_object"].scale = image_scale
136
137        if "auto_segmentation" not in self._viewer.layers:
138            if layer_choices and "auto_segmentation" in layer_choices:  # Check at 'commit' call button.
139                widgets._validation_window_for_missing_layer("auto_segmentation")
140            self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
141            if image_scale is not None:
142                self._viewer.layers["auto_segmentation"].scale = image_scale
143
144        if "committed_objects" not in self._viewer.layers:
145            if layer_choices and "committed_objects" in layer_choices:  # Check at 'commit' call button.
146                widgets._validation_window_for_missing_layer("committed_objects")
147            self._viewer.add_labels(data=dummy_data, name="committed_objects")
148            # Randomize colors so it is easy to see when object committed.
149            self._viewer.layers["committed_objects"].new_colormap()
150            if image_scale is not None:
151                self._viewer.layers["committed_objects"].scale = image_scale
152
153        # Add the point prompts layer.
154        self._point_labels = ["positive", "negative"]
155        self._track_state_labels = ["track", "division"]
156        _point_prompt_property_choices = {
157            "label": self._point_labels,
158            "state": self._track_state_labels,
159            "track_id": ["1"],  # we use string to avoid pandas warning
160        }
161
162        point_layer_mismatch = True
163        if "point_prompts" in self._viewer.layers:
164            # Check whether the 'property_choices' match or not.
165            curr_property_choices = self._viewer.layers["point_prompts"].property_choices
166            point_layer_mismatch = set(curr_property_choices.keys()) != set(_point_prompt_property_choices.keys())
167
168        if point_layer_mismatch and "point_prompts" not in self._viewer.layers:
169            self._point_prompt_layer = self._viewer.add_points(
170                name="point_prompts",
171                property_choices=_point_prompt_property_choices,
172                border_color="label",
173                border_color_cycle=vutil.LABEL_COLOR_CYCLE,
174                symbol="o",
175                face_color="state",
176                face_color_cycle=STATE_COLOR_CYCLE,
177                border_width=0.4,
178                size=12,
179                ndim=self._ndim,
180            )
181            self._point_prompt_layer.border_color_mode = "cycle"
182            self._point_prompt_layer.face_color_mode = "cycle"
183            _new_point_layer = True
184        else:
185            self._point_prompt_layer = self._viewer.layers["point_prompts"]
186            _new_point_layer = False
187
188        # Add the point prompts layer.
189        _box_prompt_property_choices = {"track_id": ["1"]}
190
191        box_layer_mismatch = True
192        if "prompts" in self._viewer.layers:
193            # Check whether the 'property_choices' match or not.
194            curr_property_choices = self._viewer.layers["prompts"].property_choices
195            box_layer_mismatch = set(curr_property_choices.keys()) != set(_box_prompt_property_choices.keys())
196
197        if box_layer_mismatch and "prompts" not in self._viewer.layers:
198            # Using the box layer to set divisions currently doesn't work.
199            # That's why some of the code below is commented out.
200            self._box_prompt_layer = self._viewer.add_shapes(
201                shape_type="rectangle",
202                edge_width=4,
203                ndim=self._ndim,
204                face_color="transparent",
205                name="prompts",
206                edge_color="green",
207                property_choices=_box_prompt_property_choices,
208                # property_choices={"track_id": ["1"], "state": self._track_state_labels},
209                # edge_color_cycle=STATE_COLOR_CYCLE,
210            )
211            # self._box_prompt_layer.edge_color_mode = "cycle"
212            _new_box_layer = True
213        else:
214            self._box_prompt_layer = self._viewer.layers["prompts"]
215            _new_box_layer = False
216
217        # Trigger a new connection for the tracking state menu only when a new layer is (re)created.
218        if _new_point_layer or _new_box_layer:
219            self._tracking_widget = create_tracking_menu(
220                points_layer=self._point_prompt_layer,
221                box_layer=self._box_prompt_layer,
222                states=self._track_state_labels,
223                track_ids=list(state.lineage.keys()),
224                tracking_widget=state.widgets.get("tracking"),
225            )
226            state.widgets["tracking"] = self._tracking_widget
227
228    def _get_widgets(self):
229        state = AnnotatorState()
230        self._require_layers()
231
232        # Create the tracking state menu.
233        # NOTE: Check whether it exists already from `_require_layers` or needs to be created.
234        if state.widgets.get("tracking") is None:
235            self._tracking_widget = create_tracking_menu(
236                ponts_layer=self._point_prompt_layer,
237                box_layer=self._box_prompt_layer,
238                states=self._track_state_labels,
239                track_ids=list(state.lineage.keys()),
240            )
241        else:
242            self._tracking_widget = state.widgets.get("tracking")
243
244        segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=True)
245        autotrack = widgets.AutoTrackWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True)
246        return {
247            "tracking": self._tracking_widget,
248            "segment": widgets.segment_frame(),
249            "segment_nd": segment_nd,
250            "autosegment": autotrack,
251            "commit": widgets.commit_track(),
252            "clear": widgets.clear_track(),
253        }
254
255    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
256        # Initialize the state for tracking.
257        self._init_track_state()
258        self._with_decoder = AnnotatorState().decoder is not None
259        super().__init__(viewer=viewer, ndim=3)
260        # Go to t=0.
261        self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:])
262
263        # Set the expected annotator class to the state.
264        state = AnnotatorState()
265
266        # Reset the state.
267        if reset_state:
268            state.reset_state()
269
270        state.annotator = self
271
272    def _init_track_state(self):
273        state = AnnotatorState()
274        state.current_track_id = 1
275        state.lineage = {1: []}
276        state.committed_lineages = []
277
278    def _update_image(self):
279        super()._update_image()
280        self._init_track_state()
281        state = AnnotatorState()
282        if self._with_decoder:
283            state.amg_state = vutil._load_is_state(state.embedding_path)
284        else:
285            state.amg_state = vutil._load_amg_state(state.embedding_path)
286
287
288def annotator_tracking(
289    image: np.ndarray,
290    embedding_path: Optional[str] = None,
291    # tracking_result: Optional[str] = None,
292    model_type: str = util._DEFAULT_MODEL,
293    tile_shape: Optional[Tuple[int, int]] = None,
294    halo: Optional[Tuple[int, int]] = None,
295    return_viewer: bool = False,
296    viewer: Optional["napari.viewer.Viewer"] = None,
297    precompute_amg_state: bool = False,
298    checkpoint_path: Optional[str] = None,
299    decoder_path: Optional[str] = None,
300    device: Optional[Union[str, torch.device]] = None,
301) -> Optional["napari.viewer.Viewer"]:
302    """Start the tracking annotation tool fora given timeseries.
303
304    Args:
305        image: The image data.
306        embedding_path: Filepath for saving the precomputed embeddings.
307        model_type: The Segment Anything model to use. For details on the available models check out
308            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
309        tile_shape: Shape of tiles for tiled embedding prediction.
310            If `None` then the whole image is passed to Segment Anything.
311        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
312        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
313            By default, does not return the napari viewer.
314        viewer: The viewer to which the Segment Anything functionality should be added.
315            This enables using a pre-initialized viewer.
316        precompute_amg_state: Whether to precompute the state for automatic mask generation.
317            This will take more time when precomputing embeddings, but will then make
318            automatic mask generation much faster. By default, set to 'False'.
319        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
320        decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder.
321        device: The computational device to use for the SAM model.
322            By default, automatically chooses the best available device.
323
324    Returns:
325        The napari viewer, only returned if `return_viewer=True`.
326    """
327
328    # Initialize the predictor state.
329    state = AnnotatorState()
330    state.initialize_predictor(
331        image, model_type=model_type, save_path=embedding_path,
332        halo=halo, tile_shape=tile_shape, prefer_decoder=True,
333        ndim=3, checkpoint_path=checkpoint_path, device=device,
334        precompute_amg_state=precompute_amg_state, use_cli=True,
335        decoder_path=decoder_path,
336    )
337    state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape
338
339    if viewer is None:
340        viewer = napari.Viewer()
341
342    viewer.add_image(image, name="image")
343    annotator = AnnotatorTracking(viewer, reset_state=False)
344
345    # Trigger layer update of the annotator so that layers have the correct shape.
346    annotator._update_image()
347
348    # Add the annotator widget to the viewer and sync widgets.
349    viewer.window.add_dock_widget(annotator)
350    vutil._sync_embedding_widget(
351        widget=state.widgets["embeddings"],
352        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
353        save_path=embedding_path,
354        checkpoint_path=checkpoint_path,
355        device=device,
356        tile_shape=tile_shape,
357        halo=halo,
358    )
359
360    if return_viewer:
361        return viewer
362
363    napari.run()
364
365
366def main():
367    """@private"""
368    parser = vutil._initialize_parser(
369        description="Run interactive segmentation for an image volume.",
370        with_segmentation_result=False,
371        with_instance_segmentation=False,
372    )
373
374    # Tracking result is not yet supported, we need to also deserialize the lineage.
375    # parser.add_argument(
376    #     "-t", "--tracking_result",
377    #     help="Optional filepath to a precomputed tracking result. If passed this will be used to initialize the "
378    #     "'committed_tracks' layer. This can be useful if you want to correct an existing tracking result or if you "
379    #     "have saved intermediate results from the annotator and want to continue. "
380    #     "Supports the same file formats as 'input'."
381    # )
382    # parser.add_argument(
383    #     "-tk", "--tracking_key",
384    #     help="The key for opening the tracking result. Same rules as for 'key' apply."
385    # )
386
387    args = parser.parse_args()
388    image = util.load_image_data(args.input, key=args.key)
389
390    annotator_tracking(
391        image, embedding_path=args.embedding_path, model_type=args.model_type,
392        tile_shape=args.tile_shape, halo=args.halo,
393        checkpoint_path=args.checkpoint, decoder_path=args.decoder_path, device=args.device,
394    )
class AnnotatorTracking(micro_sam.sam_annotator._annotator._AnnotatorBase):
115class AnnotatorTracking(_AnnotatorBase):
116
117    # The tracking annotator needs different settings for the prompt layers
118    # to support the additional tracking state.
119    # That's why we over-ride this function.
120    def _require_layers(self, layer_choices: Optional[List[str]] = None):
121
122        # Check whether the image is initialized already. And use the image shape and scale for the layers.
123        state = AnnotatorState()
124        shape = self._shape if state.image_shape is None else state.image_shape
125
126        # Add the label layers for the current object, the automatic segmentation and the committed segmentation.
127        dummy_data = np.zeros(shape, dtype="uint32")
128        image_scale = state.image_scale
129
130        # Before adding new layers, we always check whether a layer with this name already exists or not.
131        if "current_object" not in self._viewer.layers:
132            if layer_choices and "current_object" in layer_choices:  # Check at 'commit' call button.
133                widgets._validation_window_for_missing_layer("current_object")
134            self._viewer.add_labels(data=dummy_data, name="current_object")
135            if image_scale is not None:
136                self._viewer.layers["current_object"].scale = image_scale
137
138        if "auto_segmentation" not in self._viewer.layers:
139            if layer_choices and "auto_segmentation" in layer_choices:  # Check at 'commit' call button.
140                widgets._validation_window_for_missing_layer("auto_segmentation")
141            self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
142            if image_scale is not None:
143                self._viewer.layers["auto_segmentation"].scale = image_scale
144
145        if "committed_objects" not in self._viewer.layers:
146            if layer_choices and "committed_objects" in layer_choices:  # Check at 'commit' call button.
147                widgets._validation_window_for_missing_layer("committed_objects")
148            self._viewer.add_labels(data=dummy_data, name="committed_objects")
149            # Randomize colors so it is easy to see when object committed.
150            self._viewer.layers["committed_objects"].new_colormap()
151            if image_scale is not None:
152                self._viewer.layers["committed_objects"].scale = image_scale
153
154        # Add the point prompts layer.
155        self._point_labels = ["positive", "negative"]
156        self._track_state_labels = ["track", "division"]
157        _point_prompt_property_choices = {
158            "label": self._point_labels,
159            "state": self._track_state_labels,
160            "track_id": ["1"],  # we use string to avoid pandas warning
161        }
162
163        point_layer_mismatch = True
164        if "point_prompts" in self._viewer.layers:
165            # Check whether the 'property_choices' match or not.
166            curr_property_choices = self._viewer.layers["point_prompts"].property_choices
167            point_layer_mismatch = set(curr_property_choices.keys()) != set(_point_prompt_property_choices.keys())
168
169        if point_layer_mismatch and "point_prompts" not in self._viewer.layers:
170            self._point_prompt_layer = self._viewer.add_points(
171                name="point_prompts",
172                property_choices=_point_prompt_property_choices,
173                border_color="label",
174                border_color_cycle=vutil.LABEL_COLOR_CYCLE,
175                symbol="o",
176                face_color="state",
177                face_color_cycle=STATE_COLOR_CYCLE,
178                border_width=0.4,
179                size=12,
180                ndim=self._ndim,
181            )
182            self._point_prompt_layer.border_color_mode = "cycle"
183            self._point_prompt_layer.face_color_mode = "cycle"
184            _new_point_layer = True
185        else:
186            self._point_prompt_layer = self._viewer.layers["point_prompts"]
187            _new_point_layer = False
188
189        # Add the point prompts layer.
190        _box_prompt_property_choices = {"track_id": ["1"]}
191
192        box_layer_mismatch = True
193        if "prompts" in self._viewer.layers:
194            # Check whether the 'property_choices' match or not.
195            curr_property_choices = self._viewer.layers["prompts"].property_choices
196            box_layer_mismatch = set(curr_property_choices.keys()) != set(_box_prompt_property_choices.keys())
197
198        if box_layer_mismatch and "prompts" not in self._viewer.layers:
199            # Using the box layer to set divisions currently doesn't work.
200            # That's why some of the code below is commented out.
201            self._box_prompt_layer = self._viewer.add_shapes(
202                shape_type="rectangle",
203                edge_width=4,
204                ndim=self._ndim,
205                face_color="transparent",
206                name="prompts",
207                edge_color="green",
208                property_choices=_box_prompt_property_choices,
209                # property_choices={"track_id": ["1"], "state": self._track_state_labels},
210                # edge_color_cycle=STATE_COLOR_CYCLE,
211            )
212            # self._box_prompt_layer.edge_color_mode = "cycle"
213            _new_box_layer = True
214        else:
215            self._box_prompt_layer = self._viewer.layers["prompts"]
216            _new_box_layer = False
217
218        # Trigger a new connection for the tracking state menu only when a new layer is (re)created.
219        if _new_point_layer or _new_box_layer:
220            self._tracking_widget = create_tracking_menu(
221                points_layer=self._point_prompt_layer,
222                box_layer=self._box_prompt_layer,
223                states=self._track_state_labels,
224                track_ids=list(state.lineage.keys()),
225                tracking_widget=state.widgets.get("tracking"),
226            )
227            state.widgets["tracking"] = self._tracking_widget
228
229    def _get_widgets(self):
230        state = AnnotatorState()
231        self._require_layers()
232
233        # Create the tracking state menu.
234        # NOTE: Check whether it exists already from `_require_layers` or needs to be created.
235        if state.widgets.get("tracking") is None:
236            self._tracking_widget = create_tracking_menu(
237                ponts_layer=self._point_prompt_layer,
238                box_layer=self._box_prompt_layer,
239                states=self._track_state_labels,
240                track_ids=list(state.lineage.keys()),
241            )
242        else:
243            self._tracking_widget = state.widgets.get("tracking")
244
245        segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=True)
246        autotrack = widgets.AutoTrackWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True)
247        return {
248            "tracking": self._tracking_widget,
249            "segment": widgets.segment_frame(),
250            "segment_nd": segment_nd,
251            "autosegment": autotrack,
252            "commit": widgets.commit_track(),
253            "clear": widgets.clear_track(),
254        }
255
256    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
257        # Initialize the state for tracking.
258        self._init_track_state()
259        self._with_decoder = AnnotatorState().decoder is not None
260        super().__init__(viewer=viewer, ndim=3)
261        # Go to t=0.
262        self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:])
263
264        # Set the expected annotator class to the state.
265        state = AnnotatorState()
266
267        # Reset the state.
268        if reset_state:
269            state.reset_state()
270
271        state.annotator = self
272
273    def _init_track_state(self):
274        state = AnnotatorState()
275        state.current_track_id = 1
276        state.lineage = {1: []}
277        state.committed_lineages = []
278
279    def _update_image(self):
280        super()._update_image()
281        self._init_track_state()
282        state = AnnotatorState()
283        if self._with_decoder:
284            state.amg_state = vutil._load_is_state(state.embedding_path)
285        else:
286            state.amg_state = vutil._load_amg_state(state.embedding_path)

Base class for micro_sam annotation plugins.

Implements the logic for the 2d, 3d and tracking annotator. The annotators differ in their data dimensionality and the widgets.

AnnotatorTracking(viewer: napari.viewer.Viewer, reset_state: bool = True)
256    def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None:
257        # Initialize the state for tracking.
258        self._init_track_state()
259        self._with_decoder = AnnotatorState().decoder is not None
260        super().__init__(viewer=viewer, ndim=3)
261        # Go to t=0.
262        self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:])
263
264        # Set the expected annotator class to the state.
265        state = AnnotatorState()
266
267        # Reset the state.
268        if reset_state:
269            state.reset_state()
270
271        state.annotator = self

Create the annotator GUI.

Arguments:
  • viewer: The napari viewer.
  • ndim: The number of spatial dimension of the image data (2 or 3).
def annotator_tracking( image: numpy.ndarray, embedding_path: Optional[str] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, decoder_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None) -> Optional[napari.viewer.Viewer]:
289def annotator_tracking(
290    image: np.ndarray,
291    embedding_path: Optional[str] = None,
292    # tracking_result: Optional[str] = None,
293    model_type: str = util._DEFAULT_MODEL,
294    tile_shape: Optional[Tuple[int, int]] = None,
295    halo: Optional[Tuple[int, int]] = None,
296    return_viewer: bool = False,
297    viewer: Optional["napari.viewer.Viewer"] = None,
298    precompute_amg_state: bool = False,
299    checkpoint_path: Optional[str] = None,
300    decoder_path: Optional[str] = None,
301    device: Optional[Union[str, torch.device]] = None,
302) -> Optional["napari.viewer.Viewer"]:
303    """Start the tracking annotation tool fora given timeseries.
304
305    Args:
306        image: The image data.
307        embedding_path: Filepath for saving the precomputed embeddings.
308        model_type: The Segment Anything model to use. For details on the available models check out
309            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
310        tile_shape: Shape of tiles for tiled embedding prediction.
311            If `None` then the whole image is passed to Segment Anything.
312        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
313        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
314            By default, does not return the napari viewer.
315        viewer: The viewer to which the Segment Anything functionality should be added.
316            This enables using a pre-initialized viewer.
317        precompute_amg_state: Whether to precompute the state for automatic mask generation.
318            This will take more time when precomputing embeddings, but will then make
319            automatic mask generation much faster. By default, set to 'False'.
320        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
321        decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder.
322        device: The computational device to use for the SAM model.
323            By default, automatically chooses the best available device.
324
325    Returns:
326        The napari viewer, only returned if `return_viewer=True`.
327    """
328
329    # Initialize the predictor state.
330    state = AnnotatorState()
331    state.initialize_predictor(
332        image, model_type=model_type, save_path=embedding_path,
333        halo=halo, tile_shape=tile_shape, prefer_decoder=True,
334        ndim=3, checkpoint_path=checkpoint_path, device=device,
335        precompute_amg_state=precompute_amg_state, use_cli=True,
336        decoder_path=decoder_path,
337    )
338    state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape
339
340    if viewer is None:
341        viewer = napari.Viewer()
342
343    viewer.add_image(image, name="image")
344    annotator = AnnotatorTracking(viewer, reset_state=False)
345
346    # Trigger layer update of the annotator so that layers have the correct shape.
347    annotator._update_image()
348
349    # Add the annotator widget to the viewer and sync widgets.
350    viewer.window.add_dock_widget(annotator)
351    vutil._sync_embedding_widget(
352        widget=state.widgets["embeddings"],
353        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
354        save_path=embedding_path,
355        checkpoint_path=checkpoint_path,
356        device=device,
357        tile_shape=tile_shape,
358        halo=halo,
359    )
360
361    if return_viewer:
362        return viewer
363
364    napari.run()

Start the tracking annotation tool fora given timeseries.

Arguments:
  • image: The image data.
  • embedding_path: Filepath for saving the precomputed embeddings.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • precompute_amg_state: Whether to precompute the state for automatic mask generation. This will take more time when precomputing embeddings, but will then make automatic mask generation much faster. By default, set to 'False'.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • decoder_path: Path to a custom decoder checkpoint from which to load the 'micro-sam` decoder.
  • device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
Returns:

The napari viewer, only returned if return_viewer=True.