micro_sam.sam_annotator._widgets

Implements the widgets used in the annotation plugins.

   1"""Implements the widgets used in the annotation plugins.
   2"""
   3
   4import os
   5import gc
   6import multiprocessing as mp
   7import pickle
   8from pathlib import Path
   9from typing import Optional
  10
  11import h5py
  12import json
  13import zarr
  14import napari
  15import numpy as np
  16
  17try:
  18    import z5py
  19except ImportError:
  20    z5py = None
  21
  22from bioimage_cpp.utils import segmentation_overlap
  23
  24import elf.parallel
  25
  26from qtpy import QtWidgets
  27from qtpy.QtCore import QObject, Signal
  28from superqt import QCollapsible
  29from napari.utils.notifications import show_info
  30from magicgui import magic_factory
  31from magicgui.widgets import ComboBox, Container, create_widget
  32# We have disabled the thread workers for now because they result in a
  33# massive slowdown in napari >= 0.5.
  34# See also https://forum.image.sc/t/napari-thread-worker-leads-to-massive-slowdown/103786
  35# from napari.qt.threading import thread_worker
  36from napari.utils import progress
  37
  38from . import util as vutil
  39from ._tooltips import get_tooltip
  40from ._state import AnnotatorState
  41from .. import instance_segmentation, util
  42from ..multi_dimensional_segmentation import (
  43    segment_mask_in_volume, merge_instance_segmentation_3d, track_across_frames, PROJECTION_MODES, get_napari_track_data
  44)
  45
  46
  47#
  48# Convenience functionality for creating QT UI and manipulating the napari viewer.
  49#
  50
  51
  52def _select_layer(viewer, layer_name):
  53    viewer.layers.selection.select_only(viewer.layers[layer_name])
  54
  55
  56# Create a collapsible around the widget
  57def _make_collapsible(widget, title):
  58    parent_widget = QtWidgets.QWidget()
  59    parent_widget.setLayout(QtWidgets.QVBoxLayout())
  60    collapsible = QCollapsible(title, parent_widget)
  61    collapsible.addWidget(widget)
  62    parent_widget.layout().addWidget(collapsible)
  63    return parent_widget
  64
  65
  66# Base class for a widget with convenience functionality for adding parameters.
  67class _WidgetBase(QtWidgets.QWidget):
  68    def __init__(self, parent=None):
  69        super().__init__(parent)
  70        self.setLayout(QtWidgets.QVBoxLayout())
  71
  72    def _add_boolean_param(self, name, value, title=None, tooltip=None):
  73        checkbox = QtWidgets.QCheckBox(name if title is None else title)
  74        checkbox.setChecked(value)
  75        checkbox.stateChanged.connect(lambda val: setattr(self, name, val))
  76        if tooltip:
  77            checkbox.setToolTip(tooltip)
  78        return checkbox
  79
  80    def _add_string_param(self, name, value, title=None, placeholder=None, layout=None, tooltip=None):
  81        if layout is None:
  82            layout = QtWidgets.QHBoxLayout()
  83        label = QtWidgets.QLabel(title or name)
  84        if tooltip:
  85            label.setToolTip(tooltip)
  86        layout.addWidget(label)
  87        param = QtWidgets.QLineEdit()
  88        param.setText(value)
  89        if placeholder is not None:
  90            param.setPlaceholderText(placeholder)
  91        param.textChanged.connect(lambda val: setattr(self, name, val))
  92        if tooltip:
  93            param.setToolTip(tooltip)
  94        layout.addWidget(param)
  95        return param, layout
  96
  97    def _add_float_param(self, name, value, title=None, min_val=0.0, max_val=1.0, decimals=2,
  98                         step=0.01, layout=None, tooltip=None):
  99        if layout is None:
 100            layout = QtWidgets.QHBoxLayout()
 101        label = QtWidgets.QLabel(title or name)
 102        if tooltip:
 103            label.setToolTip(tooltip)
 104        layout.addWidget(label)
 105        param = QtWidgets.QDoubleSpinBox()
 106        param.setRange(min_val, max_val)
 107        param.setDecimals(decimals)
 108        param.setValue(value)
 109        param.setSingleStep(step)
 110        param.valueChanged.connect(lambda val: setattr(self, name, val))
 111        if tooltip:
 112            param.setToolTip(tooltip)
 113        layout.addWidget(param)
 114        return param, layout
 115
 116    def _add_int_param(self, name, value, min_val, max_val, title=None, step=1, layout=None, tooltip=None):
 117        if layout is None:
 118            layout = QtWidgets.QHBoxLayout()
 119        label = QtWidgets.QLabel(title or name)
 120        if tooltip:
 121            label.setToolTip(tooltip)
 122        layout.addWidget(label)
 123        param = QtWidgets.QSpinBox()
 124        param.setRange(min_val, max_val)
 125        param.setValue(value)
 126        param.setSingleStep(step)
 127        param.valueChanged.connect(lambda val: setattr(self, name, val))
 128        if tooltip:
 129            param.setToolTip(tooltip)
 130        layout.addWidget(param)
 131        return param, layout
 132
 133    def _add_choice_param(self, name, value, options, title=None, layout=None, update=None, tooltip=None):
 134        if layout is None:
 135            layout = QtWidgets.QHBoxLayout()
 136        label = QtWidgets.QLabel(title or name)
 137        if tooltip:
 138            label.setToolTip(tooltip)
 139        layout.addWidget(label)
 140
 141        # Create the dropdown menu via QComboBox, set the available values.
 142        dropdown = QtWidgets.QComboBox()
 143        dropdown.addItems(options)
 144        if update is None:
 145            dropdown.currentIndexChanged.connect(lambda index: setattr(self, name, options[index]))
 146        else:
 147            dropdown.currentIndexChanged.connect(update)
 148
 149        # Set the correct value for the value.
 150        dropdown.setCurrentIndex(dropdown.findText(value))
 151
 152        if tooltip:
 153            dropdown.setToolTip(tooltip)
 154
 155        layout.addWidget(dropdown)
 156        return dropdown, layout
 157
 158    def _add_shape_param(self, names, values, min_val, max_val, step=1, title=None, tooltip=None):
 159        layout = QtWidgets.QHBoxLayout()
 160
 161        x_layout = QtWidgets.QVBoxLayout()
 162        x_param, _ = self._add_int_param(
 163            names[0], values[0], min_val=min_val, max_val=max_val, layout=x_layout, step=step,
 164            title=title[0] if title is not None else title, tooltip=tooltip
 165        )
 166        layout.addLayout(x_layout)
 167
 168        y_layout = QtWidgets.QVBoxLayout()
 169        y_param, _ = self._add_int_param(
 170            names[1], values[1], min_val=min_val, max_val=max_val, layout=y_layout, step=step,
 171            title=title[1] if title is not None else title, tooltip=tooltip
 172        )
 173        layout.addLayout(y_layout)
 174
 175        return x_param, y_param, layout
 176
 177    def _add_path_param(self, name, value, select_type, title=None, placeholder=None, tooltip=None):
 178        assert select_type in ("directory", "file", "both")
 179
 180        layout = QtWidgets.QHBoxLayout()
 181        label = QtWidgets.QLabel(title or name)
 182        if tooltip:
 183            label.setToolTip(tooltip)
 184        layout.addWidget(label)
 185
 186        path_textbox = QtWidgets.QLineEdit()
 187        path_textbox.setText(str(value))
 188        if placeholder is not None:
 189            path_textbox.setPlaceholderText(placeholder)
 190        path_textbox.textChanged.connect(lambda val: setattr(self, name, val))
 191        if tooltip:
 192            path_textbox.setToolTip(tooltip)
 193
 194        layout.addWidget(path_textbox)
 195
 196        def add_path_button(select_type, tooltip=None):
 197            # Adjust button text.
 198            button_text = f"Select {select_type.capitalize()}"
 199            path_button = QtWidgets.QPushButton(button_text)
 200
 201            # Call appropriate function based on select_type.
 202            path_button.clicked.connect(lambda: getattr(self, f"_get_{select_type}_path")(name, path_textbox))
 203            if tooltip:
 204                path_button.setToolTip(tooltip)
 205            layout.addWidget(path_button)
 206
 207        if select_type == "both":
 208            add_path_button("file")
 209            add_path_button("directory")
 210
 211        else:
 212            add_path_button(select_type)
 213
 214        return path_textbox, layout
 215
 216    def _get_directory_path(self, name, textbox, tooltip=None):
 217        directory = QtWidgets.QFileDialog.getExistingDirectory(
 218            self, "Select Directory", "", QtWidgets.QFileDialog.ShowDirsOnly
 219        )
 220        if tooltip:
 221            directory.setToolTip(tooltip)
 222        if directory and Path(directory).is_dir():
 223            textbox.setText(str(directory))
 224        else:
 225            # Handle the case where the selected path is not a directory
 226            print("Invalid directory selected. Please try again.")
 227
 228    def _get_file_path(self, name, textbox, tooltip=None):
 229        file_path, _ = QtWidgets.QFileDialog.getOpenFileName(
 230            self, "Select File", "", "All Files (*)"
 231        )
 232        if tooltip:
 233            file_path.setToolTip(tooltip)
 234        if file_path and Path(file_path).is_file():
 235            textbox.setText(str(file_path))
 236        else:
 237            # Handle the case where the selected path is not a file
 238            print("Invalid file selected. Please try again.")
 239
 240    def _get_model_size_options(self):
 241        # We store the actual model names mapped to UI labels.
 242        self.model_size_mapping = {}
 243        if self.model_family == "Natural Images (SAM)":
 244            self.model_size_options = list(self._model_size_map .values())
 245            self.model_size_mapping = {self._model_size_map[k]: f"vit_{k}" for k in self._model_size_map.keys()}
 246        else:
 247            model_suffix = self.supported_dropdown_maps[self.model_family]
 248            self.model_size_options = []
 249
 250            for option in self.model_options:
 251                if option.endswith(model_suffix):
 252                    # Extract model size character on-the-fly.
 253                    key = next((k for k in self._model_size_map .keys() if f"vit_{k}" in option), None)
 254                    if key:
 255                        size_label = self._model_size_map[key]
 256                        self.model_size_options.append(size_label)
 257                        self.model_size_mapping[size_label] = option  # Store the actual model name.
 258
 259        # We ensure an assorted order of model sizes ('tiny' to 'huge')
 260        self.model_size_options.sort(key=lambda x: ["tiny", "base", "large", "huge"].index(x))
 261
 262    def _update_model_type(self):
 263        # Get currently selected model size (before clearing dropdown)
 264        current_selection = self.model_size_dropdown.currentText()
 265        self._get_model_size_options()  # Update model size options dynamically
 266
 267        # NOTE: We need to prevent recursive updates for this step temporarily.
 268        self.model_size_dropdown.blockSignals(True)
 269
 270        # Let's clear and recreate the dropdown.
 271        self.model_size_dropdown.clear()
 272        self.model_size_dropdown.addItems(self.model_size_options)
 273
 274        # We restore the previous selection, if still valid.
 275        if current_selection in self.model_size_options:
 276            self.model_size = current_selection
 277        else:
 278            if self.model_size_options:  # Default to the first available model size
 279                self.model_size = self.model_size_options[0]
 280
 281        # Let's map the selection to the correct model type (eg. "tiny" -> "vit_t")
 282        size_key = next(
 283            (k for k, v in self._model_size_map.items() if v == self.model_size), "b"
 284        )
 285        self.model_type = f"vit_{size_key}" + self.supported_dropdown_maps[self.model_family]
 286
 287        self.model_size_dropdown.setCurrentText(self.model_size)  # Apply the selected text to the dropdown
 288
 289        # We force a refresh for UI here.
 290        self.model_size_dropdown.update()
 291
 292        # NOTE: And finally, we should re-enable signals again.
 293        self.model_size_dropdown.blockSignals(False)
 294
 295    def _create_model_section(self, default_model: str = util._DEFAULT_MODEL, create_layout: bool = True):
 296
 297        # Create a list of support dropdown values and correspond them to suffixes.
 298        self.supported_dropdown_maps = {
 299            "Natural Images (SAM)": "",
 300            "Light Microscopy": "_lm",
 301            "Electron Microscopy": "_em_organelles",
 302            "Medical Imaging": "_medical_imaging",
 303            "Histopathology": "_histopathology",
 304        }
 305
 306        # NOTE: The available options for all are either 'tiny', 'base', 'large' or 'huge'.
 307        self._model_size_map = {"t": "tiny", "b": "base", "l": "large", "h": "huge"}
 308
 309        self._default_model_choice = default_model
 310        # Let's set the literally default model choice depending on 'micro-sam'.
 311        self.model_family = {v: k for k, v in self.supported_dropdown_maps.items()}[self._default_model_choice[5:]]
 312
 313        kwargs = {}
 314        if create_layout:
 315            layout = QtWidgets.QVBoxLayout()
 316            kwargs["layout"] = layout
 317
 318        # NOTE: We stick to the base variant for each model family.
 319        # i.e. 'Natural Images (SAM)', 'Light Microscopy', 'Electron Microscopy', 'Medical_Imaging', 'Histopathology'.
 320        self.model_family_dropdown, layout = self._add_choice_param(
 321            "model_family", self.model_family, list(self.supported_dropdown_maps.keys()),
 322            title="Model:", tooltip=get_tooltip("embedding", "model_family"), **kwargs,
 323        )
 324        self.model_family_dropdown.currentTextChanged.connect(self._update_model_type)
 325        return layout
 326
 327    def _create_model_size_section(self):
 328
 329        # Create UI for the model size.
 330        # This would combine with the chosen 'self.model_family' and depend on 'self._default_model_choice'.
 331        self.model_size = self._model_size_map[self._default_model_choice[4]]
 332
 333        # Get all model options.
 334        self.model_options = list(util.models().urls.keys())
 335        # Filter out the decoders from the model list.
 336        self.model_options = [model for model in self.model_options if not model.endswith("decoder")]
 337
 338        # Now, we get the available sizes per model family.
 339        self._get_model_size_options()
 340
 341        self.model_size_dropdown, layout = self._add_choice_param(
 342            "model_size", self.model_size, self.model_size_options,
 343            title="model size:", tooltip=get_tooltip("embedding", "model_size"),
 344        )
 345        self.model_size_dropdown.currentTextChanged.connect(self._update_model_type)
 346        return layout
 347
 348    def _validate_model_type_and_custom_weights(self):
 349        # Let's get all model combination stuff into the desired `model_type` structure.
 350        self.model_type = "vit_" + self.model_size[0] + self.supported_dropdown_maps[self.model_family]
 351
 352        # For 'custom_weights', we remove the displayed text on top of the drop-down menu.
 353        if self.custom_weights:
 354            # NOTE: We prevent recursive updates for this step temporarily.
 355            self.model_family_dropdown.blockSignals(True)
 356            self.model_family_dropdown.setCurrentIndex(-1)  # This removes the displayed text.
 357            self.model_family_dropdown.update()
 358            # NOTE: And re-enable signals again.
 359            self.model_family_dropdown.blockSignals(False)
 360
 361
 362# Custom signals for managing progress updates.
 363class PBarSignals(QObject):
 364    pbar_total = Signal(int)
 365    pbar_update = Signal(int)
 366    pbar_description = Signal(str)
 367    pbar_stop = Signal()
 368    pbar_reset = Signal()
 369
 370
 371class InfoDialog(QtWidgets.QDialog):
 372    def __init__(self, title, message):
 373        super().__init__()
 374        self.setWindowTitle(title)
 375
 376        layout = QtWidgets.QVBoxLayout()
 377        layout.addWidget(QtWidgets.QLabel(message))
 378
 379        # Add buttons
 380        button_box = QtWidgets.QHBoxLayout()  # Use QHBoxLayout for buttons side-by-side
 381        accept_button = QtWidgets.QPushButton("OK")
 382        accept_button.clicked.connect(lambda: self.button_clicked(accept_button))  # Connect to clicked signal
 383        button_box.addWidget(accept_button)
 384
 385        cancel_button = QtWidgets.QPushButton("Cancel")
 386        cancel_button.clicked.connect(lambda: self.button_clicked(cancel_button))  # Connect to clicked signal
 387        button_box.addWidget(cancel_button)
 388
 389        layout.addLayout(button_box)
 390        self.setLayout(layout)
 391
 392    def button_clicked(self, button):
 393        if button.text() == "OK":
 394            self.accept()  # Accept the dialog
 395        else:
 396            self.reject()  # Reject the dialog (Cancel)
 397
 398
 399# Set up the progress bar. We handle this via custom signals that are passed as callbacks to the
 400# function that does the actual work. We need callbacks for initializing the progress bar,
 401# updating it and for stopping the progress bar.
 402def _create_pbar_for_threadworker():
 403    pbar = progress()
 404    pbar_signals = PBarSignals()
 405    pbar_signals.pbar_total.connect(lambda total: setattr(pbar, "total", total))
 406    pbar_signals.pbar_update.connect(lambda update: pbar.update(update))
 407    pbar_signals.pbar_description.connect(lambda description: pbar.set_description(description))
 408    pbar_signals.pbar_stop.connect(lambda: pbar.close())
 409    pbar_signals.pbar_reset.connect(lambda: pbar.reset())
 410    return pbar, pbar_signals
 411
 412
 413def _reset_tracking_state(viewer):
 414    """Reset the tracking state.
 415
 416    This helper function is needed by the widgets clear_track and by commit_track.
 417    """
 418    state = AnnotatorState()
 419
 420    # Reset the lineage and track id.
 421    state.current_track_id = 1
 422    state.lineage = {1: []}
 423
 424    # Reset the layer properties.
 425    viewer.layers["point_prompts"].property_choices["track_id"] = ["1"]
 426    viewer.layers["prompts"].property_choices["track_id"] = ["1"]
 427
 428    # Reset the choices in the track_id menu.
 429    state.widgets["tracking"][1].value = "1"
 430    state.widgets["tracking"][1].choices = ["1"]
 431
 432
 433#
 434# Widgets implemented with magicgui.
 435#
 436
 437
 438@magic_factory(call_button="Clear Annotations [Shift + C]")
 439def clear(viewer: "napari.viewer.Viewer") -> None:
 440    """Widget for clearing the current annotations.
 441
 442    Args:
 443        viewer: The napari viewer.
 444    """
 445    vutil.clear_annotations(viewer)
 446
 447    # Perform garbage collection.
 448    gc.collect()
 449
 450
 451@magic_factory(call_button="Clear Annotations [Shift + C]")
 452def clear_volume(viewer: "napari.viewer.Viewer", all_slices: bool = True) -> None:
 453    """Widget for clearing the current annotations in 3D.
 454
 455    Args:
 456        viewer: The napari viewer.
 457        all_slices: Choose whether to clear the annotations for all or only the current slice.
 458    """
 459    if all_slices:
 460        vutil.clear_annotations(viewer)
 461    else:
 462        i = int(viewer.dims.point[0])
 463        vutil.clear_annotations_slice(viewer, i=i)
 464
 465    # Perform garbage collection.
 466    gc.collect()
 467
 468
 469@magic_factory(call_button="Clear Annotations [Shift + C]")
 470def clear_track(viewer: "napari.viewer.Viewer", all_frames: bool = True) -> None:
 471    """Widget for clearing all tracking annotations and state.
 472
 473    Args:
 474        viewer: The napari viewer.
 475        all_frames: Choose whether to clear the annotations for all or only the current frame.
 476    """
 477    if all_frames:
 478        _reset_tracking_state(viewer)
 479        vutil.clear_annotations(viewer)
 480    else:
 481        i = int(viewer.dims.point[0])
 482        vutil.clear_annotations_slice(viewer, i=i)
 483
 484    # Perform garbage collection.
 485    gc.collect()
 486
 487
 488def _mask_matched_objects(seg, prev_seg, preservation_threshold):
 489    prev_ids = np.unique(prev_seg)
 490    ovlp = segmentation_overlap(prev_seg, seg)
 491
 492    mask_ids, prev_mask_ids = [], []
 493    for prev_id in prev_ids:
 494        ovlp_table = ovlp.overlaps_for_label_a(prev_id)
 495        seg_ids, overlaps = ovlp_table["label"], ovlp_table["count"]
 496        if seg_ids[0] != 0 and overlaps[0] >= preservation_threshold:
 497            mask_ids.append(seg_ids[0])
 498            prev_mask_ids.append(prev_id)
 499
 500    preserve_mask = np.logical_or(np.isin(seg, mask_ids), np.isin(prev_seg, prev_mask_ids))
 501    return preserve_mask
 502
 503
 504def _commit_impl(viewer, layer, preserve_mode, preservation_threshold):
 505    state = AnnotatorState()
 506
 507    # Check whether all layers exist as expected or create new ones automatically.
 508    state.annotator._require_layers(layer_choices=[layer, "committed_objects"])
 509
 510    # Check if we have a z_range. If yes, use it to set a bounding box.
 511    if state.z_range is None:
 512        bb = np.s_[:]
 513    else:
 514        z_min, z_max = state.z_range
 515        bb = np.s_[z_min:(z_max+1)]
 516
 517    # Cast the dtype of the segmentation we work with correctly.
 518    # Otherwise we run into type conversion errors later.
 519    dtype = viewer.layers["committed_objects"].data.dtype
 520    seg = viewer.layers[layer].data[bb].astype(dtype)
 521    shape = seg.shape
 522
 523    # We parallelize these operations because they take quite long for large volumes.
 524
 525    # Compute the max id in the commited objects.
 526    # id_offset = int(viewer.layers["committed_objects"].data.max())
 527    full_shape = viewer.layers["committed_objects"].data.shape
 528    id_offset = int(
 529        elf.parallel.max(viewer.layers["committed_objects"].data, block_shape=util.get_block_shape(full_shape))
 530    )
 531
 532    # Compute the mask for the current object.
 533    # mask = seg != 0
 534    mask = np.zeros(seg.shape, dtype="bool")
 535    mask = elf.parallel.apply_operation(
 536        seg, 0, np.not_equal, out=mask, block_shape=util.get_block_shape(shape)
 537    )
 538    if preserve_mode != "none":
 539        prev_seg = viewer.layers["committed_objects"].data[bb]
 540        # The mode 'pixels' corresponds to a naive implementation where only committed pixels are preserved.
 541        preserve_mask = prev_seg != 0
 542        # If the preserve mask is empty we don't need to do anything else here, because we don't have prev objects.
 543        if preserve_mask.sum() != 0:
 544            # In the mode 'objects' we preserve committed objects instead, by comparing the overlaps
 545            # of already committed and newly committed objects.
 546            if preserve_mode == "objects":
 547                preserve_mask = _mask_matched_objects(seg, prev_seg, preservation_threshold)
 548            mask[preserve_mask] = 0
 549
 550    # Write the current object to committed objects.
 551    seg[mask] += id_offset
 552    viewer.layers["committed_objects"].data[bb][mask] = seg[mask]
 553    viewer.layers["committed_objects"].refresh()
 554
 555    return id_offset, seg, mask, bb
 556
 557
 558def _get_auto_segmentation_options(state, object_ids):
 559    widget = state.widgets["autosegment"]
 560
 561    segmentation_options = {"object_ids": [int(object_id) for object_id in object_ids]}
 562    if widget.with_decoder:
 563        segmentation_options["boundary_distance_thresh"] = widget.boundary_distance_thresh
 564        segmentation_options["center_distance_thresh"] = widget.center_distance_thresh
 565    else:
 566        segmentation_options["pred_iou_thresh"] = widget.pred_iou_thresh
 567        segmentation_options["stability_score_thresh"] = widget.stability_score_thresh
 568        segmentation_options["box_nms_thresh"] = widget.box_nms_thresh
 569
 570    segmentation_options["min_object_size"] = widget.min_object_size
 571    if widget.volumetric:
 572        segmentation_options["apply_to_volume"] = widget.apply_to_volume
 573        segmentation_options["gap_closing"] = widget.gap_closing
 574        segmentation_options["min_extent"] = widget.min_extent
 575
 576    return segmentation_options
 577
 578
 579def _get_promptable_segmentation_options(state, object_ids):
 580    segmentation_options = {"object_ids": [int(object_id) for object_id in object_ids]}
 581    is_tracking = False
 582    if "segment_nd" in state.widgets:
 583        widget = state.widgets["segment_nd"]
 584        segmentation_options["projection"] = widget.projection
 585        segmentation_options["iou_threshold"] = widget.iou_threshold
 586        segmentation_options["box_extension"] = widget.box_extension
 587        if widget.tracking:
 588            segmentation_options["motion_smoothing"] = widget.motion_smoothing
 589            is_tracking = True
 590    return segmentation_options, is_tracking
 591
 592
 593def _commit_to_file(path, viewer, layer, seg, mask, bb, extra_attrs=None):
 594
 595    if z5py is None:
 596        raise RuntimeError(
 597            "Committing annotations to file requires z5py, which is only available via conda. "
 598            "Install it with 'conda install -c conda-forge z5py'."
 599        )
 600
 601    # NOTE: zarr-python is quite inefficient and writes empty blocks.
 602    # So we have to use z5py here.
 603
 604    # Deal with issues z5py has with empty folders and require the json.
 605    if os.path.exists(path):
 606        required_json = os.path.join(path, ".zgroup")
 607        if not os.path.exists(required_json):
 608            with open(required_json, "w") as f:
 609                json.dump({"zarr_format": 2}, f)
 610
 611    f = z5py.ZarrFile(path, "a")
 612    state = AnnotatorState()
 613
 614    def _save_signature(f, data_signature):
 615        embeds = state.widgets["embeddings"]
 616        tile_shape, halo = _process_tiling_inputs(embeds.tile_x, embeds.tile_y, embeds.halo_x, embeds.halo_y)
 617        signature = util._get_embedding_signature(
 618            input_=None,  # We don't need this because we pass the data signature.
 619            predictor=state.predictor,
 620            tile_shape=tile_shape,
 621            halo=halo,
 622            data_signature=data_signature,
 623        )
 624        for key, val in signature.items():
 625            f.attrs[key] = val
 626
 627    # If the data signature is saved in the file already,
 628    # then we check if saved data signature and data signature of our image agree.
 629    # If not, this file was used for committing objects from another file.
 630    if "data_signature" in f.attrs:
 631        saved_signature = f.attrs["data_signature"]
 632        current_signature = state.data_signature
 633        if saved_signature != current_signature:  # Signatures disagree.
 634            msg = f"The commit_path {path} was already used for saving annotations for different image data:\n"
 635            msg += f"The data signatures are different: {saved_signature} != {current_signature}.\n"
 636            msg += "Press 'Ok' to remove the data already stored in that file and continue annotation.\n"
 637            msg += "Otherwise please select a different file path."
 638            skip_clear = _generate_message("info", msg)
 639            if skip_clear:
 640                return
 641            else:
 642                f = z5py.ZarrFile(path, "w")
 643                _save_signature(f, current_signature)
 644    # Otherwise (data signature not saved yet), write the current signature.
 645    else:
 646        _save_signature(f, state.data_signature)
 647
 648    # Write the segmentation.
 649    full_shape = viewer.layers["committed_objects"].data.shape
 650    block_shape = util.get_block_shape(full_shape)
 651    ds = f.require_dataset(
 652        "committed_objects", shape=full_shape, chunks=block_shape, compression="gzip", dtype=seg.dtype
 653    )
 654    ds.n_threads = mp.cpu_count()
 655    data = ds[bb]
 656    data[mask] = seg[mask]
 657    ds[bb] = data
 658
 659    # Write additional information to attrs.
 660    if extra_attrs is not None:
 661        f.attrs.update(extra_attrs)
 662
 663    # Get the commit history and the objects that are being commited.
 664    commit_history = f.attrs.get("commit_history", [])
 665    object_ids = np.unique(seg[mask])
 666
 667    # We committed an automatic segmentation.
 668    if layer == "auto_segmentation":
 669        # Save the settings of the segmentation widget.
 670        segmentation_options = _get_auto_segmentation_options(state, object_ids)
 671        commit_history.append({"auto_segmentation": segmentation_options})
 672
 673        # Write the commit history.
 674        f.attrs["commit_history"] = commit_history
 675
 676        # If we run commit from the automatic segmentation we don't have
 677        # any prompts and so don't need to commit anything else.
 678        return
 679
 680    segmentation_options, is_tracking = _get_promptable_segmentation_options(state, object_ids)
 681    commit_history.append({"current_object": segmentation_options})
 682
 683    def write_prompts(object_id, prompts, point_prompts, point_labels, track_state=None):
 684        g = f.create_group(f"prompts/{object_id}")
 685        if prompts is not None and len(prompts) > 0:
 686            data = np.array(prompts)
 687            g.create_dataset("prompts", data=data, shape=data.shape, chunks=data.shape)
 688        if point_prompts is not None and len(point_prompts) > 0:
 689            g.create_dataset("point_prompts", data=point_prompts, shape=data.shape, chunks=point_prompts.shape)
 690            ds = g.create_dataset("point_labels", data=point_labels, shape=data.shape, chunks=point_labels.shape)
 691            if track_state is not None:
 692                ds.attrs["track_state"] = track_state.tolist()
 693
 694    # Get the prompts from the layers.
 695    prompts = viewer.layers["prompts"].data
 696    point_layer = viewer.layers["point_prompts"]
 697    point_prompts = point_layer.data
 698    point_labels = point_layer.properties["label"]
 699    if len(point_prompts) > 0:
 700        point_labels = np.array([1 if label == "positive" else 0 for label in point_labels])
 701        assert len(point_prompts) == len(point_labels), \
 702            f"Number of point prompts and labels disagree: {len(point_prompts)} != {len(point_labels)}"
 703
 704    # Commit the prompts for all the objects in the commit.
 705    if len(object_ids) == 1:  # We only have a single object.
 706        write_prompts(object_ids[0], prompts, point_prompts, point_labels)
 707
 708    elif is_tracking:  # We have multiple objects from tracking a lineage with divisions.
 709        track_ids_points = np.array(point_layer.properties["track_id"])
 710        track_ids_prompts = np.array(viewer.layers["prompts"].properties["track_id"])
 711
 712        unique_track_ids = np.unique(track_ids_points)
 713        assert len(unique_track_ids) == len(object_ids)
 714        track_state = np.array(point_layer.properties["state"])
 715        for track_id, object_id in zip(unique_track_ids, object_ids):
 716            this_prompts = None if len(prompts) == 0 else prompts[track_ids_prompts == track_id]
 717            point_mask = track_ids_points == track_id
 718            this_points, this_labels, this_track_state = \
 719                point_prompts[point_mask], point_labels[point_mask], track_state[point_mask]
 720            write_prompts(object_id, this_prompts, this_points, this_labels, track_state=this_track_state)
 721
 722    else:  # We have multiple objects, which are the result from batched interactive segmentation.
 723        # Note: we can't match exact object ids to their prompts, for batched segmentation.
 724        # We first write the objects from box prompts, then from point prompts.
 725        n_prompts, n_points = len(prompts), len(point_prompts)
 726        assert n_prompts + n_points == len(object_ids), \
 727            f"Number of prompts and objects disagree: {n_prompts} + {n_points} != {len(object_ids)}"
 728        for i, object_id in enumerate(object_ids):
 729            if i < n_prompts:
 730                this_prompts, this_points, this_labels = prompts[i:i+1], None, None
 731            else:
 732                j = i - n_prompts
 733                this_prompts, this_points, this_labels = None, point_prompts[j:j+1], point_labels[j:j+1]
 734            write_prompts(object_id, this_prompts, this_points, this_labels)
 735
 736    # Write the commit history.
 737    f.attrs["commit_history"] = commit_history
 738
 739
 740@magic_factory(
 741    call_button="Commit [C]",
 742    layer={"choices": ["current_object", "auto_segmentation"], "tooltip": get_tooltip("commit", "layer")},
 743    preserve_mode={"choices": ["objects", "pixels", "none"], "tooltip": get_tooltip("commit", "preserve_mode")},
 744    commit_path={"mode": "d", "tooltip": get_tooltip("commit", "commit_path")},
 745)
 746def commit(
 747    viewer: "napari.viewer.Viewer",
 748    layer: str = "current_object",
 749    preserve_mode: str = "objects",
 750    preservation_threshold: float = 0.75,
 751    commit_path: Optional[Path] = None,
 752) -> None:
 753    """Widget for committing the segmented objects from automatic or interactive segmentation.
 754
 755    Args:
 756        viewer: The napari viewer.
 757        layer: Select the layer to commit. Can be either 'current_object' to commit interacitve segmentation results.
 758            Or 'auto_segmentation' to commit automatic segmentation results.
 759        preserve_mode: The mode for preserving already committed objects, in order to prevent over-writing
 760            them by a new commit. Supports the modes 'objects', which preserves on the object level and is the default,
 761            'pixels', which preserves on the pixel-level, or 'none', which does not preserve commited objects.
 762        preservation_threshold: The overlap threshold for preserving objects. This is only used if
 763            preservation_mode is set to 'objects'.
 764        commit_path: Select a file path where the committed results and prompts will be saved.
 765            This feature is still experimental.
 766    """
 767    # Commit the segmentation layer.
 768    _, seg, mask, bb = _commit_impl(viewer, layer, preserve_mode, preservation_threshold)
 769
 770    if commit_path is not None:
 771        _commit_to_file(commit_path, viewer, layer, seg, mask, bb)
 772
 773    if layer == "current_object":
 774        vutil.clear_annotations(viewer)
 775    else:
 776        viewer.layers["auto_segmentation"].data = np.zeros(
 777            viewer.layers["auto_segmentation"].data.shape, dtype="uint32"
 778        )
 779        viewer.layers["auto_segmentation"].refresh()
 780        _select_layer(viewer, "committed_objects")
 781
 782    # Perform garbage collection
 783    gc.collect()
 784
 785
 786@magic_factory(
 787    call_button="Commit [C]",
 788    layer={"choices": ["current_object", "auto_segmentation"]},
 789    preserve_mode={"choices": ["objects", "pixels", "none"]},
 790    commit_path={"mode": "d"},  # choose a directory
 791)
 792def commit_track(
 793    viewer: "napari.viewer.Viewer",
 794    layer: str = "current_object",
 795    preserve_mode: str = "objects",
 796    preservation_threshold: float = 0.75,
 797    commit_path: Optional[Path] = None,
 798) -> None:
 799    """Widget for committing the objects from interactive tracking.
 800
 801    Args:
 802        viewer: The napari viewer.
 803        layer: Select the layer to commit. Can be either 'current_object' to commit interacitve segmentation results.
 804            Or 'auto_segmentation' to commit automatic segmentation results.
 805        preserve_mode: The mode for preserving already committed objects, in order to prevent over-writing
 806            them by a new commit. Supports the modes 'objects', which preserves on the object level and is the default,
 807            'pixels', which preserves on the pixel-level, or 'none', which does not preserve commited objects.
 808        preservation_threshold: The overlap threshold for preserving objects. This is only used if
 809            preservation_mode is set to 'objects'.
 810        commit_path: Select a file path where the committed results and prompts will be saved.
 811            This feature is still experimental.
 812    """
 813    # Commit the segmentation layer.
 814    id_offset, seg, mask, bb = _commit_impl(viewer, layer, preserve_mode, preservation_threshold)
 815
 816    # Update the lineages.
 817    state = AnnotatorState()
 818    lineage = state.lineage
 819
 820    if isinstance(lineage, list):  # This is a list of lineages from auto-tracking.
 821        assert id_offset == 0
 822        assert len(state.committed_lineages) == 0
 823        state.committed_lineages.extend(lineage)
 824    else:  # This is a single lineage from interactive tracking.
 825        updated_lineage = {
 826            parent + id_offset: [child + id_offset for child in children] for parent, children in state.lineage.items()
 827        }
 828        state.committed_lineages.append(updated_lineage)
 829
 830    if commit_path is not None:
 831        _commit_to_file(
 832            commit_path, viewer, layer, seg, mask, bb,
 833            extra_attrs={"committed_lineages": state.committed_lineages}
 834        )
 835
 836    if layer == "current_object":
 837        vutil.clear_annotations(viewer)
 838
 839    # Create / update the tracking layer.
 840    layer_name = "tracks"
 841    segmentation = viewer.layers["committed_objects"].data
 842    track_data, parent_graph = get_napari_track_data(segmentation, state.committed_lineages)
 843    if layer_name in viewer.layers:
 844        layer = viewer.layers[layer_name]
 845        layer.data = track_data
 846        layer.graph = parent_graph
 847    else:
 848        viewer.add_tracks(track_data, name=layer_name, graph=parent_graph)
 849
 850    # Reset the tracking state.
 851    _reset_tracking_state(viewer)
 852
 853    # Perform garbage collection.
 854    gc.collect()
 855
 856
 857def create_prompt_menu(points_layer, labels, menu_name="prompt", label_name="label"):
 858    """Create the menu for toggling point prompt labels."""
 859    label_menu = ComboBox(label=menu_name, choices=labels, tooltip=get_tooltip("prompt_menu", "labels"))
 860    label_widget = Container(widgets=[label_menu])
 861
 862    def update_label_menu(event):
 863        new_label = str(points_layer.current_properties[label_name][0])
 864        if new_label != label_menu.value:
 865            label_menu.value = new_label
 866
 867    points_layer.events.current_properties.connect(update_label_menu)
 868
 869    def label_changed(new_label):
 870        current_properties = points_layer.current_properties
 871        current_properties[label_name] = np.array([new_label])
 872        points_layer.current_properties = current_properties
 873        points_layer.refresh_colors()
 874
 875    label_menu.changed.connect(label_changed)
 876
 877    return label_widget
 878
 879
 880@magic_factory(
 881    call_button="Update settings",
 882    cache_directory={"mode": "d"},  # choose a directory
 883)
 884def settings_widget(cache_directory: Optional[Path] = util.get_cache_directory()) -> None:
 885    """Widget to update global micro_sam settings.
 886
 887    Args:
 888        cache_directory: Select the path for the micro_sam cache directory. `$HOME/.cache/micro_sam`.
 889    """
 890    os.environ["MICROSAM_CACHEDIR"] = str(cache_directory)
 891    print(f"micro-sam cache directory set to: {cache_directory}")
 892
 893
 894def _generate_message(message_type: str, message: str) -> bool:
 895    """
 896    Displays a message dialog based on the provided message type.
 897
 898    Args:
 899        message_type: The type of message to display. Valid options are:
 900            - "error": Displays a critical error message with an "Ok" button.
 901            - "info": Displays an informational message in a separate dialog box.
 902                 The user can dismiss it by either clicking "Ok" or closing the dialog.
 903        message: The message content to be displayed in the dialog.
 904
 905    Returns:
 906        A flag indicating whether the user aborted the operation based on the
 907        message type. This flag is only set for "info" messages where the user
 908        can choose to cancel (rejected).
 909
 910    Raises:
 911        ValueError: If an invalid message type is provided.
 912    """
 913    # Set button text and behavior based on message type
 914    if message_type == "error":
 915        QtWidgets.QMessageBox.critical(None, "Error", message, QtWidgets.QMessageBox.Ok)
 916        abort = True
 917        return abort
 918    elif message_type == "info":
 919        info_dialog = InfoDialog(title="Validation Message", message=message)
 920        result = info_dialog.exec_()
 921        if result == QtWidgets.QDialog.Rejected:  # Check for cancel
 922            abort = True  # Set flag directly in calling function
 923            return abort
 924    else:
 925        raise ValueError(f"Invalid message type {message_type}")
 926
 927
 928def _validate_embeddings(viewer: "napari.viewer.Viewer"):
 929    state = AnnotatorState()
 930    if state.image_embeddings is None:
 931        msg = "Image embeddings are not yet computed. Press 'Compute Embeddings' to compute them for your image."
 932        return _generate_message("error", msg)
 933    else:
 934        return False
 935
 936    # This code is for checking the data signature of the current image layer and the data signature
 937    # of the embeddings. However, the code has some disadvantages, for example assuming the position of the
 938    # image layer and also having to compute the data signature every time.
 939    # That's why we are not using this for now, but may want to revisit this in the future. See:
 940    # https://github.com/computational-cell-analytics/micro-sam/issues/504
 941
 942    # embeddings_save_path = state.embedding_path
 943    # embedding_data_signature = None
 944    # image = None
 945    # if isinstance(viewer.layers[0], napari.layers.Image):  # Assuming the image layer is at index 0
 946    #     image = viewer.layers[0]
 947    # else:
 948    #     # Handle the case where the first layer isn't an Image layer
 949    #     raise ValueError("Expected an Image layer in viewer.layers")
 950    # img_signature = util._compute_data_signature(image.data)
 951    # if embeddings_save_path is not None:
 952    #     # Check for existing embeddings
 953    #     if os.listdir(embeddings_save_path):
 954    #         try:
 955    #             with zarr.open(embeddings_save_path, "a") as f:
 956    #                 # If data_signature exists, compare and return validation message
 957    #                 if "data_signature" in f.attrs:
 958    #                     embedding_data_signature = f.attrs["data_signature"]
 959    #         except RuntimeError as e:
 960    #             val_results = {
 961    #                 "message_type": "error",
 962    #                 "message": f"Failed to load image embeddings: {e}"
 963    #             }
 964    #     else:
 965    #         val_results = {"message_type": "info", "message": "No existing embeddings found at the specified path."}
 966    # else:  # load from state object
 967    #     embedding_data_signature = state.data_signature
 968    # # compare image data signature with embedding data signature
 969    # if img_signature != embedding_data_signature:
 970    #     val_results = {
 971    #         "message_type": "error",
 972    #         "message": f"The embeddings don't match with the image: {img_signature} {embedding_data_signature}"
 973    #     }
 974    # else:
 975    #     val_results = None
 976    # if val_results:
 977    #     return _generate_message(val_results["message_type"], val_results["message"])
 978    # else:
 979    #     return False
 980
 981
 982def _validation_window_for_missing_layer(layer_choice):
 983    if layer_choice == "committed_objects":
 984        msg = "The 'committed_objects' layer to commit masks is missing. Please try to commit again."
 985    else:
 986        msg = f"The '{layer_choice}' layer to commit is missing. Please re-annotate and try again."
 987
 988    return _generate_message(message_type="error", message=msg)
 989
 990
 991def _validate_layers(viewer: "napari.viewer.Viewer", automatic_segmentation: bool = False) -> bool:
 992    # Check whether all layers exist as expected or create new ones automatically.
 993    state = AnnotatorState()
 994    state.annotator._require_layers()
 995
 996    if not automatic_segmentation:
 997        # Check prompts layer.
 998        if len(viewer.layers["prompts"].data) == 0 and len(viewer.layers["point_prompts"].data) == 0:
 999            msg = "No prompts were given. Please provide prompts to run interactive segmentation."
1000            return _generate_message("error", msg)
1001        else:
1002            return False
1003
1004
1005@magic_factory(call_button="Segment Object [S]")
1006def segment(viewer: "napari.viewer.Viewer", batched: bool = False) -> None:
1007    """Segment object(s) for the current prompts.
1008
1009    Args:
1010        viewer: The napari viewer.
1011        batched: Choose if you want to segment multiple objects with point prompts.
1012    """
1013    if _validate_embeddings(viewer):
1014        return None
1015    if _validate_layers(viewer):
1016        return None
1017
1018    shape = viewer.layers["current_object"].data.shape
1019
1020    # get the current box and point prompts
1021    boxes, masks = vutil.shape_layer_to_prompts(viewer.layers["prompts"], shape)
1022    points, labels = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], with_stop_annotation=False)
1023
1024    predictor = AnnotatorState().predictor
1025    image_embeddings = AnnotatorState().image_embeddings
1026    seg = vutil.prompt_segmentation(
1027        predictor, points, labels, boxes, masks, shape, image_embeddings=image_embeddings,
1028        multiple_box_prompts=True, batched=batched, previous_segmentation=viewer.layers["current_object"].data,
1029    )
1030
1031    # no prompts were given or prompts were invalid, skip segmentation
1032    if seg is None:
1033        print("You either haven't provided any prompts or invalid prompts. The segmentation will be skipped.")
1034        return
1035
1036    viewer.layers["current_object"].data = seg
1037    viewer.layers["current_object"].refresh()
1038
1039
1040@magic_factory(call_button="Segment Slice [S]")
1041def segment_slice(viewer: "napari.viewer.Viewer") -> None:
1042    """Segment object for to the current prompts.
1043
1044    Args:
1045        viewer: The napari viewer.
1046    """
1047    if _validate_embeddings(viewer):
1048        return None
1049    if _validate_layers(viewer):
1050        return None
1051
1052    shape = viewer.layers["current_object"].data.shape[1:]
1053
1054    position_world = viewer.dims.point
1055    position = viewer.layers["point_prompts"].world_to_data(position_world)
1056    z = int(position[0])
1057
1058    point_prompts = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], z)
1059    # this is a stop prompt, we do nothing
1060    if not point_prompts:
1061        return
1062
1063    boxes, masks = vutil.shape_layer_to_prompts(viewer.layers["prompts"], shape, i=z)
1064    points, labels = point_prompts
1065
1066    state = AnnotatorState()
1067    seg = vutil.prompt_segmentation(
1068        state.predictor, points, labels, boxes, masks, shape, multiple_box_prompts=False,
1069        image_embeddings=state.image_embeddings, i=z,
1070    )
1071
1072    # no prompts were given or prompts were invalid, skip segmentation
1073    if seg is None:
1074        print("You either haven't provided any prompts or invalid prompts. The segmentation will be skipped.")
1075        return
1076
1077    viewer.layers["current_object"].data[z] = seg
1078    viewer.layers["current_object"].refresh()
1079
1080
1081@magic_factory(call_button="Segment Frame [S]")
1082def segment_frame(viewer: "napari.viewer.Viewer") -> None:
1083    """Segment object for the current prompts.
1084
1085    Args:
1086        viewer: The napari viewer.
1087    """
1088    if _validate_embeddings(viewer):
1089        return None
1090    if _validate_layers(viewer):
1091        return None
1092
1093    state = AnnotatorState()
1094    shape = state.image_shape[1:]
1095    position = viewer.dims.point
1096    t = int(position[0])
1097
1098    point_prompts = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], i=t, track_id=state.current_track_id)
1099    # this is a stop prompt, we do nothing
1100    if not point_prompts:
1101        return
1102
1103    boxes, masks = vutil.shape_layer_to_prompts(viewer.layers["prompts"], shape, i=t, track_id=state.current_track_id)
1104    points, labels = point_prompts
1105
1106    seg = vutil.prompt_segmentation(
1107        state.predictor, points, labels, boxes, masks, shape, multiple_box_prompts=False,
1108        image_embeddings=state.image_embeddings, i=t
1109    )
1110
1111    # no prompts were given or prompts were invalid, skip segmentation
1112    if seg is None:
1113        print("You either haven't provided any prompts or invalid prompts. The segmentation will be skipped.")
1114        return
1115
1116    # clear the old segmentation for this track_id
1117    old_mask = viewer.layers["current_object"].data[t] == state.current_track_id
1118    viewer.layers["current_object"].data[t][old_mask] = 0
1119    # set the new segmentation
1120    new_mask = seg.squeeze() == 1
1121    viewer.layers["current_object"].data[t][new_mask] = state.current_track_id
1122    viewer.layers["current_object"].refresh()
1123
1124
1125#
1126# Functionality and widget to compute the image embeddings.
1127#
1128
1129
1130def _process_tiling_inputs(tile_shape_x, tile_shape_y, halo_x, halo_y):
1131    tile_shape = (tile_shape_x, tile_shape_y)
1132    halo = (halo_x, halo_y)
1133    # check if tile_shape/halo are not set: (0, 0)
1134    if all(item in (0, None) for item in tile_shape):
1135        tile_shape = None
1136    # check if at least 1 param is given
1137    elif tile_shape[0] == 0 or tile_shape[1] == 0:
1138        max_val = max(tile_shape[0], tile_shape[1])
1139        if max_val < 256:  # at least tile shape >256
1140            max_val = 256
1141        tile_shape = (max_val, max_val)
1142    # if both inputs given, check if smaller than 256
1143    elif tile_shape[0] != 0 and tile_shape[1] != 0:
1144        if tile_shape[0] < 256:
1145            tile_shape = (256, tile_shape[1])  # Create a new tuple
1146        if tile_shape[1] < 256:
1147            tile_shape = (tile_shape[0], 256)  # Create a new tuple with modified value
1148    if all(item in (0, None) for item in halo):
1149        if tile_shape is not None:
1150            halo = (0, 0)
1151        else:
1152            halo = None
1153    # check if at least 1 param is given
1154    elif halo[0] != 0 or halo[1] != 0:
1155        max_val = max(halo[0], halo[1])
1156        # don't apply halo if there is no tiling
1157        if tile_shape is None:
1158            halo = None
1159        else:
1160            halo = (max_val, max_val)
1161    return tile_shape, halo
1162
1163
1164class EmbeddingWidget(_WidgetBase):
1165    def __init__(self, parent=None):
1166        super().__init__(parent=parent)
1167
1168        # Create a nested layout for the sections.
1169        # Section 1: Image and Model.
1170        section1_layout = QtWidgets.QHBoxLayout()
1171        section1_layout.addLayout(self._create_image_section())
1172        section1_layout.addLayout(self._create_model_section())  # Creates the model family widget section.
1173        self.layout().addLayout(section1_layout)
1174
1175        # Section 2: Settings (collapsible).
1176        self.layout().addWidget(self._create_settings_widget())
1177
1178        # Section 3: The button to trigger the embedding computation.
1179        self.run_button = QtWidgets.QPushButton("Compute Embeddings")
1180        self.run_button.clicked.connect(self._initialize_image)
1181        self.run_button.clicked.connect(self.__call__)
1182        self.run_button.setToolTip(get_tooltip("embedding", "run_button"))
1183        self.layout().addWidget(self.run_button)
1184
1185    def _initialize_image(self):
1186        state = AnnotatorState()
1187        layer = self.image_selection.get_value()
1188
1189        # This is encountered when there is no image layer available / selected.
1190        # In this case, we need not specify other image-level parameters to the state. Hence, we skip them.
1191        # NOTE: On code-level, this happens as the first step when "Compute Embedding" click is triggered.
1192        if layer is None:
1193            return
1194
1195        image_shape = layer.data.shape
1196        image_scale = tuple(layer.scale)
1197        state.image_shape = image_shape
1198        state.image_scale = image_scale
1199        state.image_name = layer.name
1200
1201    def _create_image_section(self):
1202        image_section = QtWidgets.QVBoxLayout()
1203        image_layer_widget = QtWidgets.QLabel("Image Layer:")
1204        # image_layer_widget.setToolTip(get_tooltip("embedding", "image")) #  this adds tooltip to label
1205        image_section.addWidget(image_layer_widget)
1206
1207        # Setting a napari layer in QT, see:
1208        # https://github.com/pyapp-kit/magicgui/blob/main/docs/examples/napari/napari_combine_qt.py
1209        self.image_selection = create_widget(annotation=napari.layers.Image)
1210        self.image_selection.native.setToolTip(get_tooltip("embedding", "image"))
1211        image_section.addWidget(self.image_selection.native)
1212
1213        return image_section
1214
1215    def _update_model(self, state):
1216        _model_type = state.predictor.model_type if self.custom_weights else self.model_type
1217
1218        # Provide a detailed message for the model family and model size per chosen combination.
1219        msg = "Computed embeddings for "
1220        if self.custom_weights:  # Whether the user provided a filepath to custom finetuned model weights.
1221            msg += f"the model located at '{os.path.abspath(self.custom_weights)}' "
1222            msg += f"of size '{self._model_size_map[_model_type[4]]}'."
1223        else:
1224            msg += f"the '{self.model_family}' model of size '{self.model_size}'."
1225
1226        show_info(msg)
1227
1228        state = AnnotatorState()
1229        # Update the widget itself. This is necessary because we may have loaded
1230        # some settings from the embedding file and have to reflect them in the widget.
1231        vutil._sync_embedding_widget(
1232            self,
1233            model_type=_model_type,
1234            save_path=self.embeddings_save_path,
1235            checkpoint_path=self.custom_weights,
1236            device=self.device,
1237            tile_shape=[self.tile_x, self.tile_y],
1238            halo=[self.halo_x, self.halo_y]
1239        )
1240
1241        # Set the default settings for this model in the autosegment widget if it is part of
1242        # the currently used plugin.
1243        if "autosegment" in state.widgets:
1244            with_decoder = state.decoder is not None
1245            vutil._sync_autosegment_widget(
1246                state.widgets["autosegment"], _model_type, self.custom_weights, update_decoder=with_decoder
1247            )
1248            # Load the AMG/AIS state if we have a 3d segmentation plugin.
1249            if state.widgets["autosegment"].volumetric and with_decoder:
1250                state.amg_state = vutil._load_is_state(state.embedding_path)
1251            elif state.widgets["autosegment"].volumetric and not with_decoder:
1252                state.amg_state = vutil._load_amg_state(state.embedding_path)
1253
1254        # Set the default settings for this model in the nd-segmentation widget if it is part of
1255        # the currently used plugin.
1256        if "segment_nd" in state.widgets:
1257            vutil._sync_ndsegment_widget(state.widgets["segment_nd"], _model_type, self.custom_weights)
1258
1259    def _create_settings_widget(self):
1260        setting_values = QtWidgets.QWidget()
1261        setting_values.setToolTip(get_tooltip("embedding", "settings"))
1262        setting_values.setLayout(QtWidgets.QVBoxLayout())
1263
1264        # Add the model size widget section.
1265        layout = self._create_model_size_section()
1266        setting_values.layout().addLayout(layout)
1267
1268        # Create UI for the device.
1269        self.device = "auto"
1270        device_options = ["auto"] + util._available_devices()
1271
1272        self.device_dropdown, layout = self._add_choice_param(
1273            "device", self.device, device_options, tooltip=get_tooltip("embedding", "device")
1274        )
1275        setting_values.layout().addLayout(layout)
1276
1277        # Create UI for the save path.
1278        self.embeddings_save_path = None
1279        self.embeddings_save_path_param, layout = self._add_path_param(
1280            "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:",
1281            tooltip=get_tooltip("embedding", "embeddings_save_path")
1282        )
1283        setting_values.layout().addLayout(layout)
1284
1285        # Create UI for the custom weights.
1286        self.custom_weights = None
1287        self.custom_weights_param, layout = self._add_path_param(
1288            "custom_weights", self.custom_weights, "file", title="custom weights path:",
1289            tooltip=get_tooltip("embedding", "custom_weights")
1290        )
1291        setting_values.layout().addLayout(layout)
1292
1293        # Create UI for the tile shape.
1294        self.tile_x, self.tile_y = 0, 0
1295        self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
1296            ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
1297            tooltip=get_tooltip("embedding", "tiling")
1298        )
1299        setting_values.layout().addLayout(layout)
1300
1301        # Create UI for the halo.
1302        self.halo_x, self.halo_y = 0, 0
1303        self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
1304            ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
1305            tooltip=get_tooltip("embedding", "halo")
1306        )
1307        setting_values.layout().addLayout(layout)
1308
1309        # Create UI for the choice of automatic segmentation mode.
1310        self.automatic_segmentation_mode = "auto"
1311        auto_seg_options = ["auto", "amg", "ais"]
1312        self.automatic_segmentation_mode_dropdown, layout = self._add_choice_param(
1313            "automatic_segmentation_mode", self.automatic_segmentation_mode, auto_seg_options,
1314            title="automatic segmentation mode", tooltip=get_tooltip("embedding", "automatic_segmentation_mode")
1315        )
1316        setting_values.layout().addLayout(layout)
1317
1318        settings = _make_collapsible(setting_values, title="Embedding Settings")
1319        return settings
1320
1321    def _validate_inputs(self):
1322        """Validates the inputs for the annotation process and returns a dictionary
1323        containing information for message generation, or False if no messages are needed.
1324
1325        This function performs the following checks:
1326
1327        - If an `embeddings_save_path` is provided:
1328            - Validates the image data signature by comparing it with the signature
1329            of the image data in the viewer's selection.
1330            - Checks for existing embeddings at the specified path.
1331                - If existing embeddings are found, it attempts to load parameters
1332                like tile shape, halo, and model type from the Zarr attributes.
1333                - An informational message is generated based on the loaded parameters.
1334                - If loading existing embeddings fails, an error message is generated.
1335                - If no existing embeddings are found, an informational message is generated.
1336        - If no `embeddings_save_path` is provided, the function returns None.
1337
1338        Returns:
1339            bool: True if the computation should be aborted, otherwise False.
1340        """
1341
1342        # Check if we have an existing input image to compute the embeddings.
1343        image = self.image_selection.get_value()
1344        if image is None:
1345            return _generate_message("error", "No image has been selected.")
1346
1347        # Check if we have an existing embedding path.
1348        # If yes we check the data signature of these embeddings against the selected image
1349        # and we ask the user if they want to load these embeddings.
1350        if self.embeddings_save_path and os.listdir(self.embeddings_save_path):
1351            try:
1352                f = zarr.open(self.embeddings_save_path, mode="a")
1353
1354                # Validate that the embeddings are complete.
1355                # Note: 'input_size' is the last value set in the attrs of f,
1356                # so we can use it as a proxy to check if the embeddings are fully computed
1357                if "input_size" not in f.attrs:
1358                    msg = (f"The embeddings at {self.embeddings_save_path} are incomplete. "
1359                           "Specify a different path or remove them.")
1360                    return _generate_message("error", msg)
1361
1362                # Validate image data signature.
1363                if "data_signature" in f.attrs:
1364                    image = self.image_selection.get_value()
1365                    img_signature = util._compute_data_signature(image.data)
1366                    if img_signature != f.attrs["data_signature"]:
1367                        msg = f"The embeddings don't match with the image: {img_signature} {f.attrs['data_signature']}"
1368                        return _generate_message("error", msg)
1369
1370                # Load existing parameters.
1371                self.model_type = f.attrs.get("model_name", f.attrs["model_type"])
1372                if "tile_shape" in f.attrs and f.attrs["tile_shape"] is not None:
1373                    self.tile_x, self.tile_y = f.attrs["tile_shape"]
1374                    self.halo_x, self.halo_y = f.attrs["halo"]
1375                    val_results = {
1376                        "message_type": "info",
1377                        "message": (f"Load embeddings for model: {self.model_type} with tile shape: "
1378                                    f"{self.tile_x}, {self.tile_y} and halo: {self.halo_x}, {self.halo_y}.")
1379                    }
1380                else:
1381                    self.tile_x, self.tile_y = 0, 0
1382                    self.halo_x, self.halo_y = 0, 0
1383                    val_results = {
1384                        "message_type": "info",
1385                        "message": f"Load embeddings for model: {self.model_type}."
1386                    }
1387
1388                return _generate_message(val_results["message_type"], val_results["message"])
1389
1390            except RuntimeError as e:
1391                val_results = {
1392                    "message_type": "error",
1393                    "message": f"Failed to load image embeddings: {e}"
1394                }
1395                return _generate_message(val_results["message_type"], val_results["message"])
1396
1397        # Otherwise we either don't have an embedding path or it is empty. We can proceed in both cases.
1398        return False
1399
1400    def _validate_existing_embeddings(self, state):
1401        if state.image_embeddings is None:
1402            return False
1403        else:
1404            val_results = {
1405                "message_type": "info",
1406                "message": "Embeddings have already been precomputed. Press OK to recompute the embeddings."
1407            }
1408            return _generate_message(val_results["message_type"], val_results["message"])
1409
1410    def __call__(self, skip_validate=False):
1411        self._validate_model_type_and_custom_weights()
1412
1413        # Validate user inputs.
1414        if not skip_validate and self._validate_inputs():
1415            return
1416
1417        # Get the image.
1418        image = self.image_selection.get_value()
1419
1420        # Update the image embeddings:
1421        state = AnnotatorState()
1422        if self._validate_existing_embeddings(state):
1423            # Whether embeddings already exist to control existing objects in layers.
1424            state.skip_recomputing_embeddings = True
1425            return
1426
1427        state.skip_recomputing_embeddings = False
1428        # Reset the state.
1429        state.reset_state()
1430
1431        # Get image dimensions.
1432        if image.rgb:
1433            ndim = image.data.ndim - 1
1434            state.image_shape = image.data.shape[:-1]
1435        else:
1436            ndim = image.data.ndim
1437            state.image_shape = image.data.shape
1438
1439        # Set layer scale
1440        state.image_scale = tuple(image.scale)
1441
1442        # Process tile_shape and halo, set other data.
1443        tile_shape, halo = _process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
1444        save_path = None if self.embeddings_save_path == "" else self.embeddings_save_path
1445        image_data = image.data
1446
1447        # Set up progress bar and signals for using it within a threadworker.
1448        pbar, pbar_signals = _create_pbar_for_threadworker()
1449
1450        # @thread_worker()
1451        def compute_image_embedding():
1452
1453            def pbar_init(total, description):
1454                pbar_signals.pbar_total.emit(total)
1455                pbar_signals.pbar_description.emit(description)
1456
1457            # Whether to prefer decoder.
1458            # With 'amg', it is set to 'False', else it is 'True' for the default 'auto' and 'ais' mode.
1459            prefer_decoder = True
1460            if self.automatic_segmentation_mode == "amg":
1461                prefer_decoder = False
1462
1463            state.initialize_predictor(
1464                image_data, model_type=self.model_type, save_path=save_path, ndim=ndim,
1465                device=self.device, checkpoint_path=self.custom_weights, tile_shape=tile_shape, halo=halo,
1466                prefer_decoder=prefer_decoder, pbar_init=pbar_init,
1467                pbar_update=lambda update: pbar_signals.pbar_update.emit(update),
1468            )
1469            pbar_signals.pbar_stop.emit()
1470
1471        compute_image_embedding()
1472        self._update_model(state)
1473        # worker = compute_image_embedding()
1474        # worker.returned.connect(self._update_model)
1475        # worker.start()
1476        # return worker
1477
1478
1479#
1480# Functionality and widget for nd segmentation.
1481#
1482
1483
1484def _update_lineage(viewer):
1485    """Updated the lineage after recording a division event.
1486    This helper function is needed by 'track_object'.
1487    """
1488    state = AnnotatorState()
1489    tracking_widget = state.widgets["tracking"]
1490
1491    mother = state.current_track_id
1492    assert mother in state.lineage
1493    assert len(state.lineage[mother]) == 0
1494
1495    daughter1, daughter2 = state.current_track_id + 1, state.current_track_id + 2
1496    state.lineage[mother] = [daughter1, daughter2]
1497    state.lineage[daughter1] = []
1498    state.lineage[daughter2] = []
1499
1500    # Update the choices in the track_id menu so that it contains the new track ids.
1501    track_ids = list(map(str, state.lineage.keys()))
1502    tracking_widget[1].choices = track_ids
1503
1504    viewer.layers["point_prompts"].property_choices["track_id"] = [str(track_id) for track_id in track_ids]
1505    viewer.layers["prompts"].property_choices["track_id"] = [str(track_id) for track_id in track_ids]
1506
1507
1508class SegmentNDWidget(_WidgetBase):
1509    def __init__(self, viewer, tracking, parent=None):
1510        super().__init__(parent=parent)
1511        self._viewer = viewer
1512        self.tracking = tracking
1513
1514        # Add the settings.
1515        self.settings = self._create_settings()
1516        self.layout().addWidget(self.settings)
1517
1518        # Add the run button.
1519        button_title = "Segment All Frames [Shift-S]" if self.tracking else "Segment All Slices [Shift-S]"
1520        self.run_button = QtWidgets.QPushButton(button_title)
1521        self.run_button.clicked.connect(self.__call__)
1522        self.layout().addWidget(self.run_button)
1523
1524    def _create_settings(self):
1525        setting_values = QtWidgets.QWidget()
1526        setting_values.setToolTip(get_tooltip("segmentnd", "settings"))
1527        setting_values.setLayout(QtWidgets.QVBoxLayout())
1528
1529        # Create the UI for the projection modes.
1530        self.projection = "single_point"
1531        self.projection_dropdown, layout = self._add_choice_param(
1532            "projection", self.projection, PROJECTION_MODES, tooltip=get_tooltip("segmentnd", "projection_dropdown")
1533            )
1534        setting_values.layout().addLayout(layout)
1535
1536        # Create the UI element for the IOU threshold.
1537        self.iou_threshold = 0.5
1538        self.iou_threshold_param, layout = self._add_float_param(
1539            "iou_threshold", self.iou_threshold, tooltip=get_tooltip("segmentnd", "iou_threshold")
1540            )
1541        setting_values.layout().addLayout(layout)
1542
1543        # Create the UI element for the box extension.
1544        self.box_extension = 0.05
1545        self.box_extension_param, layout = self._add_float_param(
1546            "box_extension", self.box_extension, tooltip=get_tooltip("segmentnd", "box_extension")
1547            )
1548        setting_values.layout().addLayout(layout)
1549
1550        # Create the UI element for the motion smoothing (if we have the tracking widget).
1551        if self.tracking:
1552            self.motion_smoothing = 0.5
1553            self.motion_smoothing_param, layout = self._add_float_param(
1554                "motion_smoothing", self.motion_smoothing, tooltip=get_tooltip("segmentnd", "motion_smoothing")
1555                )
1556            setting_values.layout().addLayout(layout)
1557
1558        settings = _make_collapsible(setting_values, title="Segmentation Settings")
1559        return settings
1560
1561    def _run_tracking(self):
1562        state = AnnotatorState()
1563        pbar, pbar_signals = _create_pbar_for_threadworker()
1564
1565        # @thread_worker
1566        def tracking_impl():
1567            shape = state.image_shape
1568
1569            pbar_signals.pbar_total.emit(shape[0])
1570            pbar_signals.pbar_description.emit("Track object")
1571
1572            # Step 1: Segment all slices with prompts.
1573            seg, slices, _, stop_upper = vutil.segment_slices_with_prompts(
1574                state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"],
1575                state.image_embeddings, shape, track_id=state.current_track_id,
1576                update_progress=lambda update: pbar_signals.pbar_update.emit(update),
1577            )
1578
1579            # Step 2: Track the object starting from the lowest annotated slice.
1580            seg, has_division = vutil.track_from_prompts(
1581                self._viewer.layers["point_prompts"], self._viewer.layers["prompts"], seg,
1582                state.predictor, slices, state.image_embeddings, stop_upper,
1583                threshold=self.iou_threshold, projection=self.projection,
1584                motion_smoothing=self.motion_smoothing,
1585                box_extension=self.box_extension,
1586                update_progress=lambda update: pbar_signals.pbar_update.emit(update),
1587            )
1588
1589            pbar_signals.pbar_stop.emit()
1590            return seg, has_division
1591
1592        def update_segmentation(ret_val):
1593            seg, has_division = ret_val
1594            # If a division has occurred and it's the first time it occurred for this track
1595            # then we need to create the two daughter tracks and update the lineage.
1596            if has_division and (len(state.lineage[state.current_track_id]) == 0):
1597                _update_lineage(self._viewer)
1598
1599            # Clear the old track mask.
1600            self._viewer.layers["current_object"].data[
1601                self._viewer.layers["current_object"].data == state.current_track_id
1602            ] = 0
1603            # Set the new object mask.
1604            self._viewer.layers["current_object"].data[seg == 1] = state.current_track_id
1605            self._viewer.layers["current_object"].refresh()
1606
1607        ret_val = tracking_impl()
1608        update_segmentation(ret_val)
1609        # worker = tracking_impl()
1610        # worker.returned.connect(update_segmentation)
1611        # worker.start()
1612        # return worker
1613
1614    def _run_volumetric_segmentation(self):
1615        pbar, pbar_signals = _create_pbar_for_threadworker()
1616
1617        # @thread_worker
1618        def volumetric_segmentation_impl():
1619            state = AnnotatorState()
1620            shape = state.image_shape
1621
1622            pbar_signals.pbar_total.emit(shape[0])
1623            pbar_signals.pbar_description.emit("Segment object")
1624
1625            # Step 1: Segment all slices with prompts.
1626            seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts(
1627                state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"],
1628                state.image_embeddings, shape,
1629                update_progress=lambda update: pbar_signals.pbar_update.emit(update),
1630            )
1631
1632            # Step 2: Segment the rest of the volume based on projecting prompts.
1633            seg, (z_min, z_max) = segment_mask_in_volume(
1634                seg, state.predictor, state.image_embeddings, slices,
1635                stop_lower, stop_upper,
1636                iou_threshold=self.iou_threshold, projection=self.projection,
1637                box_extension=self.box_extension,
1638                update_progress=lambda update: pbar_signals.pbar_update.emit(update),
1639            )
1640            pbar_signals.pbar_stop.emit()
1641
1642            state.z_range = (z_min, z_max)
1643            return seg
1644
1645        def update_segmentation(seg):
1646            self._viewer.layers["current_object"].data = seg
1647            self._viewer.layers["current_object"].refresh()
1648
1649        seg = volumetric_segmentation_impl()
1650        self._viewer.layers["current_object"].data = seg
1651        self._viewer.layers["current_object"].refresh()
1652        # worker = volumetric_segmentation_impl()
1653        # worker.returned.connect(update_segmentation)
1654        # worker.start()
1655        # return worker
1656
1657    def __call__(self):
1658        if _validate_embeddings(self._viewer):
1659            return None
1660        if _validate_layers(self._viewer):
1661            return None
1662
1663        if self.tracking:
1664            return self._run_tracking()
1665        else:
1666            return self._run_volumetric_segmentation()
1667
1668
1669#
1670# The functionality and widgets for automatic segmentation.
1671#
1672
1673
1674# Messy amg state handling, would be good to refactor this properly at some point.
1675def _handle_amg_state(state, i, pbar_init, pbar_update):
1676    if state.amg is None:
1677        is_tiled = state.image_embeddings["input_size"] is None
1678        state.amg = instance_segmentation.get_instance_segmentation_generator(
1679            state.predictor, is_tiled=is_tiled, decoder=state.decoder
1680        )
1681
1682    shape = state.image_shape
1683
1684    # Further optimization: refactor parts of this so that we can also use it in the automatic 3d segmentation fucnction
1685    # For 3D we store the amg state in a dict and check if it is computed already.
1686    if state.amg_state is not None:
1687        assert i is not None
1688        if i in state.amg_state:
1689            amg_state_i = state.amg_state[i]
1690            state.amg.set_state(amg_state_i)
1691
1692        else:
1693            dummy_image = np.zeros(shape[-2:], dtype="uint8")
1694            state.amg.initialize(
1695                dummy_image, image_embeddings=state.image_embeddings, i=i,
1696                verbose=pbar_init is not None, pbar_init=pbar_init, pbar_update=pbar_update,
1697            )
1698            amg_state_i = state.amg.get_state()
1699            state.amg_state[i] = amg_state_i
1700
1701            cache_folder = state.amg_state.get("cache_folder", None)
1702            if cache_folder is not None:
1703                cache_path = os.path.join(cache_folder, f"state-{i}.pkl")
1704                with open(cache_path, "wb") as f:
1705                    pickle.dump(amg_state_i, f)
1706
1707            cache_path = state.amg_state.get("cache_path", None)
1708            if cache_path is not None:
1709                save_key = f"state-{i}"
1710                with h5py.File(cache_path, "a") as f:
1711                    g = f.create_group(save_key)
1712                    g.create_dataset("foreground", data=amg_state_i["foreground"], compression="gzip")
1713                    g.create_dataset("boundary_distances", data=amg_state_i["boundary_distances"], compression="gzip")
1714                    g.create_dataset("center_distances", data=amg_state_i["center_distances"], compression="gzip")
1715
1716    # Otherwise (2d segmentation) we just check if the amg is initialized or not.
1717    elif not state.amg.is_initialized:
1718        assert i is None
1719        # We don't need to pass the actual image data here, since the embeddings are passed.
1720        # (The image data is only used by the amg to compute image embeddings, so not needed here.)
1721        dummy_image = np.zeros(shape, dtype="uint8")
1722        state.amg.initialize(
1723            dummy_image, image_embeddings=state.image_embeddings,
1724            verbose=pbar_init is not None, pbar_init=pbar_init, pbar_update=pbar_update
1725        )
1726
1727
1728def _instance_segmentation_impl(min_object_size, i=None, pbar_init=None, pbar_update=None, **kwargs):
1729    state = AnnotatorState()
1730    _handle_amg_state(state, i, pbar_init, pbar_update)
1731    seg = state.amg.generate(**kwargs)
1732    assert isinstance(seg, np.ndarray)
1733    return seg
1734
1735
1736class AutoSegmentWidget(_WidgetBase):
1737    def __init__(self, viewer, with_decoder, volumetric, parent=None):
1738        super().__init__(parent)
1739
1740        self._viewer = viewer
1741        self.with_decoder = with_decoder
1742        self.volumetric = volumetric
1743        self._create_widget()
1744
1745    def _create_widget(self):
1746        # Add the switch for segmenting the slice vs. the volume if we have a volume.
1747        if self.volumetric:
1748            self.layout().addWidget(self._create_volumetric_switch())
1749
1750        # Add the nested settings widget.
1751        self.settings = self._create_settings()
1752        self.layout().addWidget(self.settings)
1753
1754        # Add the run button.
1755        self.run_button = QtWidgets.QPushButton("Automatic Segmentation")
1756        self.run_button.clicked.connect(self.__call__)
1757        self.run_button.setToolTip(get_tooltip("autosegment", "run_button"))
1758        self.layout().addWidget(self.run_button)
1759
1760    def _reset_segmentation_mode(self, with_decoder):
1761        # If we already have the same segmentation mode we don't need to do anything.
1762        if with_decoder == self.with_decoder:
1763            return
1764
1765        # Otherwise we change the value of with_decoder.
1766        self.with_decoder = with_decoder
1767
1768        # Then we clear the whole widget.
1769        layout = self.layout()
1770        while layout.count():
1771            child = layout.takeAt(0)
1772            if child.widget():
1773                child.widget().deleteLater()
1774
1775        # And then we reset it.
1776        self._create_widget()
1777
1778    def _create_volumetric_switch(self):
1779        self.apply_to_volume = False
1780        return self._add_boolean_param(
1781            "apply_to_volume", self.apply_to_volume, title="Apply to Volume",
1782            tooltip=get_tooltip("autosegment", "apply_to_volume")
1783        )
1784
1785    def _add_common_settings(self, settings):
1786        # Create the UI element for min object size.
1787        self.min_object_size = 100
1788        self.min_object_size_param, layout = self._add_int_param(
1789            "min_object_size", self.min_object_size, min_val=0, max_val=int(1e4),
1790            tooltip=get_tooltip("autosegment", "min_object_size")
1791        )
1792        settings.layout().addLayout(layout)
1793
1794        # Add extra settings for volumetric segmentation: gap_closing and min_extent.
1795        if self.volumetric:
1796            self.gap_closing = 2
1797            self.gap_closing_param, layout = self._add_int_param(
1798                "gap_closing", self.gap_closing, min_val=0, max_val=10,
1799                tooltip=get_tooltip("autosegment", "gap_closing")
1800                )
1801            settings.layout().addLayout(layout)
1802
1803            self.min_extent = 2
1804            self.min_extent_param, layout = self._add_int_param(
1805                "min_extent", self.min_extent, min_val=0, max_val=10,
1806                tooltip=get_tooltip("autosegment", "min_extent")
1807                )
1808            settings.layout().addLayout(layout)
1809
1810    def _ais_settings(self):
1811        settings = QtWidgets.QWidget()
1812        settings.setLayout(QtWidgets.QVBoxLayout())
1813
1814        # Create the UI element for center_distance_threshold.
1815        self.center_distance_thresh = 0.5
1816        self.center_distance_thresh_param, layout = self._add_float_param(
1817            "center_distance_thresh", self.center_distance_thresh,
1818            tooltip=get_tooltip("autosegment", "center_distance_thresh")
1819        )
1820        settings.layout().addLayout(layout)
1821
1822        # Create the UI element for boundary_distance_threshold.
1823        self.boundary_distance_thresh = 0.5
1824        self.boundary_distance_thresh_param, layout = self._add_float_param(
1825            "boundary_distance_thresh", self.boundary_distance_thresh,
1826            tooltip=get_tooltip("autosegment", "boundary_distance_thresh")
1827        )
1828        settings.layout().addLayout(layout)
1829
1830        # Add min_object_size.
1831        self._add_common_settings(settings)
1832
1833        return settings
1834
1835    def _amg_settings(self):
1836        settings = QtWidgets.QWidget()
1837        settings.setLayout(QtWidgets.QVBoxLayout())
1838
1839        # Create the UI element for pred_iou_thresh.
1840        self.pred_iou_thresh = 0.88
1841        self.pred_iou_thresh_param, layout = self._add_float_param(
1842            "pred_iou_thresh", self.pred_iou_thresh,
1843            tooltip=get_tooltip("autosegment", "pred_iou_thresh")
1844            )
1845        settings.layout().addLayout(layout)
1846
1847        # Create the UI element for stability score thresh.
1848        self.stability_score_thresh = 0.95
1849        self.stability_score_thresh_param, layout = self._add_float_param(
1850            "stability_score_thresh", self.stability_score_thresh,
1851            tooltip=get_tooltip("autosegment", "stability_score_thresh")
1852        )
1853        settings.layout().addLayout(layout)
1854
1855        # Create the UI element for box nms thresh.
1856        self.box_nms_thresh = 0.7
1857        self.box_nms_thresh_param, layout = self._add_float_param(
1858            "box_nms_thresh", self.box_nms_thresh,
1859            tooltip=get_tooltip("autosegment", "box_nms_thresh")
1860            )
1861        settings.layout().addLayout(layout)
1862
1863        # Add min_object_size.
1864        self._add_common_settings(settings)
1865
1866        return settings
1867
1868    def _create_settings(self):
1869        setting_values = self._ais_settings() if self.with_decoder else self._amg_settings()
1870        settings = _make_collapsible(setting_values, title="Automatic Segmentation Settings")
1871        return settings
1872
1873    def _empty_segmentation_warning(self):
1874        msg = "The automatic segmentation result does not contain any objects."
1875        msg += "Setting a smaller value for 'min_object_size' may help."
1876        if not self.with_decoder:
1877            msg += "Setting smaller values for 'pred_iou_thresh' and 'stability_score_thresh' may also help."
1878        val_results = {"message_type": "error", "message": msg}
1879        return _generate_message(val_results["message_type"], val_results["message"])
1880
1881    def _run_segmentation_2d(self, kwargs, i=None):
1882        pbar, pbar_signals = _create_pbar_for_threadworker()
1883
1884        # @thread_worker
1885        def seg_impl():
1886            def pbar_init(total, description):
1887                pbar_signals.pbar_total.emit(total)
1888                pbar_signals.pbar_description.emit(description)
1889
1890            seg = _instance_segmentation_impl(
1891                self.min_object_size, i=i, pbar_init=pbar_init,
1892                pbar_update=lambda update: pbar_signals.pbar_update.emit(update),
1893                **kwargs
1894            )
1895            pbar_signals.pbar_stop.emit()
1896            return seg
1897
1898        def update_segmentation(seg):
1899            is_empty = seg.max() == 0
1900            if is_empty:
1901                self._empty_segmentation_warning()
1902
1903            if i is None:
1904                self._viewer.layers["auto_segmentation"].data = seg
1905            else:
1906                self._viewer.layers["auto_segmentation"].data[i] = seg
1907            self._viewer.layers["auto_segmentation"].refresh()
1908
1909        # Validate all layers.
1910        _validate_layers(self._viewer, automatic_segmentation=True)
1911
1912        seg = seg_impl()
1913        update_segmentation(seg)
1914        # worker = seg_impl()
1915        # worker.returned.connect(update_segmentation)
1916        # worker.start()
1917        # return worker
1918
1919    # We refuse to run 3D segmentation with the AMG unless we have a GPU or all embeddings
1920    # are precomputed. Otherwise this would take too long.
1921    def _allow_segment_3d(self):
1922        if self.with_decoder:
1923            return True
1924        state = AnnotatorState()
1925        predictor = state.predictor
1926        if str(predictor.device) == "cpu" or str(predictor.device) == "mps":
1927            n_slices = self._viewer.layers["auto_segmentation"].data.shape[0]
1928            embeddings_are_precomputed = (state.amg_state is not None) and (len(state.amg_state) > n_slices)
1929            if not embeddings_are_precomputed:
1930                return False
1931        return True
1932
1933    def _run_segmentation_3d(self, kwargs):
1934        allow_segment_3d = self._allow_segment_3d()
1935        if not allow_segment_3d:
1936            val_results = {
1937                "message_type": "error",
1938                "message": "Volumetric segmentation with AMG is only supported if you have a GPU."
1939            }
1940            return _generate_message(val_results["message_type"], val_results["message"])
1941
1942        pbar, pbar_signals = _create_pbar_for_threadworker()
1943
1944        # @thread_worker
1945        def seg_impl():
1946            segmentation = np.zeros_like(self._viewer.layers["auto_segmentation"].data)
1947            offset = 0
1948
1949            def pbar_init(total, description):
1950                pbar_signals.pbar_total.emit(total)
1951                pbar_signals.pbar_description.emit(description)
1952
1953            pbar_init(segmentation.shape[0], "Segment volume")
1954
1955            # Further optimization: parallelize if state is precomputed for all slices
1956            for i in range(segmentation.shape[0]):
1957                seg = _instance_segmentation_impl(self.min_object_size, i=i, **kwargs)
1958                seg_max = seg.max()
1959                if seg_max == 0:
1960                    continue
1961                seg[seg != 0] += offset
1962                offset = seg_max + offset
1963                segmentation[i] = seg
1964                pbar_signals.pbar_update.emit(1)
1965
1966            pbar_signals.pbar_reset.emit()
1967            segmentation = merge_instance_segmentation_3d(
1968                segmentation, beta=0.5,  gap_closing=self.gap_closing, min_z_extent=self.min_extent,
1969                verbose=True, pbar_init=pbar_init, pbar_update=lambda update: pbar_signals.pbar_update.emit(1),
1970            )
1971            pbar_signals.pbar_stop.emit()
1972            return segmentation
1973
1974        def update_segmentation(segmentation):
1975            is_empty = segmentation.max() == 0
1976            if is_empty:
1977                self._empty_segmentation_warning()
1978            self._viewer.layers["auto_segmentation"].data = segmentation
1979            self._viewer.layers["auto_segmentation"].refresh()
1980
1981        seg = seg_impl()
1982        update_segmentation(seg)
1983        # worker = seg_impl()
1984        # worker.returned.connect(update_segmentation)
1985        # worker.start()
1986        # return worker
1987
1988    def __call__(self):
1989        if _validate_embeddings(self._viewer):
1990            return None
1991
1992        if self.with_decoder:
1993            kwargs = {
1994                "center_distance_threshold": self.center_distance_thresh,
1995                "boundary_distance_threshold": self.boundary_distance_thresh,
1996                "min_size": self.min_object_size,
1997            }
1998        else:
1999            kwargs = {
2000                "pred_iou_thresh": self.pred_iou_thresh,
2001                "stability_score_thresh": self.stability_score_thresh,
2002                "box_nms_thresh": self.box_nms_thresh,
2003            }
2004        if self.volumetric and self.apply_to_volume:
2005            worker = self._run_segmentation_3d(kwargs)
2006        elif self.volumetric and not self.apply_to_volume:
2007            i = int(self._viewer.dims.point[0])
2008            worker = self._run_segmentation_2d(kwargs, i=i)
2009        else:
2010            worker = self._run_segmentation_2d(kwargs)
2011        _select_layer(self._viewer, "auto_segmentation")
2012        return worker
2013
2014
2015class AutoTrackWidget(AutoSegmentWidget):
2016    def _create_tracking_switch(self):
2017        self.apply_to_volume = False
2018        return self._add_boolean_param(
2019            "apply_to_volume", self.apply_to_volume, title="Track Timeseries",
2020            tooltip=get_tooltip("autotrack", "run_tracking")
2021        )
2022
2023    def _create_widget(self):
2024        # Add the switch for segmenting the slice vs. tracking the timeseries.
2025        self.layout().addWidget(self._create_tracking_switch())
2026
2027        # Add the nested settings widget.
2028        self.settings = self._create_settings()
2029        self.layout().addWidget(self.settings)
2030
2031        # Add the run button.
2032        self.run_button = QtWidgets.QPushButton("Automatic Tracking")
2033        self.run_button.clicked.connect(self.__call__)
2034        self.run_button.setToolTip(get_tooltip("autotrack", "run_button"))
2035        self.layout().addWidget(self.run_button)
2036
2037    def _run_segmentation_3d(self, kwargs):
2038        allow_segment_3d = self._allow_segment_3d()
2039        if not allow_segment_3d:
2040            return _generate_message("error", "Tracking with AMG is only supported if you have a GPU.")
2041
2042        state = AnnotatorState()
2043        if len(state.committed_lineages) > 0:
2044            return _generate_message(
2045                "error",
2046                "Automatic tracking can only be called if you haven't commited results from interactive tracking yet."
2047            )
2048        pbar, pbar_signals = _create_pbar_for_threadworker()
2049
2050        # @thread_worker
2051        def seg_impl():
2052            image_name = state.get_image_name(self._viewer)
2053            timeseries = self._viewer.layers[image_name].data
2054            segmentation = np.zeros_like(self._viewer.layers["auto_segmentation"].data)
2055            offset = 0
2056
2057            def pbar_init(total, description):
2058                pbar_signals.pbar_total.emit(total)
2059                pbar_signals.pbar_description.emit(description)
2060
2061            pbar_init(segmentation.shape[0], "Run tracking")
2062
2063            # Further optimization: parallelize if state is precomputed for all slices
2064            for i in range(segmentation.shape[0]):
2065                seg = _instance_segmentation_impl(self.min_object_size, i=i, **kwargs)
2066                seg_max = seg.max()
2067                if seg_max == 0:
2068                    continue
2069                seg[seg != 0] += offset
2070                offset = seg_max + offset
2071                segmentation[i] = seg
2072                pbar_signals.pbar_update.emit(1)
2073
2074            pbar_signals.pbar_reset.emit()
2075            segmentation, lineages = track_across_frames(
2076                timeseries, segmentation,
2077                verbose=True, pbar_init=pbar_init,
2078                pbar_update=lambda update: pbar_signals.pbar_update.emit(1),
2079            )
2080            pbar_signals.pbar_stop.emit()
2081            return (segmentation, lineages)
2082
2083        def update_segmentation(result):
2084            segmentation, lineages = result
2085            is_empty = segmentation.max() == 0
2086            if is_empty:
2087                self._empty_segmentation_warning()
2088
2089            state = AnnotatorState()
2090            state.lineage = lineages
2091
2092            self._viewer.layers["auto_segmentation"].data = segmentation
2093            self._viewer.layers["auto_segmentation"].refresh()
2094
2095        result = seg_impl()
2096        update_segmentation(result)
2097        # worker = seg_impl()
2098        # worker.returned.connect(update_segmentation)
2099        # worker.start()
2100        # return worker
class PBarSignals(PyQt6.QtCore.QObject):
364class PBarSignals(QObject):
365    pbar_total = Signal(int)
366    pbar_update = Signal(int)
367    pbar_description = Signal(str)
368    pbar_stop = Signal()
369    pbar_reset = Signal()

QObject(parent: Optional[QObject] = None)

def pbar_total(unknown):

pyqtSignal(*types, name: str = ..., revision: int = ..., arguments: Sequence = ...) -> PYQT_SIGNAL

types is normally a sequence of individual types. Each type is either a type object or a string that is the name of a C++ type. Alternatively each type could itself be a sequence of types each describing a different overloaded signal. name is the optional C++ name of the signal. If it is not specified then the name of the class attribute that is bound to the signal is used. revision is the optional revision of the signal that is exported to QML. If it is not specified then 0 is used. arguments is the optional sequence of the names of the signal's arguments.

def pbar_update(unknown):

pyqtSignal(*types, name: str = ..., revision: int = ..., arguments: Sequence = ...) -> PYQT_SIGNAL

types is normally a sequence of individual types. Each type is either a type object or a string that is the name of a C++ type. Alternatively each type could itself be a sequence of types each describing a different overloaded signal. name is the optional C++ name of the signal. If it is not specified then the name of the class attribute that is bound to the signal is used. revision is the optional revision of the signal that is exported to QML. If it is not specified then 0 is used. arguments is the optional sequence of the names of the signal's arguments.

def pbar_description(unknown):

pyqtSignal(*types, name: str = ..., revision: int = ..., arguments: Sequence = ...) -> PYQT_SIGNAL

types is normally a sequence of individual types. Each type is either a type object or a string that is the name of a C++ type. Alternatively each type could itself be a sequence of types each describing a different overloaded signal. name is the optional C++ name of the signal. If it is not specified then the name of the class attribute that is bound to the signal is used. revision is the optional revision of the signal that is exported to QML. If it is not specified then 0 is used. arguments is the optional sequence of the names of the signal's arguments.

def pbar_stop(unknown):

pyqtSignal(*types, name: str = ..., revision: int = ..., arguments: Sequence = ...) -> PYQT_SIGNAL

types is normally a sequence of individual types. Each type is either a type object or a string that is the name of a C++ type. Alternatively each type could itself be a sequence of types each describing a different overloaded signal. name is the optional C++ name of the signal. If it is not specified then the name of the class attribute that is bound to the signal is used. revision is the optional revision of the signal that is exported to QML. If it is not specified then 0 is used. arguments is the optional sequence of the names of the signal's arguments.

def pbar_reset(unknown):

pyqtSignal(*types, name: str = ..., revision: int = ..., arguments: Sequence = ...) -> PYQT_SIGNAL

types is normally a sequence of individual types. Each type is either a type object or a string that is the name of a C++ type. Alternatively each type could itself be a sequence of types each describing a different overloaded signal. name is the optional C++ name of the signal. If it is not specified then the name of the class attribute that is bound to the signal is used. revision is the optional revision of the signal that is exported to QML. If it is not specified then 0 is used. arguments is the optional sequence of the names of the signal's arguments.

class InfoDialog(PyQt6.QtWidgets.QDialog):
372class InfoDialog(QtWidgets.QDialog):
373    def __init__(self, title, message):
374        super().__init__()
375        self.setWindowTitle(title)
376
377        layout = QtWidgets.QVBoxLayout()
378        layout.addWidget(QtWidgets.QLabel(message))
379
380        # Add buttons
381        button_box = QtWidgets.QHBoxLayout()  # Use QHBoxLayout for buttons side-by-side
382        accept_button = QtWidgets.QPushButton("OK")
383        accept_button.clicked.connect(lambda: self.button_clicked(accept_button))  # Connect to clicked signal
384        button_box.addWidget(accept_button)
385
386        cancel_button = QtWidgets.QPushButton("Cancel")
387        cancel_button.clicked.connect(lambda: self.button_clicked(cancel_button))  # Connect to clicked signal
388        button_box.addWidget(cancel_button)
389
390        layout.addLayout(button_box)
391        self.setLayout(layout)
392
393    def button_clicked(self, button):
394        if button.text() == "OK":
395            self.accept()  # Accept the dialog
396        else:
397            self.reject()  # Reject the dialog (Cancel)

QDialog(parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowFlags())

InfoDialog(title, message)
373    def __init__(self, title, message):
374        super().__init__()
375        self.setWindowTitle(title)
376
377        layout = QtWidgets.QVBoxLayout()
378        layout.addWidget(QtWidgets.QLabel(message))
379
380        # Add buttons
381        button_box = QtWidgets.QHBoxLayout()  # Use QHBoxLayout for buttons side-by-side
382        accept_button = QtWidgets.QPushButton("OK")
383        accept_button.clicked.connect(lambda: self.button_clicked(accept_button))  # Connect to clicked signal
384        button_box.addWidget(accept_button)
385
386        cancel_button = QtWidgets.QPushButton("Cancel")
387        cancel_button.clicked.connect(lambda: self.button_clicked(cancel_button))  # Connect to clicked signal
388        button_box.addWidget(cancel_button)
389
390        layout.addLayout(button_box)
391        self.setLayout(layout)
def button_clicked(self, button):
393    def button_clicked(self, button):
394        if button.text() == "OK":
395            self.accept()  # Accept the dialog
396        else:
397            self.reject()  # Reject the dialog (Cancel)
clear = MagicFactory(function=<function clear>, call_button='Clear Annotations [Shift + C]')

Widget for clearing the current annotations.

Arguments:
  • viewer: The napari viewer.
clear_volume = MagicFactory(function=<function clear_volume>, call_button='Clear Annotations [Shift + C]')

Widget for clearing the current annotations in 3D.

Arguments:
  • viewer: The napari viewer.
  • all_slices: Choose whether to clear the annotations for all or only the current slice.
clear_track = MagicFactory(function=<function clear_track>, call_button='Clear Annotations [Shift + C]')

Widget for clearing all tracking annotations and state.

Arguments:
  • viewer: The napari viewer.
  • all_frames: Choose whether to clear the annotations for all or only the current frame.
commit = MagicFactory(function=<function commit>, call_button='Commit [C]', param_options={'layer': {'choices': ['current_object', 'auto_segmentation'], 'tooltip': "The layer to commit. Either 'current_object' to commit results from prompt-based segmentation or 'auto_segmentation' to commit results from automatic segmentation."}, 'preserve_mode': {'choices': ['objects', 'pixels', 'none'], 'tooltip': "The mode for preserving already committed objects. Either 'objects' to preserve on a per-object level, 'pixels' to preserve on a per-pixel level, or 'none' to not preserve."}, 'commit_path': {'mode': 'd', 'tooltip': 'The path to a zarr file for saving committed objects, prompts and other segmentation settings.'}})

Widget for committing the segmented objects from automatic or interactive segmentation.

Arguments:
  • viewer: The napari viewer.
  • layer: Select the layer to commit. Can be either 'current_object' to commit interacitve segmentation results. Or 'auto_segmentation' to commit automatic segmentation results.
  • preserve_mode: The mode for preserving already committed objects, in order to prevent over-writing them by a new commit. Supports the modes 'objects', which preserves on the object level and is the default, 'pixels', which preserves on the pixel-level, or 'none', which does not preserve commited objects.
  • preservation_threshold: The overlap threshold for preserving objects. This is only used if preservation_mode is set to 'objects'.
  • commit_path: Select a file path where the committed results and prompts will be saved. This feature is still experimental.
commit_track = MagicFactory(function=<function commit_track>, call_button='Commit [C]', param_options={'layer': {'choices': ['current_object', 'auto_segmentation']}, 'preserve_mode': {'choices': ['objects', 'pixels', 'none']}, 'commit_path': {'mode': 'd'}})

Widget for committing the objects from interactive tracking.

Arguments:
  • viewer: The napari viewer.
  • layer: Select the layer to commit. Can be either 'current_object' to commit interacitve segmentation results. Or 'auto_segmentation' to commit automatic segmentation results.
  • preserve_mode: The mode for preserving already committed objects, in order to prevent over-writing them by a new commit. Supports the modes 'objects', which preserves on the object level and is the default, 'pixels', which preserves on the pixel-level, or 'none', which does not preserve commited objects.
  • preservation_threshold: The overlap threshold for preserving objects. This is only used if preservation_mode is set to 'objects'.
  • commit_path: Select a file path where the committed results and prompts will be saved. This feature is still experimental.
def create_prompt_menu(points_layer, labels, menu_name='prompt', label_name='label'):
858def create_prompt_menu(points_layer, labels, menu_name="prompt", label_name="label"):
859    """Create the menu for toggling point prompt labels."""
860    label_menu = ComboBox(label=menu_name, choices=labels, tooltip=get_tooltip("prompt_menu", "labels"))
861    label_widget = Container(widgets=[label_menu])
862
863    def update_label_menu(event):
864        new_label = str(points_layer.current_properties[label_name][0])
865        if new_label != label_menu.value:
866            label_menu.value = new_label
867
868    points_layer.events.current_properties.connect(update_label_menu)
869
870    def label_changed(new_label):
871        current_properties = points_layer.current_properties
872        current_properties[label_name] = np.array([new_label])
873        points_layer.current_properties = current_properties
874        points_layer.refresh_colors()
875
876    label_menu.changed.connect(label_changed)
877
878    return label_widget

Create the menu for toggling point prompt labels.

settings_widget = MagicFactory(function=<function settings_widget>, call_button='Update settings', param_options={'cache_directory': {'mode': 'd'}})

Widget to update global micro_sam settings.

Arguments:
  • cache_directory: Select the path for the micro_sam cache directory. $HOME/.cache/micro_sam.
segment = MagicFactory(function=<function segment>, call_button='Segment Object [S]')

Segment object(s) for the current prompts.

Arguments:
  • viewer: The napari viewer.
  • batched: Choose if you want to segment multiple objects with point prompts.
segment_slice = MagicFactory(function=<function segment_slice>, call_button='Segment Slice [S]')

Segment object for to the current prompts.

Arguments:
  • viewer: The napari viewer.
segment_frame = MagicFactory(function=<function segment_frame>, call_button='Segment Frame [S]')

Segment object for the current prompts.

Arguments:
  • viewer: The napari viewer.
class EmbeddingWidget(_WidgetBase):
1165class EmbeddingWidget(_WidgetBase):
1166    def __init__(self, parent=None):
1167        super().__init__(parent=parent)
1168
1169        # Create a nested layout for the sections.
1170        # Section 1: Image and Model.
1171        section1_layout = QtWidgets.QHBoxLayout()
1172        section1_layout.addLayout(self._create_image_section())
1173        section1_layout.addLayout(self._create_model_section())  # Creates the model family widget section.
1174        self.layout().addLayout(section1_layout)
1175
1176        # Section 2: Settings (collapsible).
1177        self.layout().addWidget(self._create_settings_widget())
1178
1179        # Section 3: The button to trigger the embedding computation.
1180        self.run_button = QtWidgets.QPushButton("Compute Embeddings")
1181        self.run_button.clicked.connect(self._initialize_image)
1182        self.run_button.clicked.connect(self.__call__)
1183        self.run_button.setToolTip(get_tooltip("embedding", "run_button"))
1184        self.layout().addWidget(self.run_button)
1185
1186    def _initialize_image(self):
1187        state = AnnotatorState()
1188        layer = self.image_selection.get_value()
1189
1190        # This is encountered when there is no image layer available / selected.
1191        # In this case, we need not specify other image-level parameters to the state. Hence, we skip them.
1192        # NOTE: On code-level, this happens as the first step when "Compute Embedding" click is triggered.
1193        if layer is None:
1194            return
1195
1196        image_shape = layer.data.shape
1197        image_scale = tuple(layer.scale)
1198        state.image_shape = image_shape
1199        state.image_scale = image_scale
1200        state.image_name = layer.name
1201
1202    def _create_image_section(self):
1203        image_section = QtWidgets.QVBoxLayout()
1204        image_layer_widget = QtWidgets.QLabel("Image Layer:")
1205        # image_layer_widget.setToolTip(get_tooltip("embedding", "image")) #  this adds tooltip to label
1206        image_section.addWidget(image_layer_widget)
1207
1208        # Setting a napari layer in QT, see:
1209        # https://github.com/pyapp-kit/magicgui/blob/main/docs/examples/napari/napari_combine_qt.py
1210        self.image_selection = create_widget(annotation=napari.layers.Image)
1211        self.image_selection.native.setToolTip(get_tooltip("embedding", "image"))
1212        image_section.addWidget(self.image_selection.native)
1213
1214        return image_section
1215
1216    def _update_model(self, state):
1217        _model_type = state.predictor.model_type if self.custom_weights else self.model_type
1218
1219        # Provide a detailed message for the model family and model size per chosen combination.
1220        msg = "Computed embeddings for "
1221        if self.custom_weights:  # Whether the user provided a filepath to custom finetuned model weights.
1222            msg += f"the model located at '{os.path.abspath(self.custom_weights)}' "
1223            msg += f"of size '{self._model_size_map[_model_type[4]]}'."
1224        else:
1225            msg += f"the '{self.model_family}' model of size '{self.model_size}'."
1226
1227        show_info(msg)
1228
1229        state = AnnotatorState()
1230        # Update the widget itself. This is necessary because we may have loaded
1231        # some settings from the embedding file and have to reflect them in the widget.
1232        vutil._sync_embedding_widget(
1233            self,
1234            model_type=_model_type,
1235            save_path=self.embeddings_save_path,
1236            checkpoint_path=self.custom_weights,
1237            device=self.device,
1238            tile_shape=[self.tile_x, self.tile_y],
1239            halo=[self.halo_x, self.halo_y]
1240        )
1241
1242        # Set the default settings for this model in the autosegment widget if it is part of
1243        # the currently used plugin.
1244        if "autosegment" in state.widgets:
1245            with_decoder = state.decoder is not None
1246            vutil._sync_autosegment_widget(
1247                state.widgets["autosegment"], _model_type, self.custom_weights, update_decoder=with_decoder
1248            )
1249            # Load the AMG/AIS state if we have a 3d segmentation plugin.
1250            if state.widgets["autosegment"].volumetric and with_decoder:
1251                state.amg_state = vutil._load_is_state(state.embedding_path)
1252            elif state.widgets["autosegment"].volumetric and not with_decoder:
1253                state.amg_state = vutil._load_amg_state(state.embedding_path)
1254
1255        # Set the default settings for this model in the nd-segmentation widget if it is part of
1256        # the currently used plugin.
1257        if "segment_nd" in state.widgets:
1258            vutil._sync_ndsegment_widget(state.widgets["segment_nd"], _model_type, self.custom_weights)
1259
1260    def _create_settings_widget(self):
1261        setting_values = QtWidgets.QWidget()
1262        setting_values.setToolTip(get_tooltip("embedding", "settings"))
1263        setting_values.setLayout(QtWidgets.QVBoxLayout())
1264
1265        # Add the model size widget section.
1266        layout = self._create_model_size_section()
1267        setting_values.layout().addLayout(layout)
1268
1269        # Create UI for the device.
1270        self.device = "auto"
1271        device_options = ["auto"] + util._available_devices()
1272
1273        self.device_dropdown, layout = self._add_choice_param(
1274            "device", self.device, device_options, tooltip=get_tooltip("embedding", "device")
1275        )
1276        setting_values.layout().addLayout(layout)
1277
1278        # Create UI for the save path.
1279        self.embeddings_save_path = None
1280        self.embeddings_save_path_param, layout = self._add_path_param(
1281            "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:",
1282            tooltip=get_tooltip("embedding", "embeddings_save_path")
1283        )
1284        setting_values.layout().addLayout(layout)
1285
1286        # Create UI for the custom weights.
1287        self.custom_weights = None
1288        self.custom_weights_param, layout = self._add_path_param(
1289            "custom_weights", self.custom_weights, "file", title="custom weights path:",
1290            tooltip=get_tooltip("embedding", "custom_weights")
1291        )
1292        setting_values.layout().addLayout(layout)
1293
1294        # Create UI for the tile shape.
1295        self.tile_x, self.tile_y = 0, 0
1296        self.tile_x_param, self.tile_y_param, layout = self._add_shape_param(
1297            ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16,
1298            tooltip=get_tooltip("embedding", "tiling")
1299        )
1300        setting_values.layout().addLayout(layout)
1301
1302        # Create UI for the halo.
1303        self.halo_x, self.halo_y = 0, 0
1304        self.halo_x_param, self.halo_y_param, layout = self._add_shape_param(
1305            ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512,
1306            tooltip=get_tooltip("embedding", "halo")
1307        )
1308        setting_values.layout().addLayout(layout)
1309
1310        # Create UI for the choice of automatic segmentation mode.
1311        self.automatic_segmentation_mode = "auto"
1312        auto_seg_options = ["auto", "amg", "ais"]
1313        self.automatic_segmentation_mode_dropdown, layout = self._add_choice_param(
1314            "automatic_segmentation_mode", self.automatic_segmentation_mode, auto_seg_options,
1315            title="automatic segmentation mode", tooltip=get_tooltip("embedding", "automatic_segmentation_mode")
1316        )
1317        setting_values.layout().addLayout(layout)
1318
1319        settings = _make_collapsible(setting_values, title="Embedding Settings")
1320        return settings
1321
1322    def _validate_inputs(self):
1323        """Validates the inputs for the annotation process and returns a dictionary
1324        containing information for message generation, or False if no messages are needed.
1325
1326        This function performs the following checks:
1327
1328        - If an `embeddings_save_path` is provided:
1329            - Validates the image data signature by comparing it with the signature
1330            of the image data in the viewer's selection.
1331            - Checks for existing embeddings at the specified path.
1332                - If existing embeddings are found, it attempts to load parameters
1333                like tile shape, halo, and model type from the Zarr attributes.
1334                - An informational message is generated based on the loaded parameters.
1335                - If loading existing embeddings fails, an error message is generated.
1336                - If no existing embeddings are found, an informational message is generated.
1337        - If no `embeddings_save_path` is provided, the function returns None.
1338
1339        Returns:
1340            bool: True if the computation should be aborted, otherwise False.
1341        """
1342
1343        # Check if we have an existing input image to compute the embeddings.
1344        image = self.image_selection.get_value()
1345        if image is None:
1346            return _generate_message("error", "No image has been selected.")
1347
1348        # Check if we have an existing embedding path.
1349        # If yes we check the data signature of these embeddings against the selected image
1350        # and we ask the user if they want to load these embeddings.
1351        if self.embeddings_save_path and os.listdir(self.embeddings_save_path):
1352            try:
1353                f = zarr.open(self.embeddings_save_path, mode="a")
1354
1355                # Validate that the embeddings are complete.
1356                # Note: 'input_size' is the last value set in the attrs of f,
1357                # so we can use it as a proxy to check if the embeddings are fully computed
1358                if "input_size" not in f.attrs:
1359                    msg = (f"The embeddings at {self.embeddings_save_path} are incomplete. "
1360                           "Specify a different path or remove them.")
1361                    return _generate_message("error", msg)
1362
1363                # Validate image data signature.
1364                if "data_signature" in f.attrs:
1365                    image = self.image_selection.get_value()
1366                    img_signature = util._compute_data_signature(image.data)
1367                    if img_signature != f.attrs["data_signature"]:
1368                        msg = f"The embeddings don't match with the image: {img_signature} {f.attrs['data_signature']}"
1369                        return _generate_message("error", msg)
1370
1371                # Load existing parameters.
1372                self.model_type = f.attrs.get("model_name", f.attrs["model_type"])
1373                if "tile_shape" in f.attrs and f.attrs["tile_shape"] is not None:
1374                    self.tile_x, self.tile_y = f.attrs["tile_shape"]
1375                    self.halo_x, self.halo_y = f.attrs["halo"]
1376                    val_results = {
1377                        "message_type": "info",
1378                        "message": (f"Load embeddings for model: {self.model_type} with tile shape: "
1379                                    f"{self.tile_x}, {self.tile_y} and halo: {self.halo_x}, {self.halo_y}.")
1380                    }
1381                else:
1382                    self.tile_x, self.tile_y = 0, 0
1383                    self.halo_x, self.halo_y = 0, 0
1384                    val_results = {
1385                        "message_type": "info",
1386                        "message": f"Load embeddings for model: {self.model_type}."
1387                    }
1388
1389                return _generate_message(val_results["message_type"], val_results["message"])
1390
1391            except RuntimeError as e:
1392                val_results = {
1393                    "message_type": "error",
1394                    "message": f"Failed to load image embeddings: {e}"
1395                }
1396                return _generate_message(val_results["message_type"], val_results["message"])
1397
1398        # Otherwise we either don't have an embedding path or it is empty. We can proceed in both cases.
1399        return False
1400
1401    def _validate_existing_embeddings(self, state):
1402        if state.image_embeddings is None:
1403            return False
1404        else:
1405            val_results = {
1406                "message_type": "info",
1407                "message": "Embeddings have already been precomputed. Press OK to recompute the embeddings."
1408            }
1409            return _generate_message(val_results["message_type"], val_results["message"])
1410
1411    def __call__(self, skip_validate=False):
1412        self._validate_model_type_and_custom_weights()
1413
1414        # Validate user inputs.
1415        if not skip_validate and self._validate_inputs():
1416            return
1417
1418        # Get the image.
1419        image = self.image_selection.get_value()
1420
1421        # Update the image embeddings:
1422        state = AnnotatorState()
1423        if self._validate_existing_embeddings(state):
1424            # Whether embeddings already exist to control existing objects in layers.
1425            state.skip_recomputing_embeddings = True
1426            return
1427
1428        state.skip_recomputing_embeddings = False
1429        # Reset the state.
1430        state.reset_state()
1431
1432        # Get image dimensions.
1433        if image.rgb:
1434            ndim = image.data.ndim - 1
1435            state.image_shape = image.data.shape[:-1]
1436        else:
1437            ndim = image.data.ndim
1438            state.image_shape = image.data.shape
1439
1440        # Set layer scale
1441        state.image_scale = tuple(image.scale)
1442
1443        # Process tile_shape and halo, set other data.
1444        tile_shape, halo = _process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y)
1445        save_path = None if self.embeddings_save_path == "" else self.embeddings_save_path
1446        image_data = image.data
1447
1448        # Set up progress bar and signals for using it within a threadworker.
1449        pbar, pbar_signals = _create_pbar_for_threadworker()
1450
1451        # @thread_worker()
1452        def compute_image_embedding():
1453
1454            def pbar_init(total, description):
1455                pbar_signals.pbar_total.emit(total)
1456                pbar_signals.pbar_description.emit(description)
1457
1458            # Whether to prefer decoder.
1459            # With 'amg', it is set to 'False', else it is 'True' for the default 'auto' and 'ais' mode.
1460            prefer_decoder = True
1461            if self.automatic_segmentation_mode == "amg":
1462                prefer_decoder = False
1463
1464            state.initialize_predictor(
1465                image_data, model_type=self.model_type, save_path=save_path, ndim=ndim,
1466                device=self.device, checkpoint_path=self.custom_weights, tile_shape=tile_shape, halo=halo,
1467                prefer_decoder=prefer_decoder, pbar_init=pbar_init,
1468                pbar_update=lambda update: pbar_signals.pbar_update.emit(update),
1469            )
1470            pbar_signals.pbar_stop.emit()
1471
1472        compute_image_embedding()
1473        self._update_model(state)
1474        # worker = compute_image_embedding()
1475        # worker.returned.connect(self._update_model)
1476        # worker.start()
1477        # return worker

QWidget(parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowFlags())

EmbeddingWidget(parent=None)
1166    def __init__(self, parent=None):
1167        super().__init__(parent=parent)
1168
1169        # Create a nested layout for the sections.
1170        # Section 1: Image and Model.
1171        section1_layout = QtWidgets.QHBoxLayout()
1172        section1_layout.addLayout(self._create_image_section())
1173        section1_layout.addLayout(self._create_model_section())  # Creates the model family widget section.
1174        self.layout().addLayout(section1_layout)
1175
1176        # Section 2: Settings (collapsible).
1177        self.layout().addWidget(self._create_settings_widget())
1178
1179        # Section 3: The button to trigger the embedding computation.
1180        self.run_button = QtWidgets.QPushButton("Compute Embeddings")
1181        self.run_button.clicked.connect(self._initialize_image)
1182        self.run_button.clicked.connect(self.__call__)
1183        self.run_button.setToolTip(get_tooltip("embedding", "run_button"))
1184        self.layout().addWidget(self.run_button)
run_button
class SegmentNDWidget(_WidgetBase):
1509class SegmentNDWidget(_WidgetBase):
1510    def __init__(self, viewer, tracking, parent=None):
1511        super().__init__(parent=parent)
1512        self._viewer = viewer
1513        self.tracking = tracking
1514
1515        # Add the settings.
1516        self.settings = self._create_settings()
1517        self.layout().addWidget(self.settings)
1518
1519        # Add the run button.
1520        button_title = "Segment All Frames [Shift-S]" if self.tracking else "Segment All Slices [Shift-S]"
1521        self.run_button = QtWidgets.QPushButton(button_title)
1522        self.run_button.clicked.connect(self.__call__)
1523        self.layout().addWidget(self.run_button)
1524
1525    def _create_settings(self):
1526        setting_values = QtWidgets.QWidget()
1527        setting_values.setToolTip(get_tooltip("segmentnd", "settings"))
1528        setting_values.setLayout(QtWidgets.QVBoxLayout())
1529
1530        # Create the UI for the projection modes.
1531        self.projection = "single_point"
1532        self.projection_dropdown, layout = self._add_choice_param(
1533            "projection", self.projection, PROJECTION_MODES, tooltip=get_tooltip("segmentnd", "projection_dropdown")
1534            )
1535        setting_values.layout().addLayout(layout)
1536
1537        # Create the UI element for the IOU threshold.
1538        self.iou_threshold = 0.5
1539        self.iou_threshold_param, layout = self._add_float_param(
1540            "iou_threshold", self.iou_threshold, tooltip=get_tooltip("segmentnd", "iou_threshold")
1541            )
1542        setting_values.layout().addLayout(layout)
1543
1544        # Create the UI element for the box extension.
1545        self.box_extension = 0.05
1546        self.box_extension_param, layout = self._add_float_param(
1547            "box_extension", self.box_extension, tooltip=get_tooltip("segmentnd", "box_extension")
1548            )
1549        setting_values.layout().addLayout(layout)
1550
1551        # Create the UI element for the motion smoothing (if we have the tracking widget).
1552        if self.tracking:
1553            self.motion_smoothing = 0.5
1554            self.motion_smoothing_param, layout = self._add_float_param(
1555                "motion_smoothing", self.motion_smoothing, tooltip=get_tooltip("segmentnd", "motion_smoothing")
1556                )
1557            setting_values.layout().addLayout(layout)
1558
1559        settings = _make_collapsible(setting_values, title="Segmentation Settings")
1560        return settings
1561
1562    def _run_tracking(self):
1563        state = AnnotatorState()
1564        pbar, pbar_signals = _create_pbar_for_threadworker()
1565
1566        # @thread_worker
1567        def tracking_impl():
1568            shape = state.image_shape
1569
1570            pbar_signals.pbar_total.emit(shape[0])
1571            pbar_signals.pbar_description.emit("Track object")
1572
1573            # Step 1: Segment all slices with prompts.
1574            seg, slices, _, stop_upper = vutil.segment_slices_with_prompts(
1575                state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"],
1576                state.image_embeddings, shape, track_id=state.current_track_id,
1577                update_progress=lambda update: pbar_signals.pbar_update.emit(update),
1578            )
1579
1580            # Step 2: Track the object starting from the lowest annotated slice.
1581            seg, has_division = vutil.track_from_prompts(
1582                self._viewer.layers["point_prompts"], self._viewer.layers["prompts"], seg,
1583                state.predictor, slices, state.image_embeddings, stop_upper,
1584                threshold=self.iou_threshold, projection=self.projection,
1585                motion_smoothing=self.motion_smoothing,
1586                box_extension=self.box_extension,
1587                update_progress=lambda update: pbar_signals.pbar_update.emit(update),
1588            )
1589
1590            pbar_signals.pbar_stop.emit()
1591            return seg, has_division
1592
1593        def update_segmentation(ret_val):
1594            seg, has_division = ret_val
1595            # If a division has occurred and it's the first time it occurred for this track
1596            # then we need to create the two daughter tracks and update the lineage.
1597            if has_division and (len(state.lineage[state.current_track_id]) == 0):
1598                _update_lineage(self._viewer)
1599
1600            # Clear the old track mask.
1601            self._viewer.layers["current_object"].data[
1602                self._viewer.layers["current_object"].data == state.current_track_id
1603            ] = 0
1604            # Set the new object mask.
1605            self._viewer.layers["current_object"].data[seg == 1] = state.current_track_id
1606            self._viewer.layers["current_object"].refresh()
1607
1608        ret_val = tracking_impl()
1609        update_segmentation(ret_val)
1610        # worker = tracking_impl()
1611        # worker.returned.connect(update_segmentation)
1612        # worker.start()
1613        # return worker
1614
1615    def _run_volumetric_segmentation(self):
1616        pbar, pbar_signals = _create_pbar_for_threadworker()
1617
1618        # @thread_worker
1619        def volumetric_segmentation_impl():
1620            state = AnnotatorState()
1621            shape = state.image_shape
1622
1623            pbar_signals.pbar_total.emit(shape[0])
1624            pbar_signals.pbar_description.emit("Segment object")
1625
1626            # Step 1: Segment all slices with prompts.
1627            seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts(
1628                state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"],
1629                state.image_embeddings, shape,
1630                update_progress=lambda update: pbar_signals.pbar_update.emit(update),
1631            )
1632
1633            # Step 2: Segment the rest of the volume based on projecting prompts.
1634            seg, (z_min, z_max) = segment_mask_in_volume(
1635                seg, state.predictor, state.image_embeddings, slices,
1636                stop_lower, stop_upper,
1637                iou_threshold=self.iou_threshold, projection=self.projection,
1638                box_extension=self.box_extension,
1639                update_progress=lambda update: pbar_signals.pbar_update.emit(update),
1640            )
1641            pbar_signals.pbar_stop.emit()
1642
1643            state.z_range = (z_min, z_max)
1644            return seg
1645
1646        def update_segmentation(seg):
1647            self._viewer.layers["current_object"].data = seg
1648            self._viewer.layers["current_object"].refresh()
1649
1650        seg = volumetric_segmentation_impl()
1651        self._viewer.layers["current_object"].data = seg
1652        self._viewer.layers["current_object"].refresh()
1653        # worker = volumetric_segmentation_impl()
1654        # worker.returned.connect(update_segmentation)
1655        # worker.start()
1656        # return worker
1657
1658    def __call__(self):
1659        if _validate_embeddings(self._viewer):
1660            return None
1661        if _validate_layers(self._viewer):
1662            return None
1663
1664        if self.tracking:
1665            return self._run_tracking()
1666        else:
1667            return self._run_volumetric_segmentation()

QWidget(parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowFlags())

SegmentNDWidget(viewer, tracking, parent=None)
1510    def __init__(self, viewer, tracking, parent=None):
1511        super().__init__(parent=parent)
1512        self._viewer = viewer
1513        self.tracking = tracking
1514
1515        # Add the settings.
1516        self.settings = self._create_settings()
1517        self.layout().addWidget(self.settings)
1518
1519        # Add the run button.
1520        button_title = "Segment All Frames [Shift-S]" if self.tracking else "Segment All Slices [Shift-S]"
1521        self.run_button = QtWidgets.QPushButton(button_title)
1522        self.run_button.clicked.connect(self.__call__)
1523        self.layout().addWidget(self.run_button)
tracking
settings
run_button
class AutoSegmentWidget(_WidgetBase):
1737class AutoSegmentWidget(_WidgetBase):
1738    def __init__(self, viewer, with_decoder, volumetric, parent=None):
1739        super().__init__(parent)
1740
1741        self._viewer = viewer
1742        self.with_decoder = with_decoder
1743        self.volumetric = volumetric
1744        self._create_widget()
1745
1746    def _create_widget(self):
1747        # Add the switch for segmenting the slice vs. the volume if we have a volume.
1748        if self.volumetric:
1749            self.layout().addWidget(self._create_volumetric_switch())
1750
1751        # Add the nested settings widget.
1752        self.settings = self._create_settings()
1753        self.layout().addWidget(self.settings)
1754
1755        # Add the run button.
1756        self.run_button = QtWidgets.QPushButton("Automatic Segmentation")
1757        self.run_button.clicked.connect(self.__call__)
1758        self.run_button.setToolTip(get_tooltip("autosegment", "run_button"))
1759        self.layout().addWidget(self.run_button)
1760
1761    def _reset_segmentation_mode(self, with_decoder):
1762        # If we already have the same segmentation mode we don't need to do anything.
1763        if with_decoder == self.with_decoder:
1764            return
1765
1766        # Otherwise we change the value of with_decoder.
1767        self.with_decoder = with_decoder
1768
1769        # Then we clear the whole widget.
1770        layout = self.layout()
1771        while layout.count():
1772            child = layout.takeAt(0)
1773            if child.widget():
1774                child.widget().deleteLater()
1775
1776        # And then we reset it.
1777        self._create_widget()
1778
1779    def _create_volumetric_switch(self):
1780        self.apply_to_volume = False
1781        return self._add_boolean_param(
1782            "apply_to_volume", self.apply_to_volume, title="Apply to Volume",
1783            tooltip=get_tooltip("autosegment", "apply_to_volume")
1784        )
1785
1786    def _add_common_settings(self, settings):
1787        # Create the UI element for min object size.
1788        self.min_object_size = 100
1789        self.min_object_size_param, layout = self._add_int_param(
1790            "min_object_size", self.min_object_size, min_val=0, max_val=int(1e4),
1791            tooltip=get_tooltip("autosegment", "min_object_size")
1792        )
1793        settings.layout().addLayout(layout)
1794
1795        # Add extra settings for volumetric segmentation: gap_closing and min_extent.
1796        if self.volumetric:
1797            self.gap_closing = 2
1798            self.gap_closing_param, layout = self._add_int_param(
1799                "gap_closing", self.gap_closing, min_val=0, max_val=10,
1800                tooltip=get_tooltip("autosegment", "gap_closing")
1801                )
1802            settings.layout().addLayout(layout)
1803
1804            self.min_extent = 2
1805            self.min_extent_param, layout = self._add_int_param(
1806                "min_extent", self.min_extent, min_val=0, max_val=10,
1807                tooltip=get_tooltip("autosegment", "min_extent")
1808                )
1809            settings.layout().addLayout(layout)
1810
1811    def _ais_settings(self):
1812        settings = QtWidgets.QWidget()
1813        settings.setLayout(QtWidgets.QVBoxLayout())
1814
1815        # Create the UI element for center_distance_threshold.
1816        self.center_distance_thresh = 0.5
1817        self.center_distance_thresh_param, layout = self._add_float_param(
1818            "center_distance_thresh", self.center_distance_thresh,
1819            tooltip=get_tooltip("autosegment", "center_distance_thresh")
1820        )
1821        settings.layout().addLayout(layout)
1822
1823        # Create the UI element for boundary_distance_threshold.
1824        self.boundary_distance_thresh = 0.5
1825        self.boundary_distance_thresh_param, layout = self._add_float_param(
1826            "boundary_distance_thresh", self.boundary_distance_thresh,
1827            tooltip=get_tooltip("autosegment", "boundary_distance_thresh")
1828        )
1829        settings.layout().addLayout(layout)
1830
1831        # Add min_object_size.
1832        self._add_common_settings(settings)
1833
1834        return settings
1835
1836    def _amg_settings(self):
1837        settings = QtWidgets.QWidget()
1838        settings.setLayout(QtWidgets.QVBoxLayout())
1839
1840        # Create the UI element for pred_iou_thresh.
1841        self.pred_iou_thresh = 0.88
1842        self.pred_iou_thresh_param, layout = self._add_float_param(
1843            "pred_iou_thresh", self.pred_iou_thresh,
1844            tooltip=get_tooltip("autosegment", "pred_iou_thresh")
1845            )
1846        settings.layout().addLayout(layout)
1847
1848        # Create the UI element for stability score thresh.
1849        self.stability_score_thresh = 0.95
1850        self.stability_score_thresh_param, layout = self._add_float_param(
1851            "stability_score_thresh", self.stability_score_thresh,
1852            tooltip=get_tooltip("autosegment", "stability_score_thresh")
1853        )
1854        settings.layout().addLayout(layout)
1855
1856        # Create the UI element for box nms thresh.
1857        self.box_nms_thresh = 0.7
1858        self.box_nms_thresh_param, layout = self._add_float_param(
1859            "box_nms_thresh", self.box_nms_thresh,
1860            tooltip=get_tooltip("autosegment", "box_nms_thresh")
1861            )
1862        settings.layout().addLayout(layout)
1863
1864        # Add min_object_size.
1865        self._add_common_settings(settings)
1866
1867        return settings
1868
1869    def _create_settings(self):
1870        setting_values = self._ais_settings() if self.with_decoder else self._amg_settings()
1871        settings = _make_collapsible(setting_values, title="Automatic Segmentation Settings")
1872        return settings
1873
1874    def _empty_segmentation_warning(self):
1875        msg = "The automatic segmentation result does not contain any objects."
1876        msg += "Setting a smaller value for 'min_object_size' may help."
1877        if not self.with_decoder:
1878            msg += "Setting smaller values for 'pred_iou_thresh' and 'stability_score_thresh' may also help."
1879        val_results = {"message_type": "error", "message": msg}
1880        return _generate_message(val_results["message_type"], val_results["message"])
1881
1882    def _run_segmentation_2d(self, kwargs, i=None):
1883        pbar, pbar_signals = _create_pbar_for_threadworker()
1884
1885        # @thread_worker
1886        def seg_impl():
1887            def pbar_init(total, description):
1888                pbar_signals.pbar_total.emit(total)
1889                pbar_signals.pbar_description.emit(description)
1890
1891            seg = _instance_segmentation_impl(
1892                self.min_object_size, i=i, pbar_init=pbar_init,
1893                pbar_update=lambda update: pbar_signals.pbar_update.emit(update),
1894                **kwargs
1895            )
1896            pbar_signals.pbar_stop.emit()
1897            return seg
1898
1899        def update_segmentation(seg):
1900            is_empty = seg.max() == 0
1901            if is_empty:
1902                self._empty_segmentation_warning()
1903
1904            if i is None:
1905                self._viewer.layers["auto_segmentation"].data = seg
1906            else:
1907                self._viewer.layers["auto_segmentation"].data[i] = seg
1908            self._viewer.layers["auto_segmentation"].refresh()
1909
1910        # Validate all layers.
1911        _validate_layers(self._viewer, automatic_segmentation=True)
1912
1913        seg = seg_impl()
1914        update_segmentation(seg)
1915        # worker = seg_impl()
1916        # worker.returned.connect(update_segmentation)
1917        # worker.start()
1918        # return worker
1919
1920    # We refuse to run 3D segmentation with the AMG unless we have a GPU or all embeddings
1921    # are precomputed. Otherwise this would take too long.
1922    def _allow_segment_3d(self):
1923        if self.with_decoder:
1924            return True
1925        state = AnnotatorState()
1926        predictor = state.predictor
1927        if str(predictor.device) == "cpu" or str(predictor.device) == "mps":
1928            n_slices = self._viewer.layers["auto_segmentation"].data.shape[0]
1929            embeddings_are_precomputed = (state.amg_state is not None) and (len(state.amg_state) > n_slices)
1930            if not embeddings_are_precomputed:
1931                return False
1932        return True
1933
1934    def _run_segmentation_3d(self, kwargs):
1935        allow_segment_3d = self._allow_segment_3d()
1936        if not allow_segment_3d:
1937            val_results = {
1938                "message_type": "error",
1939                "message": "Volumetric segmentation with AMG is only supported if you have a GPU."
1940            }
1941            return _generate_message(val_results["message_type"], val_results["message"])
1942
1943        pbar, pbar_signals = _create_pbar_for_threadworker()
1944
1945        # @thread_worker
1946        def seg_impl():
1947            segmentation = np.zeros_like(self._viewer.layers["auto_segmentation"].data)
1948            offset = 0
1949
1950            def pbar_init(total, description):
1951                pbar_signals.pbar_total.emit(total)
1952                pbar_signals.pbar_description.emit(description)
1953
1954            pbar_init(segmentation.shape[0], "Segment volume")
1955
1956            # Further optimization: parallelize if state is precomputed for all slices
1957            for i in range(segmentation.shape[0]):
1958                seg = _instance_segmentation_impl(self.min_object_size, i=i, **kwargs)
1959                seg_max = seg.max()
1960                if seg_max == 0:
1961                    continue
1962                seg[seg != 0] += offset
1963                offset = seg_max + offset
1964                segmentation[i] = seg
1965                pbar_signals.pbar_update.emit(1)
1966
1967            pbar_signals.pbar_reset.emit()
1968            segmentation = merge_instance_segmentation_3d(
1969                segmentation, beta=0.5,  gap_closing=self.gap_closing, min_z_extent=self.min_extent,
1970                verbose=True, pbar_init=pbar_init, pbar_update=lambda update: pbar_signals.pbar_update.emit(1),
1971            )
1972            pbar_signals.pbar_stop.emit()
1973            return segmentation
1974
1975        def update_segmentation(segmentation):
1976            is_empty = segmentation.max() == 0
1977            if is_empty:
1978                self._empty_segmentation_warning()
1979            self._viewer.layers["auto_segmentation"].data = segmentation
1980            self._viewer.layers["auto_segmentation"].refresh()
1981
1982        seg = seg_impl()
1983        update_segmentation(seg)
1984        # worker = seg_impl()
1985        # worker.returned.connect(update_segmentation)
1986        # worker.start()
1987        # return worker
1988
1989    def __call__(self):
1990        if _validate_embeddings(self._viewer):
1991            return None
1992
1993        if self.with_decoder:
1994            kwargs = {
1995                "center_distance_threshold": self.center_distance_thresh,
1996                "boundary_distance_threshold": self.boundary_distance_thresh,
1997                "min_size": self.min_object_size,
1998            }
1999        else:
2000            kwargs = {
2001                "pred_iou_thresh": self.pred_iou_thresh,
2002                "stability_score_thresh": self.stability_score_thresh,
2003                "box_nms_thresh": self.box_nms_thresh,
2004            }
2005        if self.volumetric and self.apply_to_volume:
2006            worker = self._run_segmentation_3d(kwargs)
2007        elif self.volumetric and not self.apply_to_volume:
2008            i = int(self._viewer.dims.point[0])
2009            worker = self._run_segmentation_2d(kwargs, i=i)
2010        else:
2011            worker = self._run_segmentation_2d(kwargs)
2012        _select_layer(self._viewer, "auto_segmentation")
2013        return worker

QWidget(parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowFlags())

AutoSegmentWidget(viewer, with_decoder, volumetric, parent=None)
1738    def __init__(self, viewer, with_decoder, volumetric, parent=None):
1739        super().__init__(parent)
1740
1741        self._viewer = viewer
1742        self.with_decoder = with_decoder
1743        self.volumetric = volumetric
1744        self._create_widget()
with_decoder
volumetric
class AutoTrackWidget(AutoSegmentWidget):
2016class AutoTrackWidget(AutoSegmentWidget):
2017    def _create_tracking_switch(self):
2018        self.apply_to_volume = False
2019        return self._add_boolean_param(
2020            "apply_to_volume", self.apply_to_volume, title="Track Timeseries",
2021            tooltip=get_tooltip("autotrack", "run_tracking")
2022        )
2023
2024    def _create_widget(self):
2025        # Add the switch for segmenting the slice vs. tracking the timeseries.
2026        self.layout().addWidget(self._create_tracking_switch())
2027
2028        # Add the nested settings widget.
2029        self.settings = self._create_settings()
2030        self.layout().addWidget(self.settings)
2031
2032        # Add the run button.
2033        self.run_button = QtWidgets.QPushButton("Automatic Tracking")
2034        self.run_button.clicked.connect(self.__call__)
2035        self.run_button.setToolTip(get_tooltip("autotrack", "run_button"))
2036        self.layout().addWidget(self.run_button)
2037
2038    def _run_segmentation_3d(self, kwargs):
2039        allow_segment_3d = self._allow_segment_3d()
2040        if not allow_segment_3d:
2041            return _generate_message("error", "Tracking with AMG is only supported if you have a GPU.")
2042
2043        state = AnnotatorState()
2044        if len(state.committed_lineages) > 0:
2045            return _generate_message(
2046                "error",
2047                "Automatic tracking can only be called if you haven't commited results from interactive tracking yet."
2048            )
2049        pbar, pbar_signals = _create_pbar_for_threadworker()
2050
2051        # @thread_worker
2052        def seg_impl():
2053            image_name = state.get_image_name(self._viewer)
2054            timeseries = self._viewer.layers[image_name].data
2055            segmentation = np.zeros_like(self._viewer.layers["auto_segmentation"].data)
2056            offset = 0
2057
2058            def pbar_init(total, description):
2059                pbar_signals.pbar_total.emit(total)
2060                pbar_signals.pbar_description.emit(description)
2061
2062            pbar_init(segmentation.shape[0], "Run tracking")
2063
2064            # Further optimization: parallelize if state is precomputed for all slices
2065            for i in range(segmentation.shape[0]):
2066                seg = _instance_segmentation_impl(self.min_object_size, i=i, **kwargs)
2067                seg_max = seg.max()
2068                if seg_max == 0:
2069                    continue
2070                seg[seg != 0] += offset
2071                offset = seg_max + offset
2072                segmentation[i] = seg
2073                pbar_signals.pbar_update.emit(1)
2074
2075            pbar_signals.pbar_reset.emit()
2076            segmentation, lineages = track_across_frames(
2077                timeseries, segmentation,
2078                verbose=True, pbar_init=pbar_init,
2079                pbar_update=lambda update: pbar_signals.pbar_update.emit(1),
2080            )
2081            pbar_signals.pbar_stop.emit()
2082            return (segmentation, lineages)
2083
2084        def update_segmentation(result):
2085            segmentation, lineages = result
2086            is_empty = segmentation.max() == 0
2087            if is_empty:
2088                self._empty_segmentation_warning()
2089
2090            state = AnnotatorState()
2091            state.lineage = lineages
2092
2093            self._viewer.layers["auto_segmentation"].data = segmentation
2094            self._viewer.layers["auto_segmentation"].refresh()
2095
2096        result = seg_impl()
2097        update_segmentation(result)
2098        # worker = seg_impl()
2099        # worker.returned.connect(update_segmentation)
2100        # worker.start()
2101        # return worker

QWidget(parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowFlags())