micro_sam.sam_annotator._widgets
Implements the widgets used in the annotation plugins.
1"""Implements the widgets used in the annotation plugins. 2""" 3 4import os 5import gc 6import multiprocessing as mp 7import pickle 8from pathlib import Path 9from typing import Optional 10 11import h5py 12import json 13import zarr 14import napari 15import numpy as np 16 17try: 18 import z5py 19except ImportError: 20 z5py = None 21 22from bioimage_cpp.utils import segmentation_overlap 23 24import elf.parallel 25 26from qtpy import QtWidgets 27from qtpy.QtCore import QObject, Signal 28from superqt import QCollapsible 29from napari.utils.notifications import show_info 30from magicgui import magic_factory 31from magicgui.widgets import ComboBox, Container, create_widget 32# We have disabled the thread workers for now because they result in a 33# massive slowdown in napari >= 0.5. 34# See also https://forum.image.sc/t/napari-thread-worker-leads-to-massive-slowdown/103786 35# from napari.qt.threading import thread_worker 36from napari.utils import progress 37 38from . import util as vutil 39from ._tooltips import get_tooltip 40from ._state import AnnotatorState 41from .. import instance_segmentation, util 42from ..multi_dimensional_segmentation import ( 43 segment_mask_in_volume, merge_instance_segmentation_3d, track_across_frames, PROJECTION_MODES, get_napari_track_data 44) 45 46 47# 48# Convenience functionality for creating QT UI and manipulating the napari viewer. 49# 50 51 52def _select_layer(viewer, layer_name): 53 viewer.layers.selection.select_only(viewer.layers[layer_name]) 54 55 56# Create a collapsible around the widget 57def _make_collapsible(widget, title): 58 parent_widget = QtWidgets.QWidget() 59 parent_widget.setLayout(QtWidgets.QVBoxLayout()) 60 collapsible = QCollapsible(title, parent_widget) 61 collapsible.addWidget(widget) 62 parent_widget.layout().addWidget(collapsible) 63 return parent_widget 64 65 66# Base class for a widget with convenience functionality for adding parameters. 67class _WidgetBase(QtWidgets.QWidget): 68 def __init__(self, parent=None): 69 super().__init__(parent) 70 self.setLayout(QtWidgets.QVBoxLayout()) 71 72 def _add_boolean_param(self, name, value, title=None, tooltip=None): 73 checkbox = QtWidgets.QCheckBox(name if title is None else title) 74 checkbox.setChecked(value) 75 checkbox.stateChanged.connect(lambda val: setattr(self, name, val)) 76 if tooltip: 77 checkbox.setToolTip(tooltip) 78 return checkbox 79 80 def _add_string_param(self, name, value, title=None, placeholder=None, layout=None, tooltip=None): 81 if layout is None: 82 layout = QtWidgets.QHBoxLayout() 83 label = QtWidgets.QLabel(title or name) 84 if tooltip: 85 label.setToolTip(tooltip) 86 layout.addWidget(label) 87 param = QtWidgets.QLineEdit() 88 param.setText(value) 89 if placeholder is not None: 90 param.setPlaceholderText(placeholder) 91 param.textChanged.connect(lambda val: setattr(self, name, val)) 92 if tooltip: 93 param.setToolTip(tooltip) 94 layout.addWidget(param) 95 return param, layout 96 97 def _add_float_param(self, name, value, title=None, min_val=0.0, max_val=1.0, decimals=2, 98 step=0.01, layout=None, tooltip=None): 99 if layout is None: 100 layout = QtWidgets.QHBoxLayout() 101 label = QtWidgets.QLabel(title or name) 102 if tooltip: 103 label.setToolTip(tooltip) 104 layout.addWidget(label) 105 param = QtWidgets.QDoubleSpinBox() 106 param.setRange(min_val, max_val) 107 param.setDecimals(decimals) 108 param.setValue(value) 109 param.setSingleStep(step) 110 param.valueChanged.connect(lambda val: setattr(self, name, val)) 111 if tooltip: 112 param.setToolTip(tooltip) 113 layout.addWidget(param) 114 return param, layout 115 116 def _add_int_param(self, name, value, min_val, max_val, title=None, step=1, layout=None, tooltip=None): 117 if layout is None: 118 layout = QtWidgets.QHBoxLayout() 119 label = QtWidgets.QLabel(title or name) 120 if tooltip: 121 label.setToolTip(tooltip) 122 layout.addWidget(label) 123 param = QtWidgets.QSpinBox() 124 param.setRange(min_val, max_val) 125 param.setValue(value) 126 param.setSingleStep(step) 127 param.valueChanged.connect(lambda val: setattr(self, name, val)) 128 if tooltip: 129 param.setToolTip(tooltip) 130 layout.addWidget(param) 131 return param, layout 132 133 def _add_choice_param(self, name, value, options, title=None, layout=None, update=None, tooltip=None): 134 if layout is None: 135 layout = QtWidgets.QHBoxLayout() 136 label = QtWidgets.QLabel(title or name) 137 if tooltip: 138 label.setToolTip(tooltip) 139 layout.addWidget(label) 140 141 # Create the dropdown menu via QComboBox, set the available values. 142 dropdown = QtWidgets.QComboBox() 143 dropdown.addItems(options) 144 if update is None: 145 dropdown.currentIndexChanged.connect(lambda index: setattr(self, name, options[index])) 146 else: 147 dropdown.currentIndexChanged.connect(update) 148 149 # Set the correct value for the value. 150 dropdown.setCurrentIndex(dropdown.findText(value)) 151 152 if tooltip: 153 dropdown.setToolTip(tooltip) 154 155 layout.addWidget(dropdown) 156 return dropdown, layout 157 158 def _add_shape_param(self, names, values, min_val, max_val, step=1, title=None, tooltip=None): 159 layout = QtWidgets.QHBoxLayout() 160 161 x_layout = QtWidgets.QVBoxLayout() 162 x_param, _ = self._add_int_param( 163 names[0], values[0], min_val=min_val, max_val=max_val, layout=x_layout, step=step, 164 title=title[0] if title is not None else title, tooltip=tooltip 165 ) 166 layout.addLayout(x_layout) 167 168 y_layout = QtWidgets.QVBoxLayout() 169 y_param, _ = self._add_int_param( 170 names[1], values[1], min_val=min_val, max_val=max_val, layout=y_layout, step=step, 171 title=title[1] if title is not None else title, tooltip=tooltip 172 ) 173 layout.addLayout(y_layout) 174 175 return x_param, y_param, layout 176 177 def _add_path_param(self, name, value, select_type, title=None, placeholder=None, tooltip=None): 178 assert select_type in ("directory", "file", "both") 179 180 layout = QtWidgets.QHBoxLayout() 181 label = QtWidgets.QLabel(title or name) 182 if tooltip: 183 label.setToolTip(tooltip) 184 layout.addWidget(label) 185 186 path_textbox = QtWidgets.QLineEdit() 187 path_textbox.setText(str(value)) 188 if placeholder is not None: 189 path_textbox.setPlaceholderText(placeholder) 190 path_textbox.textChanged.connect(lambda val: setattr(self, name, val)) 191 if tooltip: 192 path_textbox.setToolTip(tooltip) 193 194 layout.addWidget(path_textbox) 195 196 def add_path_button(select_type, tooltip=None): 197 # Adjust button text. 198 button_text = f"Select {select_type.capitalize()}" 199 path_button = QtWidgets.QPushButton(button_text) 200 201 # Call appropriate function based on select_type. 202 path_button.clicked.connect(lambda: getattr(self, f"_get_{select_type}_path")(name, path_textbox)) 203 if tooltip: 204 path_button.setToolTip(tooltip) 205 layout.addWidget(path_button) 206 207 if select_type == "both": 208 add_path_button("file") 209 add_path_button("directory") 210 211 else: 212 add_path_button(select_type) 213 214 return path_textbox, layout 215 216 def _get_directory_path(self, name, textbox, tooltip=None): 217 directory = QtWidgets.QFileDialog.getExistingDirectory( 218 self, "Select Directory", "", QtWidgets.QFileDialog.ShowDirsOnly 219 ) 220 if tooltip: 221 directory.setToolTip(tooltip) 222 if directory and Path(directory).is_dir(): 223 textbox.setText(str(directory)) 224 else: 225 # Handle the case where the selected path is not a directory 226 print("Invalid directory selected. Please try again.") 227 228 def _get_file_path(self, name, textbox, tooltip=None): 229 file_path, _ = QtWidgets.QFileDialog.getOpenFileName( 230 self, "Select File", "", "All Files (*)" 231 ) 232 if tooltip: 233 file_path.setToolTip(tooltip) 234 if file_path and Path(file_path).is_file(): 235 textbox.setText(str(file_path)) 236 else: 237 # Handle the case where the selected path is not a file 238 print("Invalid file selected. Please try again.") 239 240 def _get_model_size_options(self): 241 # We store the actual model names mapped to UI labels. 242 self.model_size_mapping = {} 243 if self.model_family == "Natural Images (SAM)": 244 self.model_size_options = list(self._model_size_map .values()) 245 self.model_size_mapping = {self._model_size_map[k]: f"vit_{k}" for k in self._model_size_map.keys()} 246 else: 247 model_suffix = self.supported_dropdown_maps[self.model_family] 248 self.model_size_options = [] 249 250 for option in self.model_options: 251 if option.endswith(model_suffix): 252 # Extract model size character on-the-fly. 253 key = next((k for k in self._model_size_map .keys() if f"vit_{k}" in option), None) 254 if key: 255 size_label = self._model_size_map[key] 256 self.model_size_options.append(size_label) 257 self.model_size_mapping[size_label] = option # Store the actual model name. 258 259 # We ensure an assorted order of model sizes ('tiny' to 'huge') 260 self.model_size_options.sort(key=lambda x: ["tiny", "base", "large", "huge"].index(x)) 261 262 def _update_model_type(self): 263 # Get currently selected model size (before clearing dropdown) 264 current_selection = self.model_size_dropdown.currentText() 265 self._get_model_size_options() # Update model size options dynamically 266 267 # NOTE: We need to prevent recursive updates for this step temporarily. 268 self.model_size_dropdown.blockSignals(True) 269 270 # Let's clear and recreate the dropdown. 271 self.model_size_dropdown.clear() 272 self.model_size_dropdown.addItems(self.model_size_options) 273 274 # We restore the previous selection, if still valid. 275 if current_selection in self.model_size_options: 276 self.model_size = current_selection 277 else: 278 if self.model_size_options: # Default to the first available model size 279 self.model_size = self.model_size_options[0] 280 281 # Let's map the selection to the correct model type (eg. "tiny" -> "vit_t") 282 size_key = next( 283 (k for k, v in self._model_size_map.items() if v == self.model_size), "b" 284 ) 285 self.model_type = f"vit_{size_key}" + self.supported_dropdown_maps[self.model_family] 286 287 self.model_size_dropdown.setCurrentText(self.model_size) # Apply the selected text to the dropdown 288 289 # We force a refresh for UI here. 290 self.model_size_dropdown.update() 291 292 # NOTE: And finally, we should re-enable signals again. 293 self.model_size_dropdown.blockSignals(False) 294 295 def _create_model_section(self, default_model: str = util._DEFAULT_MODEL, create_layout: bool = True): 296 297 # Create a list of support dropdown values and correspond them to suffixes. 298 self.supported_dropdown_maps = { 299 "Natural Images (SAM)": "", 300 "Light Microscopy": "_lm", 301 "Electron Microscopy": "_em_organelles", 302 "Medical Imaging": "_medical_imaging", 303 "Histopathology": "_histopathology", 304 } 305 306 # NOTE: The available options for all are either 'tiny', 'base', 'large' or 'huge'. 307 self._model_size_map = {"t": "tiny", "b": "base", "l": "large", "h": "huge"} 308 309 self._default_model_choice = default_model 310 # Let's set the literally default model choice depending on 'micro-sam'. 311 self.model_family = {v: k for k, v in self.supported_dropdown_maps.items()}[self._default_model_choice[5:]] 312 313 kwargs = {} 314 if create_layout: 315 layout = QtWidgets.QVBoxLayout() 316 kwargs["layout"] = layout 317 318 # NOTE: We stick to the base variant for each model family. 319 # i.e. 'Natural Images (SAM)', 'Light Microscopy', 'Electron Microscopy', 'Medical_Imaging', 'Histopathology'. 320 self.model_family_dropdown, layout = self._add_choice_param( 321 "model_family", self.model_family, list(self.supported_dropdown_maps.keys()), 322 title="Model:", tooltip=get_tooltip("embedding", "model_family"), **kwargs, 323 ) 324 self.model_family_dropdown.currentTextChanged.connect(self._update_model_type) 325 return layout 326 327 def _create_model_size_section(self): 328 329 # Create UI for the model size. 330 # This would combine with the chosen 'self.model_family' and depend on 'self._default_model_choice'. 331 self.model_size = self._model_size_map[self._default_model_choice[4]] 332 333 # Get all model options. 334 self.model_options = list(util.models().urls.keys()) 335 # Filter out the decoders from the model list. 336 self.model_options = [model for model in self.model_options if not model.endswith("decoder")] 337 338 # Now, we get the available sizes per model family. 339 self._get_model_size_options() 340 341 self.model_size_dropdown, layout = self._add_choice_param( 342 "model_size", self.model_size, self.model_size_options, 343 title="model size:", tooltip=get_tooltip("embedding", "model_size"), 344 ) 345 self.model_size_dropdown.currentTextChanged.connect(self._update_model_type) 346 return layout 347 348 def _validate_model_type_and_custom_weights(self): 349 # Let's get all model combination stuff into the desired `model_type` structure. 350 self.model_type = "vit_" + self.model_size[0] + self.supported_dropdown_maps[self.model_family] 351 352 # For 'custom_weights', we remove the displayed text on top of the drop-down menu. 353 if self.custom_weights: 354 # NOTE: We prevent recursive updates for this step temporarily. 355 self.model_family_dropdown.blockSignals(True) 356 self.model_family_dropdown.setCurrentIndex(-1) # This removes the displayed text. 357 self.model_family_dropdown.update() 358 # NOTE: And re-enable signals again. 359 self.model_family_dropdown.blockSignals(False) 360 361 362# Custom signals for managing progress updates. 363class PBarSignals(QObject): 364 pbar_total = Signal(int) 365 pbar_update = Signal(int) 366 pbar_description = Signal(str) 367 pbar_stop = Signal() 368 pbar_reset = Signal() 369 370 371class InfoDialog(QtWidgets.QDialog): 372 def __init__(self, title, message): 373 super().__init__() 374 self.setWindowTitle(title) 375 376 layout = QtWidgets.QVBoxLayout() 377 layout.addWidget(QtWidgets.QLabel(message)) 378 379 # Add buttons 380 button_box = QtWidgets.QHBoxLayout() # Use QHBoxLayout for buttons side-by-side 381 accept_button = QtWidgets.QPushButton("OK") 382 accept_button.clicked.connect(lambda: self.button_clicked(accept_button)) # Connect to clicked signal 383 button_box.addWidget(accept_button) 384 385 cancel_button = QtWidgets.QPushButton("Cancel") 386 cancel_button.clicked.connect(lambda: self.button_clicked(cancel_button)) # Connect to clicked signal 387 button_box.addWidget(cancel_button) 388 389 layout.addLayout(button_box) 390 self.setLayout(layout) 391 392 def button_clicked(self, button): 393 if button.text() == "OK": 394 self.accept() # Accept the dialog 395 else: 396 self.reject() # Reject the dialog (Cancel) 397 398 399# Set up the progress bar. We handle this via custom signals that are passed as callbacks to the 400# function that does the actual work. We need callbacks for initializing the progress bar, 401# updating it and for stopping the progress bar. 402def _create_pbar_for_threadworker(): 403 pbar = progress() 404 pbar_signals = PBarSignals() 405 pbar_signals.pbar_total.connect(lambda total: setattr(pbar, "total", total)) 406 pbar_signals.pbar_update.connect(lambda update: pbar.update(update)) 407 pbar_signals.pbar_description.connect(lambda description: pbar.set_description(description)) 408 pbar_signals.pbar_stop.connect(lambda: pbar.close()) 409 pbar_signals.pbar_reset.connect(lambda: pbar.reset()) 410 return pbar, pbar_signals 411 412 413def _reset_tracking_state(viewer): 414 """Reset the tracking state. 415 416 This helper function is needed by the widgets clear_track and by commit_track. 417 """ 418 state = AnnotatorState() 419 420 # Reset the lineage and track id. 421 state.current_track_id = 1 422 state.lineage = {1: []} 423 424 # Reset the layer properties. 425 viewer.layers["point_prompts"].property_choices["track_id"] = ["1"] 426 viewer.layers["prompts"].property_choices["track_id"] = ["1"] 427 428 # Reset the choices in the track_id menu. 429 state.widgets["tracking"][1].value = "1" 430 state.widgets["tracking"][1].choices = ["1"] 431 432 433# 434# Widgets implemented with magicgui. 435# 436 437 438@magic_factory(call_button="Clear Annotations [Shift + C]") 439def clear(viewer: "napari.viewer.Viewer") -> None: 440 """Widget for clearing the current annotations. 441 442 Args: 443 viewer: The napari viewer. 444 """ 445 vutil.clear_annotations(viewer) 446 447 # Perform garbage collection. 448 gc.collect() 449 450 451@magic_factory(call_button="Clear Annotations [Shift + C]") 452def clear_volume(viewer: "napari.viewer.Viewer", all_slices: bool = True) -> None: 453 """Widget for clearing the current annotations in 3D. 454 455 Args: 456 viewer: The napari viewer. 457 all_slices: Choose whether to clear the annotations for all or only the current slice. 458 """ 459 if all_slices: 460 vutil.clear_annotations(viewer) 461 else: 462 i = int(viewer.dims.point[0]) 463 vutil.clear_annotations_slice(viewer, i=i) 464 465 # Perform garbage collection. 466 gc.collect() 467 468 469@magic_factory(call_button="Clear Annotations [Shift + C]") 470def clear_track(viewer: "napari.viewer.Viewer", all_frames: bool = True) -> None: 471 """Widget for clearing all tracking annotations and state. 472 473 Args: 474 viewer: The napari viewer. 475 all_frames: Choose whether to clear the annotations for all or only the current frame. 476 """ 477 if all_frames: 478 _reset_tracking_state(viewer) 479 vutil.clear_annotations(viewer) 480 else: 481 i = int(viewer.dims.point[0]) 482 vutil.clear_annotations_slice(viewer, i=i) 483 484 # Perform garbage collection. 485 gc.collect() 486 487 488def _mask_matched_objects(seg, prev_seg, preservation_threshold): 489 prev_ids = np.unique(prev_seg) 490 ovlp = segmentation_overlap(prev_seg, seg) 491 492 mask_ids, prev_mask_ids = [], [] 493 for prev_id in prev_ids: 494 ovlp_table = ovlp.overlaps_for_label_a(prev_id) 495 seg_ids, overlaps = ovlp_table["label"], ovlp_table["count"] 496 if seg_ids[0] != 0 and overlaps[0] >= preservation_threshold: 497 mask_ids.append(seg_ids[0]) 498 prev_mask_ids.append(prev_id) 499 500 preserve_mask = np.logical_or(np.isin(seg, mask_ids), np.isin(prev_seg, prev_mask_ids)) 501 return preserve_mask 502 503 504def _commit_impl(viewer, layer, preserve_mode, preservation_threshold): 505 state = AnnotatorState() 506 507 # Check whether all layers exist as expected or create new ones automatically. 508 state.annotator._require_layers(layer_choices=[layer, "committed_objects"]) 509 510 # Check if we have a z_range. If yes, use it to set a bounding box. 511 if state.z_range is None: 512 bb = np.s_[:] 513 else: 514 z_min, z_max = state.z_range 515 bb = np.s_[z_min:(z_max+1)] 516 517 # Cast the dtype of the segmentation we work with correctly. 518 # Otherwise we run into type conversion errors later. 519 dtype = viewer.layers["committed_objects"].data.dtype 520 seg = viewer.layers[layer].data[bb].astype(dtype) 521 shape = seg.shape 522 523 # We parallelize these operations because they take quite long for large volumes. 524 525 # Compute the max id in the commited objects. 526 # id_offset = int(viewer.layers["committed_objects"].data.max()) 527 full_shape = viewer.layers["committed_objects"].data.shape 528 id_offset = int( 529 elf.parallel.max(viewer.layers["committed_objects"].data, block_shape=util.get_block_shape(full_shape)) 530 ) 531 532 # Compute the mask for the current object. 533 # mask = seg != 0 534 mask = np.zeros(seg.shape, dtype="bool") 535 mask = elf.parallel.apply_operation( 536 seg, 0, np.not_equal, out=mask, block_shape=util.get_block_shape(shape) 537 ) 538 if preserve_mode != "none": 539 prev_seg = viewer.layers["committed_objects"].data[bb] 540 # The mode 'pixels' corresponds to a naive implementation where only committed pixels are preserved. 541 preserve_mask = prev_seg != 0 542 # If the preserve mask is empty we don't need to do anything else here, because we don't have prev objects. 543 if preserve_mask.sum() != 0: 544 # In the mode 'objects' we preserve committed objects instead, by comparing the overlaps 545 # of already committed and newly committed objects. 546 if preserve_mode == "objects": 547 preserve_mask = _mask_matched_objects(seg, prev_seg, preservation_threshold) 548 mask[preserve_mask] = 0 549 550 # Write the current object to committed objects. 551 seg[mask] += id_offset 552 viewer.layers["committed_objects"].data[bb][mask] = seg[mask] 553 viewer.layers["committed_objects"].refresh() 554 555 return id_offset, seg, mask, bb 556 557 558def _get_auto_segmentation_options(state, object_ids): 559 widget = state.widgets["autosegment"] 560 561 segmentation_options = {"object_ids": [int(object_id) for object_id in object_ids]} 562 if widget.with_decoder: 563 segmentation_options["boundary_distance_thresh"] = widget.boundary_distance_thresh 564 segmentation_options["center_distance_thresh"] = widget.center_distance_thresh 565 else: 566 segmentation_options["pred_iou_thresh"] = widget.pred_iou_thresh 567 segmentation_options["stability_score_thresh"] = widget.stability_score_thresh 568 segmentation_options["box_nms_thresh"] = widget.box_nms_thresh 569 570 segmentation_options["min_object_size"] = widget.min_object_size 571 if widget.volumetric: 572 segmentation_options["apply_to_volume"] = widget.apply_to_volume 573 segmentation_options["gap_closing"] = widget.gap_closing 574 segmentation_options["min_extent"] = widget.min_extent 575 576 return segmentation_options 577 578 579def _get_promptable_segmentation_options(state, object_ids): 580 segmentation_options = {"object_ids": [int(object_id) for object_id in object_ids]} 581 is_tracking = False 582 if "segment_nd" in state.widgets: 583 widget = state.widgets["segment_nd"] 584 segmentation_options["projection"] = widget.projection 585 segmentation_options["iou_threshold"] = widget.iou_threshold 586 segmentation_options["box_extension"] = widget.box_extension 587 if widget.tracking: 588 segmentation_options["motion_smoothing"] = widget.motion_smoothing 589 is_tracking = True 590 return segmentation_options, is_tracking 591 592 593def _commit_to_file(path, viewer, layer, seg, mask, bb, extra_attrs=None): 594 595 if z5py is None: 596 raise RuntimeError( 597 "Committing annotations to file requires z5py, which is only available via conda. " 598 "Install it with 'conda install -c conda-forge z5py'." 599 ) 600 601 # NOTE: zarr-python is quite inefficient and writes empty blocks. 602 # So we have to use z5py here. 603 604 # Deal with issues z5py has with empty folders and require the json. 605 if os.path.exists(path): 606 required_json = os.path.join(path, ".zgroup") 607 if not os.path.exists(required_json): 608 with open(required_json, "w") as f: 609 json.dump({"zarr_format": 2}, f) 610 611 f = z5py.ZarrFile(path, "a") 612 state = AnnotatorState() 613 614 def _save_signature(f, data_signature): 615 embeds = state.widgets["embeddings"] 616 tile_shape, halo = _process_tiling_inputs(embeds.tile_x, embeds.tile_y, embeds.halo_x, embeds.halo_y) 617 signature = util._get_embedding_signature( 618 input_=None, # We don't need this because we pass the data signature. 619 predictor=state.predictor, 620 tile_shape=tile_shape, 621 halo=halo, 622 data_signature=data_signature, 623 ) 624 for key, val in signature.items(): 625 f.attrs[key] = val 626 627 # If the data signature is saved in the file already, 628 # then we check if saved data signature and data signature of our image agree. 629 # If not, this file was used for committing objects from another file. 630 if "data_signature" in f.attrs: 631 saved_signature = f.attrs["data_signature"] 632 current_signature = state.data_signature 633 if saved_signature != current_signature: # Signatures disagree. 634 msg = f"The commit_path {path} was already used for saving annotations for different image data:\n" 635 msg += f"The data signatures are different: {saved_signature} != {current_signature}.\n" 636 msg += "Press 'Ok' to remove the data already stored in that file and continue annotation.\n" 637 msg += "Otherwise please select a different file path." 638 skip_clear = _generate_message("info", msg) 639 if skip_clear: 640 return 641 else: 642 f = z5py.ZarrFile(path, "w") 643 _save_signature(f, current_signature) 644 # Otherwise (data signature not saved yet), write the current signature. 645 else: 646 _save_signature(f, state.data_signature) 647 648 # Write the segmentation. 649 full_shape = viewer.layers["committed_objects"].data.shape 650 block_shape = util.get_block_shape(full_shape) 651 ds = f.require_dataset( 652 "committed_objects", shape=full_shape, chunks=block_shape, compression="gzip", dtype=seg.dtype 653 ) 654 ds.n_threads = mp.cpu_count() 655 data = ds[bb] 656 data[mask] = seg[mask] 657 ds[bb] = data 658 659 # Write additional information to attrs. 660 if extra_attrs is not None: 661 f.attrs.update(extra_attrs) 662 663 # Get the commit history and the objects that are being commited. 664 commit_history = f.attrs.get("commit_history", []) 665 object_ids = np.unique(seg[mask]) 666 667 # We committed an automatic segmentation. 668 if layer == "auto_segmentation": 669 # Save the settings of the segmentation widget. 670 segmentation_options = _get_auto_segmentation_options(state, object_ids) 671 commit_history.append({"auto_segmentation": segmentation_options}) 672 673 # Write the commit history. 674 f.attrs["commit_history"] = commit_history 675 676 # If we run commit from the automatic segmentation we don't have 677 # any prompts and so don't need to commit anything else. 678 return 679 680 segmentation_options, is_tracking = _get_promptable_segmentation_options(state, object_ids) 681 commit_history.append({"current_object": segmentation_options}) 682 683 def write_prompts(object_id, prompts, point_prompts, point_labels, track_state=None): 684 g = f.create_group(f"prompts/{object_id}") 685 if prompts is not None and len(prompts) > 0: 686 data = np.array(prompts) 687 g.create_dataset("prompts", data=data, shape=data.shape, chunks=data.shape) 688 if point_prompts is not None and len(point_prompts) > 0: 689 g.create_dataset("point_prompts", data=point_prompts, shape=data.shape, chunks=point_prompts.shape) 690 ds = g.create_dataset("point_labels", data=point_labels, shape=data.shape, chunks=point_labels.shape) 691 if track_state is not None: 692 ds.attrs["track_state"] = track_state.tolist() 693 694 # Get the prompts from the layers. 695 prompts = viewer.layers["prompts"].data 696 point_layer = viewer.layers["point_prompts"] 697 point_prompts = point_layer.data 698 point_labels = point_layer.properties["label"] 699 if len(point_prompts) > 0: 700 point_labels = np.array([1 if label == "positive" else 0 for label in point_labels]) 701 assert len(point_prompts) == len(point_labels), \ 702 f"Number of point prompts and labels disagree: {len(point_prompts)} != {len(point_labels)}" 703 704 # Commit the prompts for all the objects in the commit. 705 if len(object_ids) == 1: # We only have a single object. 706 write_prompts(object_ids[0], prompts, point_prompts, point_labels) 707 708 elif is_tracking: # We have multiple objects from tracking a lineage with divisions. 709 track_ids_points = np.array(point_layer.properties["track_id"]) 710 track_ids_prompts = np.array(viewer.layers["prompts"].properties["track_id"]) 711 712 unique_track_ids = np.unique(track_ids_points) 713 assert len(unique_track_ids) == len(object_ids) 714 track_state = np.array(point_layer.properties["state"]) 715 for track_id, object_id in zip(unique_track_ids, object_ids): 716 this_prompts = None if len(prompts) == 0 else prompts[track_ids_prompts == track_id] 717 point_mask = track_ids_points == track_id 718 this_points, this_labels, this_track_state = \ 719 point_prompts[point_mask], point_labels[point_mask], track_state[point_mask] 720 write_prompts(object_id, this_prompts, this_points, this_labels, track_state=this_track_state) 721 722 else: # We have multiple objects, which are the result from batched interactive segmentation. 723 # Note: we can't match exact object ids to their prompts, for batched segmentation. 724 # We first write the objects from box prompts, then from point prompts. 725 n_prompts, n_points = len(prompts), len(point_prompts) 726 assert n_prompts + n_points == len(object_ids), \ 727 f"Number of prompts and objects disagree: {n_prompts} + {n_points} != {len(object_ids)}" 728 for i, object_id in enumerate(object_ids): 729 if i < n_prompts: 730 this_prompts, this_points, this_labels = prompts[i:i+1], None, None 731 else: 732 j = i - n_prompts 733 this_prompts, this_points, this_labels = None, point_prompts[j:j+1], point_labels[j:j+1] 734 write_prompts(object_id, this_prompts, this_points, this_labels) 735 736 # Write the commit history. 737 f.attrs["commit_history"] = commit_history 738 739 740@magic_factory( 741 call_button="Commit [C]", 742 layer={"choices": ["current_object", "auto_segmentation"], "tooltip": get_tooltip("commit", "layer")}, 743 preserve_mode={"choices": ["objects", "pixels", "none"], "tooltip": get_tooltip("commit", "preserve_mode")}, 744 commit_path={"mode": "d", "tooltip": get_tooltip("commit", "commit_path")}, 745) 746def commit( 747 viewer: "napari.viewer.Viewer", 748 layer: str = "current_object", 749 preserve_mode: str = "objects", 750 preservation_threshold: float = 0.75, 751 commit_path: Optional[Path] = None, 752) -> None: 753 """Widget for committing the segmented objects from automatic or interactive segmentation. 754 755 Args: 756 viewer: The napari viewer. 757 layer: Select the layer to commit. Can be either 'current_object' to commit interacitve segmentation results. 758 Or 'auto_segmentation' to commit automatic segmentation results. 759 preserve_mode: The mode for preserving already committed objects, in order to prevent over-writing 760 them by a new commit. Supports the modes 'objects', which preserves on the object level and is the default, 761 'pixels', which preserves on the pixel-level, or 'none', which does not preserve commited objects. 762 preservation_threshold: The overlap threshold for preserving objects. This is only used if 763 preservation_mode is set to 'objects'. 764 commit_path: Select a file path where the committed results and prompts will be saved. 765 This feature is still experimental. 766 """ 767 # Commit the segmentation layer. 768 _, seg, mask, bb = _commit_impl(viewer, layer, preserve_mode, preservation_threshold) 769 770 if commit_path is not None: 771 _commit_to_file(commit_path, viewer, layer, seg, mask, bb) 772 773 if layer == "current_object": 774 vutil.clear_annotations(viewer) 775 else: 776 viewer.layers["auto_segmentation"].data = np.zeros( 777 viewer.layers["auto_segmentation"].data.shape, dtype="uint32" 778 ) 779 viewer.layers["auto_segmentation"].refresh() 780 _select_layer(viewer, "committed_objects") 781 782 # Perform garbage collection 783 gc.collect() 784 785 786@magic_factory( 787 call_button="Commit [C]", 788 layer={"choices": ["current_object", "auto_segmentation"]}, 789 preserve_mode={"choices": ["objects", "pixels", "none"]}, 790 commit_path={"mode": "d"}, # choose a directory 791) 792def commit_track( 793 viewer: "napari.viewer.Viewer", 794 layer: str = "current_object", 795 preserve_mode: str = "objects", 796 preservation_threshold: float = 0.75, 797 commit_path: Optional[Path] = None, 798) -> None: 799 """Widget for committing the objects from interactive tracking. 800 801 Args: 802 viewer: The napari viewer. 803 layer: Select the layer to commit. Can be either 'current_object' to commit interacitve segmentation results. 804 Or 'auto_segmentation' to commit automatic segmentation results. 805 preserve_mode: The mode for preserving already committed objects, in order to prevent over-writing 806 them by a new commit. Supports the modes 'objects', which preserves on the object level and is the default, 807 'pixels', which preserves on the pixel-level, or 'none', which does not preserve commited objects. 808 preservation_threshold: The overlap threshold for preserving objects. This is only used if 809 preservation_mode is set to 'objects'. 810 commit_path: Select a file path where the committed results and prompts will be saved. 811 This feature is still experimental. 812 """ 813 # Commit the segmentation layer. 814 id_offset, seg, mask, bb = _commit_impl(viewer, layer, preserve_mode, preservation_threshold) 815 816 # Update the lineages. 817 state = AnnotatorState() 818 lineage = state.lineage 819 820 if isinstance(lineage, list): # This is a list of lineages from auto-tracking. 821 assert id_offset == 0 822 assert len(state.committed_lineages) == 0 823 state.committed_lineages.extend(lineage) 824 else: # This is a single lineage from interactive tracking. 825 updated_lineage = { 826 parent + id_offset: [child + id_offset for child in children] for parent, children in state.lineage.items() 827 } 828 state.committed_lineages.append(updated_lineage) 829 830 if commit_path is not None: 831 _commit_to_file( 832 commit_path, viewer, layer, seg, mask, bb, 833 extra_attrs={"committed_lineages": state.committed_lineages} 834 ) 835 836 if layer == "current_object": 837 vutil.clear_annotations(viewer) 838 839 # Create / update the tracking layer. 840 layer_name = "tracks" 841 segmentation = viewer.layers["committed_objects"].data 842 track_data, parent_graph = get_napari_track_data(segmentation, state.committed_lineages) 843 if layer_name in viewer.layers: 844 layer = viewer.layers[layer_name] 845 layer.data = track_data 846 layer.graph = parent_graph 847 else: 848 viewer.add_tracks(track_data, name=layer_name, graph=parent_graph) 849 850 # Reset the tracking state. 851 _reset_tracking_state(viewer) 852 853 # Perform garbage collection. 854 gc.collect() 855 856 857def create_prompt_menu(points_layer, labels, menu_name="prompt", label_name="label"): 858 """Create the menu for toggling point prompt labels.""" 859 label_menu = ComboBox(label=menu_name, choices=labels, tooltip=get_tooltip("prompt_menu", "labels")) 860 label_widget = Container(widgets=[label_menu]) 861 862 def update_label_menu(event): 863 new_label = str(points_layer.current_properties[label_name][0]) 864 if new_label != label_menu.value: 865 label_menu.value = new_label 866 867 points_layer.events.current_properties.connect(update_label_menu) 868 869 def label_changed(new_label): 870 current_properties = points_layer.current_properties 871 current_properties[label_name] = np.array([new_label]) 872 points_layer.current_properties = current_properties 873 points_layer.refresh_colors() 874 875 label_menu.changed.connect(label_changed) 876 877 return label_widget 878 879 880@magic_factory( 881 call_button="Update settings", 882 cache_directory={"mode": "d"}, # choose a directory 883) 884def settings_widget(cache_directory: Optional[Path] = util.get_cache_directory()) -> None: 885 """Widget to update global micro_sam settings. 886 887 Args: 888 cache_directory: Select the path for the micro_sam cache directory. `$HOME/.cache/micro_sam`. 889 """ 890 os.environ["MICROSAM_CACHEDIR"] = str(cache_directory) 891 print(f"micro-sam cache directory set to: {cache_directory}") 892 893 894def _generate_message(message_type: str, message: str) -> bool: 895 """ 896 Displays a message dialog based on the provided message type. 897 898 Args: 899 message_type: The type of message to display. Valid options are: 900 - "error": Displays a critical error message with an "Ok" button. 901 - "info": Displays an informational message in a separate dialog box. 902 The user can dismiss it by either clicking "Ok" or closing the dialog. 903 message: The message content to be displayed in the dialog. 904 905 Returns: 906 A flag indicating whether the user aborted the operation based on the 907 message type. This flag is only set for "info" messages where the user 908 can choose to cancel (rejected). 909 910 Raises: 911 ValueError: If an invalid message type is provided. 912 """ 913 # Set button text and behavior based on message type 914 if message_type == "error": 915 QtWidgets.QMessageBox.critical(None, "Error", message, QtWidgets.QMessageBox.Ok) 916 abort = True 917 return abort 918 elif message_type == "info": 919 info_dialog = InfoDialog(title="Validation Message", message=message) 920 result = info_dialog.exec_() 921 if result == QtWidgets.QDialog.Rejected: # Check for cancel 922 abort = True # Set flag directly in calling function 923 return abort 924 else: 925 raise ValueError(f"Invalid message type {message_type}") 926 927 928def _validate_embeddings(viewer: "napari.viewer.Viewer"): 929 state = AnnotatorState() 930 if state.image_embeddings is None: 931 msg = "Image embeddings are not yet computed. Press 'Compute Embeddings' to compute them for your image." 932 return _generate_message("error", msg) 933 else: 934 return False 935 936 # This code is for checking the data signature of the current image layer and the data signature 937 # of the embeddings. However, the code has some disadvantages, for example assuming the position of the 938 # image layer and also having to compute the data signature every time. 939 # That's why we are not using this for now, but may want to revisit this in the future. See: 940 # https://github.com/computational-cell-analytics/micro-sam/issues/504 941 942 # embeddings_save_path = state.embedding_path 943 # embedding_data_signature = None 944 # image = None 945 # if isinstance(viewer.layers[0], napari.layers.Image): # Assuming the image layer is at index 0 946 # image = viewer.layers[0] 947 # else: 948 # # Handle the case where the first layer isn't an Image layer 949 # raise ValueError("Expected an Image layer in viewer.layers") 950 # img_signature = util._compute_data_signature(image.data) 951 # if embeddings_save_path is not None: 952 # # Check for existing embeddings 953 # if os.listdir(embeddings_save_path): 954 # try: 955 # with zarr.open(embeddings_save_path, "a") as f: 956 # # If data_signature exists, compare and return validation message 957 # if "data_signature" in f.attrs: 958 # embedding_data_signature = f.attrs["data_signature"] 959 # except RuntimeError as e: 960 # val_results = { 961 # "message_type": "error", 962 # "message": f"Failed to load image embeddings: {e}" 963 # } 964 # else: 965 # val_results = {"message_type": "info", "message": "No existing embeddings found at the specified path."} 966 # else: # load from state object 967 # embedding_data_signature = state.data_signature 968 # # compare image data signature with embedding data signature 969 # if img_signature != embedding_data_signature: 970 # val_results = { 971 # "message_type": "error", 972 # "message": f"The embeddings don't match with the image: {img_signature} {embedding_data_signature}" 973 # } 974 # else: 975 # val_results = None 976 # if val_results: 977 # return _generate_message(val_results["message_type"], val_results["message"]) 978 # else: 979 # return False 980 981 982def _validation_window_for_missing_layer(layer_choice): 983 if layer_choice == "committed_objects": 984 msg = "The 'committed_objects' layer to commit masks is missing. Please try to commit again." 985 else: 986 msg = f"The '{layer_choice}' layer to commit is missing. Please re-annotate and try again." 987 988 return _generate_message(message_type="error", message=msg) 989 990 991def _validate_layers(viewer: "napari.viewer.Viewer", automatic_segmentation: bool = False) -> bool: 992 # Check whether all layers exist as expected or create new ones automatically. 993 state = AnnotatorState() 994 state.annotator._require_layers() 995 996 if not automatic_segmentation: 997 # Check prompts layer. 998 if len(viewer.layers["prompts"].data) == 0 and len(viewer.layers["point_prompts"].data) == 0: 999 msg = "No prompts were given. Please provide prompts to run interactive segmentation." 1000 return _generate_message("error", msg) 1001 else: 1002 return False 1003 1004 1005@magic_factory(call_button="Segment Object [S]") 1006def segment(viewer: "napari.viewer.Viewer", batched: bool = False) -> None: 1007 """Segment object(s) for the current prompts. 1008 1009 Args: 1010 viewer: The napari viewer. 1011 batched: Choose if you want to segment multiple objects with point prompts. 1012 """ 1013 if _validate_embeddings(viewer): 1014 return None 1015 if _validate_layers(viewer): 1016 return None 1017 1018 shape = viewer.layers["current_object"].data.shape 1019 1020 # get the current box and point prompts 1021 boxes, masks = vutil.shape_layer_to_prompts(viewer.layers["prompts"], shape) 1022 points, labels = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], with_stop_annotation=False) 1023 1024 predictor = AnnotatorState().predictor 1025 image_embeddings = AnnotatorState().image_embeddings 1026 seg = vutil.prompt_segmentation( 1027 predictor, points, labels, boxes, masks, shape, image_embeddings=image_embeddings, 1028 multiple_box_prompts=True, batched=batched, previous_segmentation=viewer.layers["current_object"].data, 1029 ) 1030 1031 # no prompts were given or prompts were invalid, skip segmentation 1032 if seg is None: 1033 print("You either haven't provided any prompts or invalid prompts. The segmentation will be skipped.") 1034 return 1035 1036 viewer.layers["current_object"].data = seg 1037 viewer.layers["current_object"].refresh() 1038 1039 1040@magic_factory(call_button="Segment Slice [S]") 1041def segment_slice(viewer: "napari.viewer.Viewer") -> None: 1042 """Segment object for to the current prompts. 1043 1044 Args: 1045 viewer: The napari viewer. 1046 """ 1047 if _validate_embeddings(viewer): 1048 return None 1049 if _validate_layers(viewer): 1050 return None 1051 1052 shape = viewer.layers["current_object"].data.shape[1:] 1053 1054 position_world = viewer.dims.point 1055 position = viewer.layers["point_prompts"].world_to_data(position_world) 1056 z = int(position[0]) 1057 1058 point_prompts = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], z) 1059 # this is a stop prompt, we do nothing 1060 if not point_prompts: 1061 return 1062 1063 boxes, masks = vutil.shape_layer_to_prompts(viewer.layers["prompts"], shape, i=z) 1064 points, labels = point_prompts 1065 1066 state = AnnotatorState() 1067 seg = vutil.prompt_segmentation( 1068 state.predictor, points, labels, boxes, masks, shape, multiple_box_prompts=False, 1069 image_embeddings=state.image_embeddings, i=z, 1070 ) 1071 1072 # no prompts were given or prompts were invalid, skip segmentation 1073 if seg is None: 1074 print("You either haven't provided any prompts or invalid prompts. The segmentation will be skipped.") 1075 return 1076 1077 viewer.layers["current_object"].data[z] = seg 1078 viewer.layers["current_object"].refresh() 1079 1080 1081@magic_factory(call_button="Segment Frame [S]") 1082def segment_frame(viewer: "napari.viewer.Viewer") -> None: 1083 """Segment object for the current prompts. 1084 1085 Args: 1086 viewer: The napari viewer. 1087 """ 1088 if _validate_embeddings(viewer): 1089 return None 1090 if _validate_layers(viewer): 1091 return None 1092 1093 state = AnnotatorState() 1094 shape = state.image_shape[1:] 1095 position = viewer.dims.point 1096 t = int(position[0]) 1097 1098 point_prompts = vutil.point_layer_to_prompts(viewer.layers["point_prompts"], i=t, track_id=state.current_track_id) 1099 # this is a stop prompt, we do nothing 1100 if not point_prompts: 1101 return 1102 1103 boxes, masks = vutil.shape_layer_to_prompts(viewer.layers["prompts"], shape, i=t, track_id=state.current_track_id) 1104 points, labels = point_prompts 1105 1106 seg = vutil.prompt_segmentation( 1107 state.predictor, points, labels, boxes, masks, shape, multiple_box_prompts=False, 1108 image_embeddings=state.image_embeddings, i=t 1109 ) 1110 1111 # no prompts were given or prompts were invalid, skip segmentation 1112 if seg is None: 1113 print("You either haven't provided any prompts or invalid prompts. The segmentation will be skipped.") 1114 return 1115 1116 # clear the old segmentation for this track_id 1117 old_mask = viewer.layers["current_object"].data[t] == state.current_track_id 1118 viewer.layers["current_object"].data[t][old_mask] = 0 1119 # set the new segmentation 1120 new_mask = seg.squeeze() == 1 1121 viewer.layers["current_object"].data[t][new_mask] = state.current_track_id 1122 viewer.layers["current_object"].refresh() 1123 1124 1125# 1126# Functionality and widget to compute the image embeddings. 1127# 1128 1129 1130def _process_tiling_inputs(tile_shape_x, tile_shape_y, halo_x, halo_y): 1131 tile_shape = (tile_shape_x, tile_shape_y) 1132 halo = (halo_x, halo_y) 1133 # check if tile_shape/halo are not set: (0, 0) 1134 if all(item in (0, None) for item in tile_shape): 1135 tile_shape = None 1136 # check if at least 1 param is given 1137 elif tile_shape[0] == 0 or tile_shape[1] == 0: 1138 max_val = max(tile_shape[0], tile_shape[1]) 1139 if max_val < 256: # at least tile shape >256 1140 max_val = 256 1141 tile_shape = (max_val, max_val) 1142 # if both inputs given, check if smaller than 256 1143 elif tile_shape[0] != 0 and tile_shape[1] != 0: 1144 if tile_shape[0] < 256: 1145 tile_shape = (256, tile_shape[1]) # Create a new tuple 1146 if tile_shape[1] < 256: 1147 tile_shape = (tile_shape[0], 256) # Create a new tuple with modified value 1148 if all(item in (0, None) for item in halo): 1149 if tile_shape is not None: 1150 halo = (0, 0) 1151 else: 1152 halo = None 1153 # check if at least 1 param is given 1154 elif halo[0] != 0 or halo[1] != 0: 1155 max_val = max(halo[0], halo[1]) 1156 # don't apply halo if there is no tiling 1157 if tile_shape is None: 1158 halo = None 1159 else: 1160 halo = (max_val, max_val) 1161 return tile_shape, halo 1162 1163 1164class EmbeddingWidget(_WidgetBase): 1165 def __init__(self, parent=None): 1166 super().__init__(parent=parent) 1167 1168 # Create a nested layout for the sections. 1169 # Section 1: Image and Model. 1170 section1_layout = QtWidgets.QHBoxLayout() 1171 section1_layout.addLayout(self._create_image_section()) 1172 section1_layout.addLayout(self._create_model_section()) # Creates the model family widget section. 1173 self.layout().addLayout(section1_layout) 1174 1175 # Section 2: Settings (collapsible). 1176 self.layout().addWidget(self._create_settings_widget()) 1177 1178 # Section 3: The button to trigger the embedding computation. 1179 self.run_button = QtWidgets.QPushButton("Compute Embeddings") 1180 self.run_button.clicked.connect(self._initialize_image) 1181 self.run_button.clicked.connect(self.__call__) 1182 self.run_button.setToolTip(get_tooltip("embedding", "run_button")) 1183 self.layout().addWidget(self.run_button) 1184 1185 def _initialize_image(self): 1186 state = AnnotatorState() 1187 layer = self.image_selection.get_value() 1188 1189 # This is encountered when there is no image layer available / selected. 1190 # In this case, we need not specify other image-level parameters to the state. Hence, we skip them. 1191 # NOTE: On code-level, this happens as the first step when "Compute Embedding" click is triggered. 1192 if layer is None: 1193 return 1194 1195 image_shape = layer.data.shape 1196 image_scale = tuple(layer.scale) 1197 state.image_shape = image_shape 1198 state.image_scale = image_scale 1199 state.image_name = layer.name 1200 1201 def _create_image_section(self): 1202 image_section = QtWidgets.QVBoxLayout() 1203 image_layer_widget = QtWidgets.QLabel("Image Layer:") 1204 # image_layer_widget.setToolTip(get_tooltip("embedding", "image")) # this adds tooltip to label 1205 image_section.addWidget(image_layer_widget) 1206 1207 # Setting a napari layer in QT, see: 1208 # https://github.com/pyapp-kit/magicgui/blob/main/docs/examples/napari/napari_combine_qt.py 1209 self.image_selection = create_widget(annotation=napari.layers.Image) 1210 self.image_selection.native.setToolTip(get_tooltip("embedding", "image")) 1211 image_section.addWidget(self.image_selection.native) 1212 1213 return image_section 1214 1215 def _update_model(self, state): 1216 _model_type = state.predictor.model_type if self.custom_weights else self.model_type 1217 1218 # Provide a detailed message for the model family and model size per chosen combination. 1219 msg = "Computed embeddings for " 1220 if self.custom_weights: # Whether the user provided a filepath to custom finetuned model weights. 1221 msg += f"the model located at '{os.path.abspath(self.custom_weights)}' " 1222 msg += f"of size '{self._model_size_map[_model_type[4]]}'." 1223 else: 1224 msg += f"the '{self.model_family}' model of size '{self.model_size}'." 1225 1226 show_info(msg) 1227 1228 state = AnnotatorState() 1229 # Update the widget itself. This is necessary because we may have loaded 1230 # some settings from the embedding file and have to reflect them in the widget. 1231 vutil._sync_embedding_widget( 1232 self, 1233 model_type=_model_type, 1234 save_path=self.embeddings_save_path, 1235 checkpoint_path=self.custom_weights, 1236 device=self.device, 1237 tile_shape=[self.tile_x, self.tile_y], 1238 halo=[self.halo_x, self.halo_y] 1239 ) 1240 1241 # Set the default settings for this model in the autosegment widget if it is part of 1242 # the currently used plugin. 1243 if "autosegment" in state.widgets: 1244 with_decoder = state.decoder is not None 1245 vutil._sync_autosegment_widget( 1246 state.widgets["autosegment"], _model_type, self.custom_weights, update_decoder=with_decoder 1247 ) 1248 # Load the AMG/AIS state if we have a 3d segmentation plugin. 1249 if state.widgets["autosegment"].volumetric and with_decoder: 1250 state.amg_state = vutil._load_is_state(state.embedding_path) 1251 elif state.widgets["autosegment"].volumetric and not with_decoder: 1252 state.amg_state = vutil._load_amg_state(state.embedding_path) 1253 1254 # Set the default settings for this model in the nd-segmentation widget if it is part of 1255 # the currently used plugin. 1256 if "segment_nd" in state.widgets: 1257 vutil._sync_ndsegment_widget(state.widgets["segment_nd"], _model_type, self.custom_weights) 1258 1259 def _create_settings_widget(self): 1260 setting_values = QtWidgets.QWidget() 1261 setting_values.setToolTip(get_tooltip("embedding", "settings")) 1262 setting_values.setLayout(QtWidgets.QVBoxLayout()) 1263 1264 # Add the model size widget section. 1265 layout = self._create_model_size_section() 1266 setting_values.layout().addLayout(layout) 1267 1268 # Create UI for the device. 1269 self.device = "auto" 1270 device_options = ["auto"] + util._available_devices() 1271 1272 self.device_dropdown, layout = self._add_choice_param( 1273 "device", self.device, device_options, tooltip=get_tooltip("embedding", "device") 1274 ) 1275 setting_values.layout().addLayout(layout) 1276 1277 # Create UI for the save path. 1278 self.embeddings_save_path = None 1279 self.embeddings_save_path_param, layout = self._add_path_param( 1280 "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:", 1281 tooltip=get_tooltip("embedding", "embeddings_save_path") 1282 ) 1283 setting_values.layout().addLayout(layout) 1284 1285 # Create UI for the custom weights. 1286 self.custom_weights = None 1287 self.custom_weights_param, layout = self._add_path_param( 1288 "custom_weights", self.custom_weights, "file", title="custom weights path:", 1289 tooltip=get_tooltip("embedding", "custom_weights") 1290 ) 1291 setting_values.layout().addLayout(layout) 1292 1293 # Create UI for the tile shape. 1294 self.tile_x, self.tile_y = 0, 0 1295 self.tile_x_param, self.tile_y_param, layout = self._add_shape_param( 1296 ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16, 1297 tooltip=get_tooltip("embedding", "tiling") 1298 ) 1299 setting_values.layout().addLayout(layout) 1300 1301 # Create UI for the halo. 1302 self.halo_x, self.halo_y = 0, 0 1303 self.halo_x_param, self.halo_y_param, layout = self._add_shape_param( 1304 ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512, 1305 tooltip=get_tooltip("embedding", "halo") 1306 ) 1307 setting_values.layout().addLayout(layout) 1308 1309 # Create UI for the choice of automatic segmentation mode. 1310 self.automatic_segmentation_mode = "auto" 1311 auto_seg_options = ["auto", "amg", "ais"] 1312 self.automatic_segmentation_mode_dropdown, layout = self._add_choice_param( 1313 "automatic_segmentation_mode", self.automatic_segmentation_mode, auto_seg_options, 1314 title="automatic segmentation mode", tooltip=get_tooltip("embedding", "automatic_segmentation_mode") 1315 ) 1316 setting_values.layout().addLayout(layout) 1317 1318 settings = _make_collapsible(setting_values, title="Embedding Settings") 1319 return settings 1320 1321 def _validate_inputs(self): 1322 """Validates the inputs for the annotation process and returns a dictionary 1323 containing information for message generation, or False if no messages are needed. 1324 1325 This function performs the following checks: 1326 1327 - If an `embeddings_save_path` is provided: 1328 - Validates the image data signature by comparing it with the signature 1329 of the image data in the viewer's selection. 1330 - Checks for existing embeddings at the specified path. 1331 - If existing embeddings are found, it attempts to load parameters 1332 like tile shape, halo, and model type from the Zarr attributes. 1333 - An informational message is generated based on the loaded parameters. 1334 - If loading existing embeddings fails, an error message is generated. 1335 - If no existing embeddings are found, an informational message is generated. 1336 - If no `embeddings_save_path` is provided, the function returns None. 1337 1338 Returns: 1339 bool: True if the computation should be aborted, otherwise False. 1340 """ 1341 1342 # Check if we have an existing input image to compute the embeddings. 1343 image = self.image_selection.get_value() 1344 if image is None: 1345 return _generate_message("error", "No image has been selected.") 1346 1347 # Check if we have an existing embedding path. 1348 # If yes we check the data signature of these embeddings against the selected image 1349 # and we ask the user if they want to load these embeddings. 1350 if self.embeddings_save_path and os.listdir(self.embeddings_save_path): 1351 try: 1352 f = zarr.open(self.embeddings_save_path, mode="a") 1353 1354 # Validate that the embeddings are complete. 1355 # Note: 'input_size' is the last value set in the attrs of f, 1356 # so we can use it as a proxy to check if the embeddings are fully computed 1357 if "input_size" not in f.attrs: 1358 msg = (f"The embeddings at {self.embeddings_save_path} are incomplete. " 1359 "Specify a different path or remove them.") 1360 return _generate_message("error", msg) 1361 1362 # Validate image data signature. 1363 if "data_signature" in f.attrs: 1364 image = self.image_selection.get_value() 1365 img_signature = util._compute_data_signature(image.data) 1366 if img_signature != f.attrs["data_signature"]: 1367 msg = f"The embeddings don't match with the image: {img_signature} {f.attrs['data_signature']}" 1368 return _generate_message("error", msg) 1369 1370 # Load existing parameters. 1371 self.model_type = f.attrs.get("model_name", f.attrs["model_type"]) 1372 if "tile_shape" in f.attrs and f.attrs["tile_shape"] is not None: 1373 self.tile_x, self.tile_y = f.attrs["tile_shape"] 1374 self.halo_x, self.halo_y = f.attrs["halo"] 1375 val_results = { 1376 "message_type": "info", 1377 "message": (f"Load embeddings for model: {self.model_type} with tile shape: " 1378 f"{self.tile_x}, {self.tile_y} and halo: {self.halo_x}, {self.halo_y}.") 1379 } 1380 else: 1381 self.tile_x, self.tile_y = 0, 0 1382 self.halo_x, self.halo_y = 0, 0 1383 val_results = { 1384 "message_type": "info", 1385 "message": f"Load embeddings for model: {self.model_type}." 1386 } 1387 1388 return _generate_message(val_results["message_type"], val_results["message"]) 1389 1390 except RuntimeError as e: 1391 val_results = { 1392 "message_type": "error", 1393 "message": f"Failed to load image embeddings: {e}" 1394 } 1395 return _generate_message(val_results["message_type"], val_results["message"]) 1396 1397 # Otherwise we either don't have an embedding path or it is empty. We can proceed in both cases. 1398 return False 1399 1400 def _validate_existing_embeddings(self, state): 1401 if state.image_embeddings is None: 1402 return False 1403 else: 1404 val_results = { 1405 "message_type": "info", 1406 "message": "Embeddings have already been precomputed. Press OK to recompute the embeddings." 1407 } 1408 return _generate_message(val_results["message_type"], val_results["message"]) 1409 1410 def __call__(self, skip_validate=False): 1411 self._validate_model_type_and_custom_weights() 1412 1413 # Validate user inputs. 1414 if not skip_validate and self._validate_inputs(): 1415 return 1416 1417 # Get the image. 1418 image = self.image_selection.get_value() 1419 1420 # Update the image embeddings: 1421 state = AnnotatorState() 1422 if self._validate_existing_embeddings(state): 1423 # Whether embeddings already exist to control existing objects in layers. 1424 state.skip_recomputing_embeddings = True 1425 return 1426 1427 state.skip_recomputing_embeddings = False 1428 # Reset the state. 1429 state.reset_state() 1430 1431 # Get image dimensions. 1432 if image.rgb: 1433 ndim = image.data.ndim - 1 1434 state.image_shape = image.data.shape[:-1] 1435 else: 1436 ndim = image.data.ndim 1437 state.image_shape = image.data.shape 1438 1439 # Set layer scale 1440 state.image_scale = tuple(image.scale) 1441 1442 # Process tile_shape and halo, set other data. 1443 tile_shape, halo = _process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) 1444 save_path = None if self.embeddings_save_path == "" else self.embeddings_save_path 1445 image_data = image.data 1446 1447 # Set up progress bar and signals for using it within a threadworker. 1448 pbar, pbar_signals = _create_pbar_for_threadworker() 1449 1450 # @thread_worker() 1451 def compute_image_embedding(): 1452 1453 def pbar_init(total, description): 1454 pbar_signals.pbar_total.emit(total) 1455 pbar_signals.pbar_description.emit(description) 1456 1457 # Whether to prefer decoder. 1458 # With 'amg', it is set to 'False', else it is 'True' for the default 'auto' and 'ais' mode. 1459 prefer_decoder = True 1460 if self.automatic_segmentation_mode == "amg": 1461 prefer_decoder = False 1462 1463 state.initialize_predictor( 1464 image_data, model_type=self.model_type, save_path=save_path, ndim=ndim, 1465 device=self.device, checkpoint_path=self.custom_weights, tile_shape=tile_shape, halo=halo, 1466 prefer_decoder=prefer_decoder, pbar_init=pbar_init, 1467 pbar_update=lambda update: pbar_signals.pbar_update.emit(update), 1468 ) 1469 pbar_signals.pbar_stop.emit() 1470 1471 compute_image_embedding() 1472 self._update_model(state) 1473 # worker = compute_image_embedding() 1474 # worker.returned.connect(self._update_model) 1475 # worker.start() 1476 # return worker 1477 1478 1479# 1480# Functionality and widget for nd segmentation. 1481# 1482 1483 1484def _update_lineage(viewer): 1485 """Updated the lineage after recording a division event. 1486 This helper function is needed by 'track_object'. 1487 """ 1488 state = AnnotatorState() 1489 tracking_widget = state.widgets["tracking"] 1490 1491 mother = state.current_track_id 1492 assert mother in state.lineage 1493 assert len(state.lineage[mother]) == 0 1494 1495 daughter1, daughter2 = state.current_track_id + 1, state.current_track_id + 2 1496 state.lineage[mother] = [daughter1, daughter2] 1497 state.lineage[daughter1] = [] 1498 state.lineage[daughter2] = [] 1499 1500 # Update the choices in the track_id menu so that it contains the new track ids. 1501 track_ids = list(map(str, state.lineage.keys())) 1502 tracking_widget[1].choices = track_ids 1503 1504 viewer.layers["point_prompts"].property_choices["track_id"] = [str(track_id) for track_id in track_ids] 1505 viewer.layers["prompts"].property_choices["track_id"] = [str(track_id) for track_id in track_ids] 1506 1507 1508class SegmentNDWidget(_WidgetBase): 1509 def __init__(self, viewer, tracking, parent=None): 1510 super().__init__(parent=parent) 1511 self._viewer = viewer 1512 self.tracking = tracking 1513 1514 # Add the settings. 1515 self.settings = self._create_settings() 1516 self.layout().addWidget(self.settings) 1517 1518 # Add the run button. 1519 button_title = "Segment All Frames [Shift-S]" if self.tracking else "Segment All Slices [Shift-S]" 1520 self.run_button = QtWidgets.QPushButton(button_title) 1521 self.run_button.clicked.connect(self.__call__) 1522 self.layout().addWidget(self.run_button) 1523 1524 def _create_settings(self): 1525 setting_values = QtWidgets.QWidget() 1526 setting_values.setToolTip(get_tooltip("segmentnd", "settings")) 1527 setting_values.setLayout(QtWidgets.QVBoxLayout()) 1528 1529 # Create the UI for the projection modes. 1530 self.projection = "single_point" 1531 self.projection_dropdown, layout = self._add_choice_param( 1532 "projection", self.projection, PROJECTION_MODES, tooltip=get_tooltip("segmentnd", "projection_dropdown") 1533 ) 1534 setting_values.layout().addLayout(layout) 1535 1536 # Create the UI element for the IOU threshold. 1537 self.iou_threshold = 0.5 1538 self.iou_threshold_param, layout = self._add_float_param( 1539 "iou_threshold", self.iou_threshold, tooltip=get_tooltip("segmentnd", "iou_threshold") 1540 ) 1541 setting_values.layout().addLayout(layout) 1542 1543 # Create the UI element for the box extension. 1544 self.box_extension = 0.05 1545 self.box_extension_param, layout = self._add_float_param( 1546 "box_extension", self.box_extension, tooltip=get_tooltip("segmentnd", "box_extension") 1547 ) 1548 setting_values.layout().addLayout(layout) 1549 1550 # Create the UI element for the motion smoothing (if we have the tracking widget). 1551 if self.tracking: 1552 self.motion_smoothing = 0.5 1553 self.motion_smoothing_param, layout = self._add_float_param( 1554 "motion_smoothing", self.motion_smoothing, tooltip=get_tooltip("segmentnd", "motion_smoothing") 1555 ) 1556 setting_values.layout().addLayout(layout) 1557 1558 settings = _make_collapsible(setting_values, title="Segmentation Settings") 1559 return settings 1560 1561 def _run_tracking(self): 1562 state = AnnotatorState() 1563 pbar, pbar_signals = _create_pbar_for_threadworker() 1564 1565 # @thread_worker 1566 def tracking_impl(): 1567 shape = state.image_shape 1568 1569 pbar_signals.pbar_total.emit(shape[0]) 1570 pbar_signals.pbar_description.emit("Track object") 1571 1572 # Step 1: Segment all slices with prompts. 1573 seg, slices, _, stop_upper = vutil.segment_slices_with_prompts( 1574 state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"], 1575 state.image_embeddings, shape, track_id=state.current_track_id, 1576 update_progress=lambda update: pbar_signals.pbar_update.emit(update), 1577 ) 1578 1579 # Step 2: Track the object starting from the lowest annotated slice. 1580 seg, has_division = vutil.track_from_prompts( 1581 self._viewer.layers["point_prompts"], self._viewer.layers["prompts"], seg, 1582 state.predictor, slices, state.image_embeddings, stop_upper, 1583 threshold=self.iou_threshold, projection=self.projection, 1584 motion_smoothing=self.motion_smoothing, 1585 box_extension=self.box_extension, 1586 update_progress=lambda update: pbar_signals.pbar_update.emit(update), 1587 ) 1588 1589 pbar_signals.pbar_stop.emit() 1590 return seg, has_division 1591 1592 def update_segmentation(ret_val): 1593 seg, has_division = ret_val 1594 # If a division has occurred and it's the first time it occurred for this track 1595 # then we need to create the two daughter tracks and update the lineage. 1596 if has_division and (len(state.lineage[state.current_track_id]) == 0): 1597 _update_lineage(self._viewer) 1598 1599 # Clear the old track mask. 1600 self._viewer.layers["current_object"].data[ 1601 self._viewer.layers["current_object"].data == state.current_track_id 1602 ] = 0 1603 # Set the new object mask. 1604 self._viewer.layers["current_object"].data[seg == 1] = state.current_track_id 1605 self._viewer.layers["current_object"].refresh() 1606 1607 ret_val = tracking_impl() 1608 update_segmentation(ret_val) 1609 # worker = tracking_impl() 1610 # worker.returned.connect(update_segmentation) 1611 # worker.start() 1612 # return worker 1613 1614 def _run_volumetric_segmentation(self): 1615 pbar, pbar_signals = _create_pbar_for_threadworker() 1616 1617 # @thread_worker 1618 def volumetric_segmentation_impl(): 1619 state = AnnotatorState() 1620 shape = state.image_shape 1621 1622 pbar_signals.pbar_total.emit(shape[0]) 1623 pbar_signals.pbar_description.emit("Segment object") 1624 1625 # Step 1: Segment all slices with prompts. 1626 seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts( 1627 state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"], 1628 state.image_embeddings, shape, 1629 update_progress=lambda update: pbar_signals.pbar_update.emit(update), 1630 ) 1631 1632 # Step 2: Segment the rest of the volume based on projecting prompts. 1633 seg, (z_min, z_max) = segment_mask_in_volume( 1634 seg, state.predictor, state.image_embeddings, slices, 1635 stop_lower, stop_upper, 1636 iou_threshold=self.iou_threshold, projection=self.projection, 1637 box_extension=self.box_extension, 1638 update_progress=lambda update: pbar_signals.pbar_update.emit(update), 1639 ) 1640 pbar_signals.pbar_stop.emit() 1641 1642 state.z_range = (z_min, z_max) 1643 return seg 1644 1645 def update_segmentation(seg): 1646 self._viewer.layers["current_object"].data = seg 1647 self._viewer.layers["current_object"].refresh() 1648 1649 seg = volumetric_segmentation_impl() 1650 self._viewer.layers["current_object"].data = seg 1651 self._viewer.layers["current_object"].refresh() 1652 # worker = volumetric_segmentation_impl() 1653 # worker.returned.connect(update_segmentation) 1654 # worker.start() 1655 # return worker 1656 1657 def __call__(self): 1658 if _validate_embeddings(self._viewer): 1659 return None 1660 if _validate_layers(self._viewer): 1661 return None 1662 1663 if self.tracking: 1664 return self._run_tracking() 1665 else: 1666 return self._run_volumetric_segmentation() 1667 1668 1669# 1670# The functionality and widgets for automatic segmentation. 1671# 1672 1673 1674# Messy amg state handling, would be good to refactor this properly at some point. 1675def _handle_amg_state(state, i, pbar_init, pbar_update): 1676 if state.amg is None: 1677 is_tiled = state.image_embeddings["input_size"] is None 1678 state.amg = instance_segmentation.get_instance_segmentation_generator( 1679 state.predictor, is_tiled=is_tiled, decoder=state.decoder 1680 ) 1681 1682 shape = state.image_shape 1683 1684 # Further optimization: refactor parts of this so that we can also use it in the automatic 3d segmentation fucnction 1685 # For 3D we store the amg state in a dict and check if it is computed already. 1686 if state.amg_state is not None: 1687 assert i is not None 1688 if i in state.amg_state: 1689 amg_state_i = state.amg_state[i] 1690 state.amg.set_state(amg_state_i) 1691 1692 else: 1693 dummy_image = np.zeros(shape[-2:], dtype="uint8") 1694 state.amg.initialize( 1695 dummy_image, image_embeddings=state.image_embeddings, i=i, 1696 verbose=pbar_init is not None, pbar_init=pbar_init, pbar_update=pbar_update, 1697 ) 1698 amg_state_i = state.amg.get_state() 1699 state.amg_state[i] = amg_state_i 1700 1701 cache_folder = state.amg_state.get("cache_folder", None) 1702 if cache_folder is not None: 1703 cache_path = os.path.join(cache_folder, f"state-{i}.pkl") 1704 with open(cache_path, "wb") as f: 1705 pickle.dump(amg_state_i, f) 1706 1707 cache_path = state.amg_state.get("cache_path", None) 1708 if cache_path is not None: 1709 save_key = f"state-{i}" 1710 with h5py.File(cache_path, "a") as f: 1711 g = f.create_group(save_key) 1712 g.create_dataset("foreground", data=amg_state_i["foreground"], compression="gzip") 1713 g.create_dataset("boundary_distances", data=amg_state_i["boundary_distances"], compression="gzip") 1714 g.create_dataset("center_distances", data=amg_state_i["center_distances"], compression="gzip") 1715 1716 # Otherwise (2d segmentation) we just check if the amg is initialized or not. 1717 elif not state.amg.is_initialized: 1718 assert i is None 1719 # We don't need to pass the actual image data here, since the embeddings are passed. 1720 # (The image data is only used by the amg to compute image embeddings, so not needed here.) 1721 dummy_image = np.zeros(shape, dtype="uint8") 1722 state.amg.initialize( 1723 dummy_image, image_embeddings=state.image_embeddings, 1724 verbose=pbar_init is not None, pbar_init=pbar_init, pbar_update=pbar_update 1725 ) 1726 1727 1728def _instance_segmentation_impl(min_object_size, i=None, pbar_init=None, pbar_update=None, **kwargs): 1729 state = AnnotatorState() 1730 _handle_amg_state(state, i, pbar_init, pbar_update) 1731 seg = state.amg.generate(**kwargs) 1732 assert isinstance(seg, np.ndarray) 1733 return seg 1734 1735 1736class AutoSegmentWidget(_WidgetBase): 1737 def __init__(self, viewer, with_decoder, volumetric, parent=None): 1738 super().__init__(parent) 1739 1740 self._viewer = viewer 1741 self.with_decoder = with_decoder 1742 self.volumetric = volumetric 1743 self._create_widget() 1744 1745 def _create_widget(self): 1746 # Add the switch for segmenting the slice vs. the volume if we have a volume. 1747 if self.volumetric: 1748 self.layout().addWidget(self._create_volumetric_switch()) 1749 1750 # Add the nested settings widget. 1751 self.settings = self._create_settings() 1752 self.layout().addWidget(self.settings) 1753 1754 # Add the run button. 1755 self.run_button = QtWidgets.QPushButton("Automatic Segmentation") 1756 self.run_button.clicked.connect(self.__call__) 1757 self.run_button.setToolTip(get_tooltip("autosegment", "run_button")) 1758 self.layout().addWidget(self.run_button) 1759 1760 def _reset_segmentation_mode(self, with_decoder): 1761 # If we already have the same segmentation mode we don't need to do anything. 1762 if with_decoder == self.with_decoder: 1763 return 1764 1765 # Otherwise we change the value of with_decoder. 1766 self.with_decoder = with_decoder 1767 1768 # Then we clear the whole widget. 1769 layout = self.layout() 1770 while layout.count(): 1771 child = layout.takeAt(0) 1772 if child.widget(): 1773 child.widget().deleteLater() 1774 1775 # And then we reset it. 1776 self._create_widget() 1777 1778 def _create_volumetric_switch(self): 1779 self.apply_to_volume = False 1780 return self._add_boolean_param( 1781 "apply_to_volume", self.apply_to_volume, title="Apply to Volume", 1782 tooltip=get_tooltip("autosegment", "apply_to_volume") 1783 ) 1784 1785 def _add_common_settings(self, settings): 1786 # Create the UI element for min object size. 1787 self.min_object_size = 100 1788 self.min_object_size_param, layout = self._add_int_param( 1789 "min_object_size", self.min_object_size, min_val=0, max_val=int(1e4), 1790 tooltip=get_tooltip("autosegment", "min_object_size") 1791 ) 1792 settings.layout().addLayout(layout) 1793 1794 # Add extra settings for volumetric segmentation: gap_closing and min_extent. 1795 if self.volumetric: 1796 self.gap_closing = 2 1797 self.gap_closing_param, layout = self._add_int_param( 1798 "gap_closing", self.gap_closing, min_val=0, max_val=10, 1799 tooltip=get_tooltip("autosegment", "gap_closing") 1800 ) 1801 settings.layout().addLayout(layout) 1802 1803 self.min_extent = 2 1804 self.min_extent_param, layout = self._add_int_param( 1805 "min_extent", self.min_extent, min_val=0, max_val=10, 1806 tooltip=get_tooltip("autosegment", "min_extent") 1807 ) 1808 settings.layout().addLayout(layout) 1809 1810 def _ais_settings(self): 1811 settings = QtWidgets.QWidget() 1812 settings.setLayout(QtWidgets.QVBoxLayout()) 1813 1814 # Create the UI element for center_distance_threshold. 1815 self.center_distance_thresh = 0.5 1816 self.center_distance_thresh_param, layout = self._add_float_param( 1817 "center_distance_thresh", self.center_distance_thresh, 1818 tooltip=get_tooltip("autosegment", "center_distance_thresh") 1819 ) 1820 settings.layout().addLayout(layout) 1821 1822 # Create the UI element for boundary_distance_threshold. 1823 self.boundary_distance_thresh = 0.5 1824 self.boundary_distance_thresh_param, layout = self._add_float_param( 1825 "boundary_distance_thresh", self.boundary_distance_thresh, 1826 tooltip=get_tooltip("autosegment", "boundary_distance_thresh") 1827 ) 1828 settings.layout().addLayout(layout) 1829 1830 # Add min_object_size. 1831 self._add_common_settings(settings) 1832 1833 return settings 1834 1835 def _amg_settings(self): 1836 settings = QtWidgets.QWidget() 1837 settings.setLayout(QtWidgets.QVBoxLayout()) 1838 1839 # Create the UI element for pred_iou_thresh. 1840 self.pred_iou_thresh = 0.88 1841 self.pred_iou_thresh_param, layout = self._add_float_param( 1842 "pred_iou_thresh", self.pred_iou_thresh, 1843 tooltip=get_tooltip("autosegment", "pred_iou_thresh") 1844 ) 1845 settings.layout().addLayout(layout) 1846 1847 # Create the UI element for stability score thresh. 1848 self.stability_score_thresh = 0.95 1849 self.stability_score_thresh_param, layout = self._add_float_param( 1850 "stability_score_thresh", self.stability_score_thresh, 1851 tooltip=get_tooltip("autosegment", "stability_score_thresh") 1852 ) 1853 settings.layout().addLayout(layout) 1854 1855 # Create the UI element for box nms thresh. 1856 self.box_nms_thresh = 0.7 1857 self.box_nms_thresh_param, layout = self._add_float_param( 1858 "box_nms_thresh", self.box_nms_thresh, 1859 tooltip=get_tooltip("autosegment", "box_nms_thresh") 1860 ) 1861 settings.layout().addLayout(layout) 1862 1863 # Add min_object_size. 1864 self._add_common_settings(settings) 1865 1866 return settings 1867 1868 def _create_settings(self): 1869 setting_values = self._ais_settings() if self.with_decoder else self._amg_settings() 1870 settings = _make_collapsible(setting_values, title="Automatic Segmentation Settings") 1871 return settings 1872 1873 def _empty_segmentation_warning(self): 1874 msg = "The automatic segmentation result does not contain any objects." 1875 msg += "Setting a smaller value for 'min_object_size' may help." 1876 if not self.with_decoder: 1877 msg += "Setting smaller values for 'pred_iou_thresh' and 'stability_score_thresh' may also help." 1878 val_results = {"message_type": "error", "message": msg} 1879 return _generate_message(val_results["message_type"], val_results["message"]) 1880 1881 def _run_segmentation_2d(self, kwargs, i=None): 1882 pbar, pbar_signals = _create_pbar_for_threadworker() 1883 1884 # @thread_worker 1885 def seg_impl(): 1886 def pbar_init(total, description): 1887 pbar_signals.pbar_total.emit(total) 1888 pbar_signals.pbar_description.emit(description) 1889 1890 seg = _instance_segmentation_impl( 1891 self.min_object_size, i=i, pbar_init=pbar_init, 1892 pbar_update=lambda update: pbar_signals.pbar_update.emit(update), 1893 **kwargs 1894 ) 1895 pbar_signals.pbar_stop.emit() 1896 return seg 1897 1898 def update_segmentation(seg): 1899 is_empty = seg.max() == 0 1900 if is_empty: 1901 self._empty_segmentation_warning() 1902 1903 if i is None: 1904 self._viewer.layers["auto_segmentation"].data = seg 1905 else: 1906 self._viewer.layers["auto_segmentation"].data[i] = seg 1907 self._viewer.layers["auto_segmentation"].refresh() 1908 1909 # Validate all layers. 1910 _validate_layers(self._viewer, automatic_segmentation=True) 1911 1912 seg = seg_impl() 1913 update_segmentation(seg) 1914 # worker = seg_impl() 1915 # worker.returned.connect(update_segmentation) 1916 # worker.start() 1917 # return worker 1918 1919 # We refuse to run 3D segmentation with the AMG unless we have a GPU or all embeddings 1920 # are precomputed. Otherwise this would take too long. 1921 def _allow_segment_3d(self): 1922 if self.with_decoder: 1923 return True 1924 state = AnnotatorState() 1925 predictor = state.predictor 1926 if str(predictor.device) == "cpu" or str(predictor.device) == "mps": 1927 n_slices = self._viewer.layers["auto_segmentation"].data.shape[0] 1928 embeddings_are_precomputed = (state.amg_state is not None) and (len(state.amg_state) > n_slices) 1929 if not embeddings_are_precomputed: 1930 return False 1931 return True 1932 1933 def _run_segmentation_3d(self, kwargs): 1934 allow_segment_3d = self._allow_segment_3d() 1935 if not allow_segment_3d: 1936 val_results = { 1937 "message_type": "error", 1938 "message": "Volumetric segmentation with AMG is only supported if you have a GPU." 1939 } 1940 return _generate_message(val_results["message_type"], val_results["message"]) 1941 1942 pbar, pbar_signals = _create_pbar_for_threadworker() 1943 1944 # @thread_worker 1945 def seg_impl(): 1946 segmentation = np.zeros_like(self._viewer.layers["auto_segmentation"].data) 1947 offset = 0 1948 1949 def pbar_init(total, description): 1950 pbar_signals.pbar_total.emit(total) 1951 pbar_signals.pbar_description.emit(description) 1952 1953 pbar_init(segmentation.shape[0], "Segment volume") 1954 1955 # Further optimization: parallelize if state is precomputed for all slices 1956 for i in range(segmentation.shape[0]): 1957 seg = _instance_segmentation_impl(self.min_object_size, i=i, **kwargs) 1958 seg_max = seg.max() 1959 if seg_max == 0: 1960 continue 1961 seg[seg != 0] += offset 1962 offset = seg_max + offset 1963 segmentation[i] = seg 1964 pbar_signals.pbar_update.emit(1) 1965 1966 pbar_signals.pbar_reset.emit() 1967 segmentation = merge_instance_segmentation_3d( 1968 segmentation, beta=0.5, gap_closing=self.gap_closing, min_z_extent=self.min_extent, 1969 verbose=True, pbar_init=pbar_init, pbar_update=lambda update: pbar_signals.pbar_update.emit(1), 1970 ) 1971 pbar_signals.pbar_stop.emit() 1972 return segmentation 1973 1974 def update_segmentation(segmentation): 1975 is_empty = segmentation.max() == 0 1976 if is_empty: 1977 self._empty_segmentation_warning() 1978 self._viewer.layers["auto_segmentation"].data = segmentation 1979 self._viewer.layers["auto_segmentation"].refresh() 1980 1981 seg = seg_impl() 1982 update_segmentation(seg) 1983 # worker = seg_impl() 1984 # worker.returned.connect(update_segmentation) 1985 # worker.start() 1986 # return worker 1987 1988 def __call__(self): 1989 if _validate_embeddings(self._viewer): 1990 return None 1991 1992 if self.with_decoder: 1993 kwargs = { 1994 "center_distance_threshold": self.center_distance_thresh, 1995 "boundary_distance_threshold": self.boundary_distance_thresh, 1996 "min_size": self.min_object_size, 1997 } 1998 else: 1999 kwargs = { 2000 "pred_iou_thresh": self.pred_iou_thresh, 2001 "stability_score_thresh": self.stability_score_thresh, 2002 "box_nms_thresh": self.box_nms_thresh, 2003 } 2004 if self.volumetric and self.apply_to_volume: 2005 worker = self._run_segmentation_3d(kwargs) 2006 elif self.volumetric and not self.apply_to_volume: 2007 i = int(self._viewer.dims.point[0]) 2008 worker = self._run_segmentation_2d(kwargs, i=i) 2009 else: 2010 worker = self._run_segmentation_2d(kwargs) 2011 _select_layer(self._viewer, "auto_segmentation") 2012 return worker 2013 2014 2015class AutoTrackWidget(AutoSegmentWidget): 2016 def _create_tracking_switch(self): 2017 self.apply_to_volume = False 2018 return self._add_boolean_param( 2019 "apply_to_volume", self.apply_to_volume, title="Track Timeseries", 2020 tooltip=get_tooltip("autotrack", "run_tracking") 2021 ) 2022 2023 def _create_widget(self): 2024 # Add the switch for segmenting the slice vs. tracking the timeseries. 2025 self.layout().addWidget(self._create_tracking_switch()) 2026 2027 # Add the nested settings widget. 2028 self.settings = self._create_settings() 2029 self.layout().addWidget(self.settings) 2030 2031 # Add the run button. 2032 self.run_button = QtWidgets.QPushButton("Automatic Tracking") 2033 self.run_button.clicked.connect(self.__call__) 2034 self.run_button.setToolTip(get_tooltip("autotrack", "run_button")) 2035 self.layout().addWidget(self.run_button) 2036 2037 def _run_segmentation_3d(self, kwargs): 2038 allow_segment_3d = self._allow_segment_3d() 2039 if not allow_segment_3d: 2040 return _generate_message("error", "Tracking with AMG is only supported if you have a GPU.") 2041 2042 state = AnnotatorState() 2043 if len(state.committed_lineages) > 0: 2044 return _generate_message( 2045 "error", 2046 "Automatic tracking can only be called if you haven't commited results from interactive tracking yet." 2047 ) 2048 pbar, pbar_signals = _create_pbar_for_threadworker() 2049 2050 # @thread_worker 2051 def seg_impl(): 2052 image_name = state.get_image_name(self._viewer) 2053 timeseries = self._viewer.layers[image_name].data 2054 segmentation = np.zeros_like(self._viewer.layers["auto_segmentation"].data) 2055 offset = 0 2056 2057 def pbar_init(total, description): 2058 pbar_signals.pbar_total.emit(total) 2059 pbar_signals.pbar_description.emit(description) 2060 2061 pbar_init(segmentation.shape[0], "Run tracking") 2062 2063 # Further optimization: parallelize if state is precomputed for all slices 2064 for i in range(segmentation.shape[0]): 2065 seg = _instance_segmentation_impl(self.min_object_size, i=i, **kwargs) 2066 seg_max = seg.max() 2067 if seg_max == 0: 2068 continue 2069 seg[seg != 0] += offset 2070 offset = seg_max + offset 2071 segmentation[i] = seg 2072 pbar_signals.pbar_update.emit(1) 2073 2074 pbar_signals.pbar_reset.emit() 2075 segmentation, lineages = track_across_frames( 2076 timeseries, segmentation, 2077 verbose=True, pbar_init=pbar_init, 2078 pbar_update=lambda update: pbar_signals.pbar_update.emit(1), 2079 ) 2080 pbar_signals.pbar_stop.emit() 2081 return (segmentation, lineages) 2082 2083 def update_segmentation(result): 2084 segmentation, lineages = result 2085 is_empty = segmentation.max() == 0 2086 if is_empty: 2087 self._empty_segmentation_warning() 2088 2089 state = AnnotatorState() 2090 state.lineage = lineages 2091 2092 self._viewer.layers["auto_segmentation"].data = segmentation 2093 self._viewer.layers["auto_segmentation"].refresh() 2094 2095 result = seg_impl() 2096 update_segmentation(result) 2097 # worker = seg_impl() 2098 # worker.returned.connect(update_segmentation) 2099 # worker.start() 2100 # return worker
364class PBarSignals(QObject): 365 pbar_total = Signal(int) 366 pbar_update = Signal(int) 367 pbar_description = Signal(str) 368 pbar_stop = Signal() 369 pbar_reset = Signal()
QObject(parent: Optional[QObject] = None)
pyqtSignal(*types, name: str = ..., revision: int = ..., arguments: Sequence = ...) -> PYQT_SIGNAL
types is normally a sequence of individual types. Each type is either a type object or a string that is the name of a C++ type. Alternatively each type could itself be a sequence of types each describing a different overloaded signal. name is the optional C++ name of the signal. If it is not specified then the name of the class attribute that is bound to the signal is used. revision is the optional revision of the signal that is exported to QML. If it is not specified then 0 is used. arguments is the optional sequence of the names of the signal's arguments.
pyqtSignal(*types, name: str = ..., revision: int = ..., arguments: Sequence = ...) -> PYQT_SIGNAL
types is normally a sequence of individual types. Each type is either a type object or a string that is the name of a C++ type. Alternatively each type could itself be a sequence of types each describing a different overloaded signal. name is the optional C++ name of the signal. If it is not specified then the name of the class attribute that is bound to the signal is used. revision is the optional revision of the signal that is exported to QML. If it is not specified then 0 is used. arguments is the optional sequence of the names of the signal's arguments.
pyqtSignal(*types, name: str = ..., revision: int = ..., arguments: Sequence = ...) -> PYQT_SIGNAL
types is normally a sequence of individual types. Each type is either a type object or a string that is the name of a C++ type. Alternatively each type could itself be a sequence of types each describing a different overloaded signal. name is the optional C++ name of the signal. If it is not specified then the name of the class attribute that is bound to the signal is used. revision is the optional revision of the signal that is exported to QML. If it is not specified then 0 is used. arguments is the optional sequence of the names of the signal's arguments.
pyqtSignal(*types, name: str = ..., revision: int = ..., arguments: Sequence = ...) -> PYQT_SIGNAL
types is normally a sequence of individual types. Each type is either a type object or a string that is the name of a C++ type. Alternatively each type could itself be a sequence of types each describing a different overloaded signal. name is the optional C++ name of the signal. If it is not specified then the name of the class attribute that is bound to the signal is used. revision is the optional revision of the signal that is exported to QML. If it is not specified then 0 is used. arguments is the optional sequence of the names of the signal's arguments.
pyqtSignal(*types, name: str = ..., revision: int = ..., arguments: Sequence = ...) -> PYQT_SIGNAL
types is normally a sequence of individual types. Each type is either a type object or a string that is the name of a C++ type. Alternatively each type could itself be a sequence of types each describing a different overloaded signal. name is the optional C++ name of the signal. If it is not specified then the name of the class attribute that is bound to the signal is used. revision is the optional revision of the signal that is exported to QML. If it is not specified then 0 is used. arguments is the optional sequence of the names of the signal's arguments.
372class InfoDialog(QtWidgets.QDialog): 373 def __init__(self, title, message): 374 super().__init__() 375 self.setWindowTitle(title) 376 377 layout = QtWidgets.QVBoxLayout() 378 layout.addWidget(QtWidgets.QLabel(message)) 379 380 # Add buttons 381 button_box = QtWidgets.QHBoxLayout() # Use QHBoxLayout for buttons side-by-side 382 accept_button = QtWidgets.QPushButton("OK") 383 accept_button.clicked.connect(lambda: self.button_clicked(accept_button)) # Connect to clicked signal 384 button_box.addWidget(accept_button) 385 386 cancel_button = QtWidgets.QPushButton("Cancel") 387 cancel_button.clicked.connect(lambda: self.button_clicked(cancel_button)) # Connect to clicked signal 388 button_box.addWidget(cancel_button) 389 390 layout.addLayout(button_box) 391 self.setLayout(layout) 392 393 def button_clicked(self, button): 394 if button.text() == "OK": 395 self.accept() # Accept the dialog 396 else: 397 self.reject() # Reject the dialog (Cancel)
QDialog(parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowFlags())
373 def __init__(self, title, message): 374 super().__init__() 375 self.setWindowTitle(title) 376 377 layout = QtWidgets.QVBoxLayout() 378 layout.addWidget(QtWidgets.QLabel(message)) 379 380 # Add buttons 381 button_box = QtWidgets.QHBoxLayout() # Use QHBoxLayout for buttons side-by-side 382 accept_button = QtWidgets.QPushButton("OK") 383 accept_button.clicked.connect(lambda: self.button_clicked(accept_button)) # Connect to clicked signal 384 button_box.addWidget(accept_button) 385 386 cancel_button = QtWidgets.QPushButton("Cancel") 387 cancel_button.clicked.connect(lambda: self.button_clicked(cancel_button)) # Connect to clicked signal 388 button_box.addWidget(cancel_button) 389 390 layout.addLayout(button_box) 391 self.setLayout(layout)
Widget for clearing the current annotations.
Arguments:
- viewer: The napari viewer.
Widget for clearing the current annotations in 3D.
Arguments:
- viewer: The napari viewer.
- all_slices: Choose whether to clear the annotations for all or only the current slice.
Widget for clearing all tracking annotations and state.
Arguments:
- viewer: The napari viewer.
- all_frames: Choose whether to clear the annotations for all or only the current frame.
Widget for committing the segmented objects from automatic or interactive segmentation.
Arguments:
- viewer: The napari viewer.
- layer: Select the layer to commit. Can be either 'current_object' to commit interacitve segmentation results. Or 'auto_segmentation' to commit automatic segmentation results.
- preserve_mode: The mode for preserving already committed objects, in order to prevent over-writing them by a new commit. Supports the modes 'objects', which preserves on the object level and is the default, 'pixels', which preserves on the pixel-level, or 'none', which does not preserve commited objects.
- preservation_threshold: The overlap threshold for preserving objects. This is only used if preservation_mode is set to 'objects'.
- commit_path: Select a file path where the committed results and prompts will be saved. This feature is still experimental.
Widget for committing the objects from interactive tracking.
Arguments:
- viewer: The napari viewer.
- layer: Select the layer to commit. Can be either 'current_object' to commit interacitve segmentation results. Or 'auto_segmentation' to commit automatic segmentation results.
- preserve_mode: The mode for preserving already committed objects, in order to prevent over-writing them by a new commit. Supports the modes 'objects', which preserves on the object level and is the default, 'pixels', which preserves on the pixel-level, or 'none', which does not preserve commited objects.
- preservation_threshold: The overlap threshold for preserving objects. This is only used if preservation_mode is set to 'objects'.
- commit_path: Select a file path where the committed results and prompts will be saved. This feature is still experimental.
Widget to update global micro_sam settings.
Arguments:
- cache_directory: Select the path for the micro_sam cache directory.
$HOME/.cache/micro_sam.
Segment object(s) for the current prompts.
Arguments:
- viewer: The napari viewer.
- batched: Choose if you want to segment multiple objects with point prompts.
Segment object for to the current prompts.
Arguments:
- viewer: The napari viewer.
Segment object for the current prompts.
Arguments:
- viewer: The napari viewer.
1165class EmbeddingWidget(_WidgetBase): 1166 def __init__(self, parent=None): 1167 super().__init__(parent=parent) 1168 1169 # Create a nested layout for the sections. 1170 # Section 1: Image and Model. 1171 section1_layout = QtWidgets.QHBoxLayout() 1172 section1_layout.addLayout(self._create_image_section()) 1173 section1_layout.addLayout(self._create_model_section()) # Creates the model family widget section. 1174 self.layout().addLayout(section1_layout) 1175 1176 # Section 2: Settings (collapsible). 1177 self.layout().addWidget(self._create_settings_widget()) 1178 1179 # Section 3: The button to trigger the embedding computation. 1180 self.run_button = QtWidgets.QPushButton("Compute Embeddings") 1181 self.run_button.clicked.connect(self._initialize_image) 1182 self.run_button.clicked.connect(self.__call__) 1183 self.run_button.setToolTip(get_tooltip("embedding", "run_button")) 1184 self.layout().addWidget(self.run_button) 1185 1186 def _initialize_image(self): 1187 state = AnnotatorState() 1188 layer = self.image_selection.get_value() 1189 1190 # This is encountered when there is no image layer available / selected. 1191 # In this case, we need not specify other image-level parameters to the state. Hence, we skip them. 1192 # NOTE: On code-level, this happens as the first step when "Compute Embedding" click is triggered. 1193 if layer is None: 1194 return 1195 1196 image_shape = layer.data.shape 1197 image_scale = tuple(layer.scale) 1198 state.image_shape = image_shape 1199 state.image_scale = image_scale 1200 state.image_name = layer.name 1201 1202 def _create_image_section(self): 1203 image_section = QtWidgets.QVBoxLayout() 1204 image_layer_widget = QtWidgets.QLabel("Image Layer:") 1205 # image_layer_widget.setToolTip(get_tooltip("embedding", "image")) # this adds tooltip to label 1206 image_section.addWidget(image_layer_widget) 1207 1208 # Setting a napari layer in QT, see: 1209 # https://github.com/pyapp-kit/magicgui/blob/main/docs/examples/napari/napari_combine_qt.py 1210 self.image_selection = create_widget(annotation=napari.layers.Image) 1211 self.image_selection.native.setToolTip(get_tooltip("embedding", "image")) 1212 image_section.addWidget(self.image_selection.native) 1213 1214 return image_section 1215 1216 def _update_model(self, state): 1217 _model_type = state.predictor.model_type if self.custom_weights else self.model_type 1218 1219 # Provide a detailed message for the model family and model size per chosen combination. 1220 msg = "Computed embeddings for " 1221 if self.custom_weights: # Whether the user provided a filepath to custom finetuned model weights. 1222 msg += f"the model located at '{os.path.abspath(self.custom_weights)}' " 1223 msg += f"of size '{self._model_size_map[_model_type[4]]}'." 1224 else: 1225 msg += f"the '{self.model_family}' model of size '{self.model_size}'." 1226 1227 show_info(msg) 1228 1229 state = AnnotatorState() 1230 # Update the widget itself. This is necessary because we may have loaded 1231 # some settings from the embedding file and have to reflect them in the widget. 1232 vutil._sync_embedding_widget( 1233 self, 1234 model_type=_model_type, 1235 save_path=self.embeddings_save_path, 1236 checkpoint_path=self.custom_weights, 1237 device=self.device, 1238 tile_shape=[self.tile_x, self.tile_y], 1239 halo=[self.halo_x, self.halo_y] 1240 ) 1241 1242 # Set the default settings for this model in the autosegment widget if it is part of 1243 # the currently used plugin. 1244 if "autosegment" in state.widgets: 1245 with_decoder = state.decoder is not None 1246 vutil._sync_autosegment_widget( 1247 state.widgets["autosegment"], _model_type, self.custom_weights, update_decoder=with_decoder 1248 ) 1249 # Load the AMG/AIS state if we have a 3d segmentation plugin. 1250 if state.widgets["autosegment"].volumetric and with_decoder: 1251 state.amg_state = vutil._load_is_state(state.embedding_path) 1252 elif state.widgets["autosegment"].volumetric and not with_decoder: 1253 state.amg_state = vutil._load_amg_state(state.embedding_path) 1254 1255 # Set the default settings for this model in the nd-segmentation widget if it is part of 1256 # the currently used plugin. 1257 if "segment_nd" in state.widgets: 1258 vutil._sync_ndsegment_widget(state.widgets["segment_nd"], _model_type, self.custom_weights) 1259 1260 def _create_settings_widget(self): 1261 setting_values = QtWidgets.QWidget() 1262 setting_values.setToolTip(get_tooltip("embedding", "settings")) 1263 setting_values.setLayout(QtWidgets.QVBoxLayout()) 1264 1265 # Add the model size widget section. 1266 layout = self._create_model_size_section() 1267 setting_values.layout().addLayout(layout) 1268 1269 # Create UI for the device. 1270 self.device = "auto" 1271 device_options = ["auto"] + util._available_devices() 1272 1273 self.device_dropdown, layout = self._add_choice_param( 1274 "device", self.device, device_options, tooltip=get_tooltip("embedding", "device") 1275 ) 1276 setting_values.layout().addLayout(layout) 1277 1278 # Create UI for the save path. 1279 self.embeddings_save_path = None 1280 self.embeddings_save_path_param, layout = self._add_path_param( 1281 "embeddings_save_path", self.embeddings_save_path, "directory", title="embeddings save path:", 1282 tooltip=get_tooltip("embedding", "embeddings_save_path") 1283 ) 1284 setting_values.layout().addLayout(layout) 1285 1286 # Create UI for the custom weights. 1287 self.custom_weights = None 1288 self.custom_weights_param, layout = self._add_path_param( 1289 "custom_weights", self.custom_weights, "file", title="custom weights path:", 1290 tooltip=get_tooltip("embedding", "custom_weights") 1291 ) 1292 setting_values.layout().addLayout(layout) 1293 1294 # Create UI for the tile shape. 1295 self.tile_x, self.tile_y = 0, 0 1296 self.tile_x_param, self.tile_y_param, layout = self._add_shape_param( 1297 ("tile_x", "tile_y"), (self.tile_x, self.tile_y), min_val=0, max_val=2048, step=16, 1298 tooltip=get_tooltip("embedding", "tiling") 1299 ) 1300 setting_values.layout().addLayout(layout) 1301 1302 # Create UI for the halo. 1303 self.halo_x, self.halo_y = 0, 0 1304 self.halo_x_param, self.halo_y_param, layout = self._add_shape_param( 1305 ("halo_x", "halo_y"), (self.halo_x, self.halo_y), min_val=0, max_val=512, 1306 tooltip=get_tooltip("embedding", "halo") 1307 ) 1308 setting_values.layout().addLayout(layout) 1309 1310 # Create UI for the choice of automatic segmentation mode. 1311 self.automatic_segmentation_mode = "auto" 1312 auto_seg_options = ["auto", "amg", "ais"] 1313 self.automatic_segmentation_mode_dropdown, layout = self._add_choice_param( 1314 "automatic_segmentation_mode", self.automatic_segmentation_mode, auto_seg_options, 1315 title="automatic segmentation mode", tooltip=get_tooltip("embedding", "automatic_segmentation_mode") 1316 ) 1317 setting_values.layout().addLayout(layout) 1318 1319 settings = _make_collapsible(setting_values, title="Embedding Settings") 1320 return settings 1321 1322 def _validate_inputs(self): 1323 """Validates the inputs for the annotation process and returns a dictionary 1324 containing information for message generation, or False if no messages are needed. 1325 1326 This function performs the following checks: 1327 1328 - If an `embeddings_save_path` is provided: 1329 - Validates the image data signature by comparing it with the signature 1330 of the image data in the viewer's selection. 1331 - Checks for existing embeddings at the specified path. 1332 - If existing embeddings are found, it attempts to load parameters 1333 like tile shape, halo, and model type from the Zarr attributes. 1334 - An informational message is generated based on the loaded parameters. 1335 - If loading existing embeddings fails, an error message is generated. 1336 - If no existing embeddings are found, an informational message is generated. 1337 - If no `embeddings_save_path` is provided, the function returns None. 1338 1339 Returns: 1340 bool: True if the computation should be aborted, otherwise False. 1341 """ 1342 1343 # Check if we have an existing input image to compute the embeddings. 1344 image = self.image_selection.get_value() 1345 if image is None: 1346 return _generate_message("error", "No image has been selected.") 1347 1348 # Check if we have an existing embedding path. 1349 # If yes we check the data signature of these embeddings against the selected image 1350 # and we ask the user if they want to load these embeddings. 1351 if self.embeddings_save_path and os.listdir(self.embeddings_save_path): 1352 try: 1353 f = zarr.open(self.embeddings_save_path, mode="a") 1354 1355 # Validate that the embeddings are complete. 1356 # Note: 'input_size' is the last value set in the attrs of f, 1357 # so we can use it as a proxy to check if the embeddings are fully computed 1358 if "input_size" not in f.attrs: 1359 msg = (f"The embeddings at {self.embeddings_save_path} are incomplete. " 1360 "Specify a different path or remove them.") 1361 return _generate_message("error", msg) 1362 1363 # Validate image data signature. 1364 if "data_signature" in f.attrs: 1365 image = self.image_selection.get_value() 1366 img_signature = util._compute_data_signature(image.data) 1367 if img_signature != f.attrs["data_signature"]: 1368 msg = f"The embeddings don't match with the image: {img_signature} {f.attrs['data_signature']}" 1369 return _generate_message("error", msg) 1370 1371 # Load existing parameters. 1372 self.model_type = f.attrs.get("model_name", f.attrs["model_type"]) 1373 if "tile_shape" in f.attrs and f.attrs["tile_shape"] is not None: 1374 self.tile_x, self.tile_y = f.attrs["tile_shape"] 1375 self.halo_x, self.halo_y = f.attrs["halo"] 1376 val_results = { 1377 "message_type": "info", 1378 "message": (f"Load embeddings for model: {self.model_type} with tile shape: " 1379 f"{self.tile_x}, {self.tile_y} and halo: {self.halo_x}, {self.halo_y}.") 1380 } 1381 else: 1382 self.tile_x, self.tile_y = 0, 0 1383 self.halo_x, self.halo_y = 0, 0 1384 val_results = { 1385 "message_type": "info", 1386 "message": f"Load embeddings for model: {self.model_type}." 1387 } 1388 1389 return _generate_message(val_results["message_type"], val_results["message"]) 1390 1391 except RuntimeError as e: 1392 val_results = { 1393 "message_type": "error", 1394 "message": f"Failed to load image embeddings: {e}" 1395 } 1396 return _generate_message(val_results["message_type"], val_results["message"]) 1397 1398 # Otherwise we either don't have an embedding path or it is empty. We can proceed in both cases. 1399 return False 1400 1401 def _validate_existing_embeddings(self, state): 1402 if state.image_embeddings is None: 1403 return False 1404 else: 1405 val_results = { 1406 "message_type": "info", 1407 "message": "Embeddings have already been precomputed. Press OK to recompute the embeddings." 1408 } 1409 return _generate_message(val_results["message_type"], val_results["message"]) 1410 1411 def __call__(self, skip_validate=False): 1412 self._validate_model_type_and_custom_weights() 1413 1414 # Validate user inputs. 1415 if not skip_validate and self._validate_inputs(): 1416 return 1417 1418 # Get the image. 1419 image = self.image_selection.get_value() 1420 1421 # Update the image embeddings: 1422 state = AnnotatorState() 1423 if self._validate_existing_embeddings(state): 1424 # Whether embeddings already exist to control existing objects in layers. 1425 state.skip_recomputing_embeddings = True 1426 return 1427 1428 state.skip_recomputing_embeddings = False 1429 # Reset the state. 1430 state.reset_state() 1431 1432 # Get image dimensions. 1433 if image.rgb: 1434 ndim = image.data.ndim - 1 1435 state.image_shape = image.data.shape[:-1] 1436 else: 1437 ndim = image.data.ndim 1438 state.image_shape = image.data.shape 1439 1440 # Set layer scale 1441 state.image_scale = tuple(image.scale) 1442 1443 # Process tile_shape and halo, set other data. 1444 tile_shape, halo = _process_tiling_inputs(self.tile_x, self.tile_y, self.halo_x, self.halo_y) 1445 save_path = None if self.embeddings_save_path == "" else self.embeddings_save_path 1446 image_data = image.data 1447 1448 # Set up progress bar and signals for using it within a threadworker. 1449 pbar, pbar_signals = _create_pbar_for_threadworker() 1450 1451 # @thread_worker() 1452 def compute_image_embedding(): 1453 1454 def pbar_init(total, description): 1455 pbar_signals.pbar_total.emit(total) 1456 pbar_signals.pbar_description.emit(description) 1457 1458 # Whether to prefer decoder. 1459 # With 'amg', it is set to 'False', else it is 'True' for the default 'auto' and 'ais' mode. 1460 prefer_decoder = True 1461 if self.automatic_segmentation_mode == "amg": 1462 prefer_decoder = False 1463 1464 state.initialize_predictor( 1465 image_data, model_type=self.model_type, save_path=save_path, ndim=ndim, 1466 device=self.device, checkpoint_path=self.custom_weights, tile_shape=tile_shape, halo=halo, 1467 prefer_decoder=prefer_decoder, pbar_init=pbar_init, 1468 pbar_update=lambda update: pbar_signals.pbar_update.emit(update), 1469 ) 1470 pbar_signals.pbar_stop.emit() 1471 1472 compute_image_embedding() 1473 self._update_model(state) 1474 # worker = compute_image_embedding() 1475 # worker.returned.connect(self._update_model) 1476 # worker.start() 1477 # return worker
QWidget(parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowFlags())
1166 def __init__(self, parent=None): 1167 super().__init__(parent=parent) 1168 1169 # Create a nested layout for the sections. 1170 # Section 1: Image and Model. 1171 section1_layout = QtWidgets.QHBoxLayout() 1172 section1_layout.addLayout(self._create_image_section()) 1173 section1_layout.addLayout(self._create_model_section()) # Creates the model family widget section. 1174 self.layout().addLayout(section1_layout) 1175 1176 # Section 2: Settings (collapsible). 1177 self.layout().addWidget(self._create_settings_widget()) 1178 1179 # Section 3: The button to trigger the embedding computation. 1180 self.run_button = QtWidgets.QPushButton("Compute Embeddings") 1181 self.run_button.clicked.connect(self._initialize_image) 1182 self.run_button.clicked.connect(self.__call__) 1183 self.run_button.setToolTip(get_tooltip("embedding", "run_button")) 1184 self.layout().addWidget(self.run_button)
1509class SegmentNDWidget(_WidgetBase): 1510 def __init__(self, viewer, tracking, parent=None): 1511 super().__init__(parent=parent) 1512 self._viewer = viewer 1513 self.tracking = tracking 1514 1515 # Add the settings. 1516 self.settings = self._create_settings() 1517 self.layout().addWidget(self.settings) 1518 1519 # Add the run button. 1520 button_title = "Segment All Frames [Shift-S]" if self.tracking else "Segment All Slices [Shift-S]" 1521 self.run_button = QtWidgets.QPushButton(button_title) 1522 self.run_button.clicked.connect(self.__call__) 1523 self.layout().addWidget(self.run_button) 1524 1525 def _create_settings(self): 1526 setting_values = QtWidgets.QWidget() 1527 setting_values.setToolTip(get_tooltip("segmentnd", "settings")) 1528 setting_values.setLayout(QtWidgets.QVBoxLayout()) 1529 1530 # Create the UI for the projection modes. 1531 self.projection = "single_point" 1532 self.projection_dropdown, layout = self._add_choice_param( 1533 "projection", self.projection, PROJECTION_MODES, tooltip=get_tooltip("segmentnd", "projection_dropdown") 1534 ) 1535 setting_values.layout().addLayout(layout) 1536 1537 # Create the UI element for the IOU threshold. 1538 self.iou_threshold = 0.5 1539 self.iou_threshold_param, layout = self._add_float_param( 1540 "iou_threshold", self.iou_threshold, tooltip=get_tooltip("segmentnd", "iou_threshold") 1541 ) 1542 setting_values.layout().addLayout(layout) 1543 1544 # Create the UI element for the box extension. 1545 self.box_extension = 0.05 1546 self.box_extension_param, layout = self._add_float_param( 1547 "box_extension", self.box_extension, tooltip=get_tooltip("segmentnd", "box_extension") 1548 ) 1549 setting_values.layout().addLayout(layout) 1550 1551 # Create the UI element for the motion smoothing (if we have the tracking widget). 1552 if self.tracking: 1553 self.motion_smoothing = 0.5 1554 self.motion_smoothing_param, layout = self._add_float_param( 1555 "motion_smoothing", self.motion_smoothing, tooltip=get_tooltip("segmentnd", "motion_smoothing") 1556 ) 1557 setting_values.layout().addLayout(layout) 1558 1559 settings = _make_collapsible(setting_values, title="Segmentation Settings") 1560 return settings 1561 1562 def _run_tracking(self): 1563 state = AnnotatorState() 1564 pbar, pbar_signals = _create_pbar_for_threadworker() 1565 1566 # @thread_worker 1567 def tracking_impl(): 1568 shape = state.image_shape 1569 1570 pbar_signals.pbar_total.emit(shape[0]) 1571 pbar_signals.pbar_description.emit("Track object") 1572 1573 # Step 1: Segment all slices with prompts. 1574 seg, slices, _, stop_upper = vutil.segment_slices_with_prompts( 1575 state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"], 1576 state.image_embeddings, shape, track_id=state.current_track_id, 1577 update_progress=lambda update: pbar_signals.pbar_update.emit(update), 1578 ) 1579 1580 # Step 2: Track the object starting from the lowest annotated slice. 1581 seg, has_division = vutil.track_from_prompts( 1582 self._viewer.layers["point_prompts"], self._viewer.layers["prompts"], seg, 1583 state.predictor, slices, state.image_embeddings, stop_upper, 1584 threshold=self.iou_threshold, projection=self.projection, 1585 motion_smoothing=self.motion_smoothing, 1586 box_extension=self.box_extension, 1587 update_progress=lambda update: pbar_signals.pbar_update.emit(update), 1588 ) 1589 1590 pbar_signals.pbar_stop.emit() 1591 return seg, has_division 1592 1593 def update_segmentation(ret_val): 1594 seg, has_division = ret_val 1595 # If a division has occurred and it's the first time it occurred for this track 1596 # then we need to create the two daughter tracks and update the lineage. 1597 if has_division and (len(state.lineage[state.current_track_id]) == 0): 1598 _update_lineage(self._viewer) 1599 1600 # Clear the old track mask. 1601 self._viewer.layers["current_object"].data[ 1602 self._viewer.layers["current_object"].data == state.current_track_id 1603 ] = 0 1604 # Set the new object mask. 1605 self._viewer.layers["current_object"].data[seg == 1] = state.current_track_id 1606 self._viewer.layers["current_object"].refresh() 1607 1608 ret_val = tracking_impl() 1609 update_segmentation(ret_val) 1610 # worker = tracking_impl() 1611 # worker.returned.connect(update_segmentation) 1612 # worker.start() 1613 # return worker 1614 1615 def _run_volumetric_segmentation(self): 1616 pbar, pbar_signals = _create_pbar_for_threadworker() 1617 1618 # @thread_worker 1619 def volumetric_segmentation_impl(): 1620 state = AnnotatorState() 1621 shape = state.image_shape 1622 1623 pbar_signals.pbar_total.emit(shape[0]) 1624 pbar_signals.pbar_description.emit("Segment object") 1625 1626 # Step 1: Segment all slices with prompts. 1627 seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts( 1628 state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"], 1629 state.image_embeddings, shape, 1630 update_progress=lambda update: pbar_signals.pbar_update.emit(update), 1631 ) 1632 1633 # Step 2: Segment the rest of the volume based on projecting prompts. 1634 seg, (z_min, z_max) = segment_mask_in_volume( 1635 seg, state.predictor, state.image_embeddings, slices, 1636 stop_lower, stop_upper, 1637 iou_threshold=self.iou_threshold, projection=self.projection, 1638 box_extension=self.box_extension, 1639 update_progress=lambda update: pbar_signals.pbar_update.emit(update), 1640 ) 1641 pbar_signals.pbar_stop.emit() 1642 1643 state.z_range = (z_min, z_max) 1644 return seg 1645 1646 def update_segmentation(seg): 1647 self._viewer.layers["current_object"].data = seg 1648 self._viewer.layers["current_object"].refresh() 1649 1650 seg = volumetric_segmentation_impl() 1651 self._viewer.layers["current_object"].data = seg 1652 self._viewer.layers["current_object"].refresh() 1653 # worker = volumetric_segmentation_impl() 1654 # worker.returned.connect(update_segmentation) 1655 # worker.start() 1656 # return worker 1657 1658 def __call__(self): 1659 if _validate_embeddings(self._viewer): 1660 return None 1661 if _validate_layers(self._viewer): 1662 return None 1663 1664 if self.tracking: 1665 return self._run_tracking() 1666 else: 1667 return self._run_volumetric_segmentation()
QWidget(parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowFlags())
1510 def __init__(self, viewer, tracking, parent=None): 1511 super().__init__(parent=parent) 1512 self._viewer = viewer 1513 self.tracking = tracking 1514 1515 # Add the settings. 1516 self.settings = self._create_settings() 1517 self.layout().addWidget(self.settings) 1518 1519 # Add the run button. 1520 button_title = "Segment All Frames [Shift-S]" if self.tracking else "Segment All Slices [Shift-S]" 1521 self.run_button = QtWidgets.QPushButton(button_title) 1522 self.run_button.clicked.connect(self.__call__) 1523 self.layout().addWidget(self.run_button)
1737class AutoSegmentWidget(_WidgetBase): 1738 def __init__(self, viewer, with_decoder, volumetric, parent=None): 1739 super().__init__(parent) 1740 1741 self._viewer = viewer 1742 self.with_decoder = with_decoder 1743 self.volumetric = volumetric 1744 self._create_widget() 1745 1746 def _create_widget(self): 1747 # Add the switch for segmenting the slice vs. the volume if we have a volume. 1748 if self.volumetric: 1749 self.layout().addWidget(self._create_volumetric_switch()) 1750 1751 # Add the nested settings widget. 1752 self.settings = self._create_settings() 1753 self.layout().addWidget(self.settings) 1754 1755 # Add the run button. 1756 self.run_button = QtWidgets.QPushButton("Automatic Segmentation") 1757 self.run_button.clicked.connect(self.__call__) 1758 self.run_button.setToolTip(get_tooltip("autosegment", "run_button")) 1759 self.layout().addWidget(self.run_button) 1760 1761 def _reset_segmentation_mode(self, with_decoder): 1762 # If we already have the same segmentation mode we don't need to do anything. 1763 if with_decoder == self.with_decoder: 1764 return 1765 1766 # Otherwise we change the value of with_decoder. 1767 self.with_decoder = with_decoder 1768 1769 # Then we clear the whole widget. 1770 layout = self.layout() 1771 while layout.count(): 1772 child = layout.takeAt(0) 1773 if child.widget(): 1774 child.widget().deleteLater() 1775 1776 # And then we reset it. 1777 self._create_widget() 1778 1779 def _create_volumetric_switch(self): 1780 self.apply_to_volume = False 1781 return self._add_boolean_param( 1782 "apply_to_volume", self.apply_to_volume, title="Apply to Volume", 1783 tooltip=get_tooltip("autosegment", "apply_to_volume") 1784 ) 1785 1786 def _add_common_settings(self, settings): 1787 # Create the UI element for min object size. 1788 self.min_object_size = 100 1789 self.min_object_size_param, layout = self._add_int_param( 1790 "min_object_size", self.min_object_size, min_val=0, max_val=int(1e4), 1791 tooltip=get_tooltip("autosegment", "min_object_size") 1792 ) 1793 settings.layout().addLayout(layout) 1794 1795 # Add extra settings for volumetric segmentation: gap_closing and min_extent. 1796 if self.volumetric: 1797 self.gap_closing = 2 1798 self.gap_closing_param, layout = self._add_int_param( 1799 "gap_closing", self.gap_closing, min_val=0, max_val=10, 1800 tooltip=get_tooltip("autosegment", "gap_closing") 1801 ) 1802 settings.layout().addLayout(layout) 1803 1804 self.min_extent = 2 1805 self.min_extent_param, layout = self._add_int_param( 1806 "min_extent", self.min_extent, min_val=0, max_val=10, 1807 tooltip=get_tooltip("autosegment", "min_extent") 1808 ) 1809 settings.layout().addLayout(layout) 1810 1811 def _ais_settings(self): 1812 settings = QtWidgets.QWidget() 1813 settings.setLayout(QtWidgets.QVBoxLayout()) 1814 1815 # Create the UI element for center_distance_threshold. 1816 self.center_distance_thresh = 0.5 1817 self.center_distance_thresh_param, layout = self._add_float_param( 1818 "center_distance_thresh", self.center_distance_thresh, 1819 tooltip=get_tooltip("autosegment", "center_distance_thresh") 1820 ) 1821 settings.layout().addLayout(layout) 1822 1823 # Create the UI element for boundary_distance_threshold. 1824 self.boundary_distance_thresh = 0.5 1825 self.boundary_distance_thresh_param, layout = self._add_float_param( 1826 "boundary_distance_thresh", self.boundary_distance_thresh, 1827 tooltip=get_tooltip("autosegment", "boundary_distance_thresh") 1828 ) 1829 settings.layout().addLayout(layout) 1830 1831 # Add min_object_size. 1832 self._add_common_settings(settings) 1833 1834 return settings 1835 1836 def _amg_settings(self): 1837 settings = QtWidgets.QWidget() 1838 settings.setLayout(QtWidgets.QVBoxLayout()) 1839 1840 # Create the UI element for pred_iou_thresh. 1841 self.pred_iou_thresh = 0.88 1842 self.pred_iou_thresh_param, layout = self._add_float_param( 1843 "pred_iou_thresh", self.pred_iou_thresh, 1844 tooltip=get_tooltip("autosegment", "pred_iou_thresh") 1845 ) 1846 settings.layout().addLayout(layout) 1847 1848 # Create the UI element for stability score thresh. 1849 self.stability_score_thresh = 0.95 1850 self.stability_score_thresh_param, layout = self._add_float_param( 1851 "stability_score_thresh", self.stability_score_thresh, 1852 tooltip=get_tooltip("autosegment", "stability_score_thresh") 1853 ) 1854 settings.layout().addLayout(layout) 1855 1856 # Create the UI element for box nms thresh. 1857 self.box_nms_thresh = 0.7 1858 self.box_nms_thresh_param, layout = self._add_float_param( 1859 "box_nms_thresh", self.box_nms_thresh, 1860 tooltip=get_tooltip("autosegment", "box_nms_thresh") 1861 ) 1862 settings.layout().addLayout(layout) 1863 1864 # Add min_object_size. 1865 self._add_common_settings(settings) 1866 1867 return settings 1868 1869 def _create_settings(self): 1870 setting_values = self._ais_settings() if self.with_decoder else self._amg_settings() 1871 settings = _make_collapsible(setting_values, title="Automatic Segmentation Settings") 1872 return settings 1873 1874 def _empty_segmentation_warning(self): 1875 msg = "The automatic segmentation result does not contain any objects." 1876 msg += "Setting a smaller value for 'min_object_size' may help." 1877 if not self.with_decoder: 1878 msg += "Setting smaller values for 'pred_iou_thresh' and 'stability_score_thresh' may also help." 1879 val_results = {"message_type": "error", "message": msg} 1880 return _generate_message(val_results["message_type"], val_results["message"]) 1881 1882 def _run_segmentation_2d(self, kwargs, i=None): 1883 pbar, pbar_signals = _create_pbar_for_threadworker() 1884 1885 # @thread_worker 1886 def seg_impl(): 1887 def pbar_init(total, description): 1888 pbar_signals.pbar_total.emit(total) 1889 pbar_signals.pbar_description.emit(description) 1890 1891 seg = _instance_segmentation_impl( 1892 self.min_object_size, i=i, pbar_init=pbar_init, 1893 pbar_update=lambda update: pbar_signals.pbar_update.emit(update), 1894 **kwargs 1895 ) 1896 pbar_signals.pbar_stop.emit() 1897 return seg 1898 1899 def update_segmentation(seg): 1900 is_empty = seg.max() == 0 1901 if is_empty: 1902 self._empty_segmentation_warning() 1903 1904 if i is None: 1905 self._viewer.layers["auto_segmentation"].data = seg 1906 else: 1907 self._viewer.layers["auto_segmentation"].data[i] = seg 1908 self._viewer.layers["auto_segmentation"].refresh() 1909 1910 # Validate all layers. 1911 _validate_layers(self._viewer, automatic_segmentation=True) 1912 1913 seg = seg_impl() 1914 update_segmentation(seg) 1915 # worker = seg_impl() 1916 # worker.returned.connect(update_segmentation) 1917 # worker.start() 1918 # return worker 1919 1920 # We refuse to run 3D segmentation with the AMG unless we have a GPU or all embeddings 1921 # are precomputed. Otherwise this would take too long. 1922 def _allow_segment_3d(self): 1923 if self.with_decoder: 1924 return True 1925 state = AnnotatorState() 1926 predictor = state.predictor 1927 if str(predictor.device) == "cpu" or str(predictor.device) == "mps": 1928 n_slices = self._viewer.layers["auto_segmentation"].data.shape[0] 1929 embeddings_are_precomputed = (state.amg_state is not None) and (len(state.amg_state) > n_slices) 1930 if not embeddings_are_precomputed: 1931 return False 1932 return True 1933 1934 def _run_segmentation_3d(self, kwargs): 1935 allow_segment_3d = self._allow_segment_3d() 1936 if not allow_segment_3d: 1937 val_results = { 1938 "message_type": "error", 1939 "message": "Volumetric segmentation with AMG is only supported if you have a GPU." 1940 } 1941 return _generate_message(val_results["message_type"], val_results["message"]) 1942 1943 pbar, pbar_signals = _create_pbar_for_threadworker() 1944 1945 # @thread_worker 1946 def seg_impl(): 1947 segmentation = np.zeros_like(self._viewer.layers["auto_segmentation"].data) 1948 offset = 0 1949 1950 def pbar_init(total, description): 1951 pbar_signals.pbar_total.emit(total) 1952 pbar_signals.pbar_description.emit(description) 1953 1954 pbar_init(segmentation.shape[0], "Segment volume") 1955 1956 # Further optimization: parallelize if state is precomputed for all slices 1957 for i in range(segmentation.shape[0]): 1958 seg = _instance_segmentation_impl(self.min_object_size, i=i, **kwargs) 1959 seg_max = seg.max() 1960 if seg_max == 0: 1961 continue 1962 seg[seg != 0] += offset 1963 offset = seg_max + offset 1964 segmentation[i] = seg 1965 pbar_signals.pbar_update.emit(1) 1966 1967 pbar_signals.pbar_reset.emit() 1968 segmentation = merge_instance_segmentation_3d( 1969 segmentation, beta=0.5, gap_closing=self.gap_closing, min_z_extent=self.min_extent, 1970 verbose=True, pbar_init=pbar_init, pbar_update=lambda update: pbar_signals.pbar_update.emit(1), 1971 ) 1972 pbar_signals.pbar_stop.emit() 1973 return segmentation 1974 1975 def update_segmentation(segmentation): 1976 is_empty = segmentation.max() == 0 1977 if is_empty: 1978 self._empty_segmentation_warning() 1979 self._viewer.layers["auto_segmentation"].data = segmentation 1980 self._viewer.layers["auto_segmentation"].refresh() 1981 1982 seg = seg_impl() 1983 update_segmentation(seg) 1984 # worker = seg_impl() 1985 # worker.returned.connect(update_segmentation) 1986 # worker.start() 1987 # return worker 1988 1989 def __call__(self): 1990 if _validate_embeddings(self._viewer): 1991 return None 1992 1993 if self.with_decoder: 1994 kwargs = { 1995 "center_distance_threshold": self.center_distance_thresh, 1996 "boundary_distance_threshold": self.boundary_distance_thresh, 1997 "min_size": self.min_object_size, 1998 } 1999 else: 2000 kwargs = { 2001 "pred_iou_thresh": self.pred_iou_thresh, 2002 "stability_score_thresh": self.stability_score_thresh, 2003 "box_nms_thresh": self.box_nms_thresh, 2004 } 2005 if self.volumetric and self.apply_to_volume: 2006 worker = self._run_segmentation_3d(kwargs) 2007 elif self.volumetric and not self.apply_to_volume: 2008 i = int(self._viewer.dims.point[0]) 2009 worker = self._run_segmentation_2d(kwargs, i=i) 2010 else: 2011 worker = self._run_segmentation_2d(kwargs) 2012 _select_layer(self._viewer, "auto_segmentation") 2013 return worker
QWidget(parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowFlags())
2016class AutoTrackWidget(AutoSegmentWidget): 2017 def _create_tracking_switch(self): 2018 self.apply_to_volume = False 2019 return self._add_boolean_param( 2020 "apply_to_volume", self.apply_to_volume, title="Track Timeseries", 2021 tooltip=get_tooltip("autotrack", "run_tracking") 2022 ) 2023 2024 def _create_widget(self): 2025 # Add the switch for segmenting the slice vs. tracking the timeseries. 2026 self.layout().addWidget(self._create_tracking_switch()) 2027 2028 # Add the nested settings widget. 2029 self.settings = self._create_settings() 2030 self.layout().addWidget(self.settings) 2031 2032 # Add the run button. 2033 self.run_button = QtWidgets.QPushButton("Automatic Tracking") 2034 self.run_button.clicked.connect(self.__call__) 2035 self.run_button.setToolTip(get_tooltip("autotrack", "run_button")) 2036 self.layout().addWidget(self.run_button) 2037 2038 def _run_segmentation_3d(self, kwargs): 2039 allow_segment_3d = self._allow_segment_3d() 2040 if not allow_segment_3d: 2041 return _generate_message("error", "Tracking with AMG is only supported if you have a GPU.") 2042 2043 state = AnnotatorState() 2044 if len(state.committed_lineages) > 0: 2045 return _generate_message( 2046 "error", 2047 "Automatic tracking can only be called if you haven't commited results from interactive tracking yet." 2048 ) 2049 pbar, pbar_signals = _create_pbar_for_threadworker() 2050 2051 # @thread_worker 2052 def seg_impl(): 2053 image_name = state.get_image_name(self._viewer) 2054 timeseries = self._viewer.layers[image_name].data 2055 segmentation = np.zeros_like(self._viewer.layers["auto_segmentation"].data) 2056 offset = 0 2057 2058 def pbar_init(total, description): 2059 pbar_signals.pbar_total.emit(total) 2060 pbar_signals.pbar_description.emit(description) 2061 2062 pbar_init(segmentation.shape[0], "Run tracking") 2063 2064 # Further optimization: parallelize if state is precomputed for all slices 2065 for i in range(segmentation.shape[0]): 2066 seg = _instance_segmentation_impl(self.min_object_size, i=i, **kwargs) 2067 seg_max = seg.max() 2068 if seg_max == 0: 2069 continue 2070 seg[seg != 0] += offset 2071 offset = seg_max + offset 2072 segmentation[i] = seg 2073 pbar_signals.pbar_update.emit(1) 2074 2075 pbar_signals.pbar_reset.emit() 2076 segmentation, lineages = track_across_frames( 2077 timeseries, segmentation, 2078 verbose=True, pbar_init=pbar_init, 2079 pbar_update=lambda update: pbar_signals.pbar_update.emit(1), 2080 ) 2081 pbar_signals.pbar_stop.emit() 2082 return (segmentation, lineages) 2083 2084 def update_segmentation(result): 2085 segmentation, lineages = result 2086 is_empty = segmentation.max() == 0 2087 if is_empty: 2088 self._empty_segmentation_warning() 2089 2090 state = AnnotatorState() 2091 state.lineage = lineages 2092 2093 self._viewer.layers["auto_segmentation"].data = segmentation 2094 self._viewer.layers["auto_segmentation"].refresh() 2095 2096 result = seg_impl() 2097 update_segmentation(result) 2098 # worker = seg_impl() 2099 # worker.returned.connect(update_segmentation) 2100 # worker.start() 2101 # return worker
QWidget(parent: Optional[QWidget] = None, flags: Qt.WindowType = Qt.WindowFlags())