micro_sam.sam_annotator.object_classifier
1import os 2from joblib import dump 3from multiprocessing import cpu_count 4from pathlib import Path 5from typing import List, Optional, Tuple, Union 6 7import imageio.v3 as imageio 8import napari 9import numpy as np 10import torch 11 12from magicgui import magic_factory, magicgui 13from magicgui.widgets import Widget, Container, FunctionGui, create_widget 14from qtpy import QtWidgets 15 16from skimage.measure import regionprops_table 17from sklearn.ensemble import RandomForestClassifier 18 19from .. import util 20from ..object_classification import compute_object_features, project_prediction_to_segmentation 21from ._state import AnnotatorState 22from . import _widgets as widgets 23from .util import _sync_embedding_widget 24 25# 26# Utility functionality. 27# Some of this could be refactored to general purpose functionality that can also 28# be used for inference with the trained classifier. 29# 30 31 32def _accumulate_labels(segmentation, annotations): 33 34 def majority_label(mask, annotation): 35 ids, counts = np.unique(annotation[mask], return_counts=True) 36 if len(ids) == 1 and ids[0] == 0: 37 return 0 38 if ids[0] == 0: 39 ids, counts = ids[1:], counts[1:] 40 return ids[np.argmax(counts)] 41 42 all_features = regionprops_table( 43 segmentation, intensity_image=annotations, properties=("label",), 44 extra_properties=[majority_label], 45 ) 46 return all_features["majority_label"].astype("int") 47 48 49def _train_rf(features, labels, previous_features=None, previous_labels=None, **rf_kwargs): 50 assert len(features) == len(labels) 51 valid = labels != 0 52 X, y = features[valid], labels[valid] 53 54 if previous_features is not None: 55 assert previous_labels is not None and len(previous_features) == len(previous_labels) 56 X = np.concatenate([previous_features, X], axis=0) 57 y = np.concatenate([previous_labels, y], axis=0) 58 59 rf = RandomForestClassifier(**rf_kwargs) 60 rf.fit(X, y) 61 return rf 62 63 64# TODO do we add a shortcut? 65@magic_factory(call_button="Train and predict") 66def _train_and_predict_rf_widget(viewer: "napari.viewer.Viewer") -> None: 67 # Get the object features and the annotations. 68 state = AnnotatorState() 69 state.annotator._require_layers() 70 annotations = viewer.layers["annotations"].data 71 segmentation = state.segmentation_selection.get_value().data 72 73 if state.object_features is None: 74 if widgets._validate_embeddings(viewer): 75 return None 76 image_embeddings = state.image_embeddings 77 seg_ids, features = compute_object_features(image_embeddings, segmentation) 78 state.seg_ids = seg_ids 79 state.object_features = features 80 else: 81 features, seg_ids = state.object_features, state.seg_ids 82 83 previous_features, previous_labels = state.previous_features, state.previous_labels 84 labels = _accumulate_labels(segmentation, annotations) 85 if (labels == 0).all() and (previous_labels is None): 86 return widgets._generate_message("error", "You have not provided any annotations.") 87 88 # Run RF training and store it in the state. 89 rf = _train_rf( 90 features, labels, previous_features=previous_features, previous_labels=previous_labels, 91 n_estimators=200, max_depth=10, n_jobs=cpu_count(), 92 ) 93 state.object_rf = rf 94 95 # Run and set the prediction. 96 pred = rf.predict(features) 97 prediction_data = project_prediction_to_segmentation(segmentation, pred, seg_ids) 98 viewer.layers["prediction"].data = prediction_data 99 100 state.annotator._refresh_label_widget() 101 102 103@magic_factory(call_button="Export Classifier") 104def _create_export_rf_widget(export_path: Optional[Path] = None) -> None: 105 state = AnnotatorState() 106 rf = state.object_rf 107 if rf is None: 108 return widgets._generate_message("error", "You have not run training yet.") 109 if export_path is None or export_path == "": 110 return widgets._generate_message("error", "You have to provide an export path.") 111 # Do we add an extension? .joblib? 112 dump(rf, export_path) 113 # TODO show an info method about the export 114 115# 116# Object classifier implementation. 117# 118 119 120# TODO add a gui element that shows the current label ids, how many objects are labeled, and that 121# enables naming them so that the user can keep track of what has been labeled 122class ObjectClassifier(QtWidgets.QScrollArea): 123 124 def _require_layers(self, layer_choices: Optional[List[str]] = None): 125 # Check whether the image is initialized already. And use the image shape and scale for the layers. 126 state = AnnotatorState() 127 shape = self._shape if state.image_shape is None else state.image_shape 128 129 # Add the label layers for the current object, the automatic segmentation and the committed segmentation. 130 dummy_data = np.zeros(shape, dtype="uint32") 131 image_scale = state.image_scale 132 133 # Before adding new layers, we always check whether a layer with this name already exists or not. 134 if "annotations" not in self._viewer.layers: 135 if layer_choices and "annotations" in layer_choices: 136 widgets._validation_window_for_missing_layer("annotations") 137 annotation_layer = self._viewer.add_labels(data=dummy_data, name="annotations") 138 if image_scale is not None: 139 self._viewer.layers["annotations"].scale = image_scale 140 # Reduce the brush size and set the default mode to "paint" brush mode. 141 annotation_layer.brush_size = 3 142 annotation_layer.mode = "paint" 143 144 if "prediction" not in self._viewer.layers: 145 if layer_choices and "prediction" in layer_choices: 146 widgets._validation_window_for_missing_layer("prediction") 147 self._viewer.add_labels(data=dummy_data, name="prediction") 148 if image_scale is not None: 149 self._viewer.layers["prediction"].scale = image_scale 150 151 def _create_segmentation_layer_section(self): 152 segmentation_selection = QtWidgets.QVBoxLayout() 153 segmentation_layer_widget = QtWidgets.QLabel("Segmentation:") 154 segmentation_selection.addWidget(segmentation_layer_widget) 155 self.segmentation_selection = create_widget(annotation=napari.layers.Labels) 156 state = AnnotatorState() 157 state.segmentation_selection = self.segmentation_selection 158 segmentation_selection.addWidget(self.segmentation_selection.native) 159 return segmentation_selection 160 161 def _create_label_widget(self): 162 self._label_form = QtWidgets.QFormLayout() 163 scroll_area = QtWidgets.QScrollArea() 164 inner = QtWidgets.QWidget() 165 inner.setLayout(self._label_form) 166 scroll_area.setWidget(inner) 167 scroll_area.setWidgetResizable(True) 168 169 layout = QtWidgets.QVBoxLayout() 170 layout.addWidget(QtWidgets.QLabel("Object label names:")) 171 layout.addWidget(scroll_area) 172 173 return layout 174 175 def _refresh_label_widget(self): 176 state = AnnotatorState() 177 178 # Get the current label ids. 179 ids = np.unique(self._viewer.layers["annotations"].data)[1:] 180 if state.previous_labels is not None: 181 ids = np.union1d(ids, np.unique(state.previous_labels)) 182 183 # Add new rows. 184 for lbl in ids: 185 if lbl in self._label_names: 186 continue 187 line = QtWidgets.QLineEdit(self._label_names.get(lbl, "")) 188 self._label_names[lbl] = "" 189 self._label_form.addRow(f"ID {lbl}", line) 190 line.textChanged.connect(lambda txt, lbl=lbl: self._label_names.__setitem__(lbl, txt)) 191 192 # Remove rows whose label vanished. 193 for row in reversed(range(self._label_form.rowCount())): 194 lbl_text = self._label_form.itemAt(row, QtWidgets.QFormLayout.LabelRole).widget().text() 195 lbl_id = int(lbl_text.split()[1]) 196 if lbl_id not in ids: 197 # Remove label+field widgets completely. 198 w_label = self._label_form.itemAt(row, QtWidgets.QFormLayout.LabelRole).widget() 199 w_edit = self._label_form.itemAt(row, QtWidgets.QFormLayout.FieldRole).widget() 200 self._label_form.removeRow(row) 201 w_label.deleteLater() 202 w_edit.deleteLater() 203 self.names.pop(lbl_id, None) 204 205 def _create_widgets(self): 206 # Create the embedding widget and connect all events related to it. 207 self._embedding_widget = widgets.EmbeddingWidget() 208 # Connect events for the image selection box. 209 self._viewer.layers.events.inserted.connect(self._embedding_widget.image_selection.reset_choices) 210 self._viewer.layers.events.removed.connect(self._embedding_widget.image_selection.reset_choices) 211 # Connect the run button with the function to update the image. 212 self._embedding_widget.run_button.clicked.connect(self._update_image) 213 214 # Create the widget for training and prediction of the classifier. 215 self._train_and_predict_widget = _train_and_predict_rf_widget() 216 217 # Create the widget for segmentation selection. 218 self._seg_selection_widget = self._create_segmentation_layer_section() 219 220 # Create the widget for displaying the current label state. 221 self._label_widget = self._create_label_widget() 222 223 # Cretate the widget for exporting the RF. 224 self._export_rf_widget = _create_export_rf_widget() 225 226 self._widgets = { 227 "embeddings": self._embedding_widget, 228 "segmentation_selection": self._seg_selection_widget, 229 "train_and_predict": self._train_and_predict_widget, 230 "label_widget": self._label_widget, 231 "export_rf": self._export_rf_widget, 232 } 233 234 def __init__(self, viewer: "napari.viewer.Viewer") -> None: 235 """Create the GUI for the object classifier. 236 237 Args: 238 viewer: The napari viewer. 239 """ 240 super().__init__() 241 self._viewer = viewer 242 self._annotator_widget = QtWidgets.QWidget() 243 self._annotator_widget.setLayout(QtWidgets.QVBoxLayout()) 244 245 # Add the layers for prompts and segmented obejcts. 246 # Initialize with a dummy shape, which is reset to the correct shape once an image is set. 247 self._shape = (256, 256) 248 self._require_layers() 249 self._ndim = len(self._shape) 250 251 # Create all the widgets and add them to the layout. 252 self._label_names = {} # The names for the object labels. 253 self._create_widgets() 254 255 # We could refactor this. 256 for widget_name, widget in self._widgets.items(): 257 widget_frame = QtWidgets.QGroupBox() 258 widget_layout = QtWidgets.QVBoxLayout() 259 if isinstance(widget, (Container, FunctionGui, Widget)): 260 # This is a magicgui type and we need to get the native qt widget. 261 widget_layout.addWidget(widget.native) 262 elif isinstance(widget, QtWidgets.QLayout): 263 widget_layout.addLayout(widget) 264 else: 265 # This is a qt type and we add the widget directly. 266 widget_layout.addWidget(widget) 267 widget_frame.setLayout(widget_layout) 268 self._annotator_widget.layout().addWidget(widget_frame) 269 270 # Connect the label layer and the refresh function. 271 self._refresh_label_widget() 272 273 # Set the expected annotator class to the state. 274 state = AnnotatorState() 275 state.annotator = self 276 277 # Add the widgets to the state. 278 state.widgets = self._widgets 279 280 # Add the widget to the scroll area. 281 self.setWidgetResizable(True) # Allow widget to resize within scroll area. 282 self.setWidget(self._annotator_widget) 283 284 def _update_image(self, segmentation_result=None): 285 state = AnnotatorState() 286 287 # Whether embeddings already exist and avoid clearing objects in layers. 288 if state.skip_recomputing_embeddings: 289 return 290 291 if state.image_shape is None: 292 return 293 294 # Update the dimension and image shape if it has changed. 295 if state.image_shape != self._shape: 296 self._ndim = len(state.image_shape) 297 self._shape = state.image_shape 298 299 # Before we reset the layers, we ensure all expected layers exist. 300 self._require_layers() 301 302 # Update the image scale. 303 scale = state.image_scale 304 305 # Reset all layers. 306 self._viewer.layers["annotations"].data = np.zeros(self._shape, dtype="uint32") 307 self._viewer.layers["annotations"].scale = scale 308 self._viewer.layers["prediction"].data = np.zeros(self._shape, dtype="uint32") 309 self._viewer.layers["prediction"].scale = scale 310 311 312def object_classifier( 313 image: np.ndarray, 314 segmentation: np.ndarray, 315 embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None, 316 model_type: str = util._DEFAULT_MODEL, 317 tile_shape: Optional[Tuple[int, int]] = None, 318 halo: Optional[Tuple[int, int]] = None, 319 return_viewer: bool = False, 320 viewer: Optional["napari.viewer.Viewer"] = None, 321 checkpoint_path: Optional[str] = None, 322 device: Optional[Union[str, torch.device]] = None, 323 ndim: Optional[int] = None, 324) -> Optional["napari.viewer.Viewer"]: 325 """Start the object classifier for a given image and segmentation. 326 327 Args: 328 image: The image data. 329 segmentation: The segmentation data. 330 embedding_path: Filepath where to save the embeddings 331 or the precompted image embeddings computed by `precompute_image_embeddings`. 332 model_type: The Segment Anything model to use. For details on the available models check out 333 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 334 tile_shape: Shape of tiles for tiled embedding prediction. 335 If `None` then the whole image is passed to Segment Anything. 336 halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders. 337 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 338 By default, does not return the napari viewer. 339 viewer: The viewer to which the Segment Anything functionality should be added. 340 This enables using a pre-initialized viewer. 341 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 342 device: The computational device to use for the SAM model. 343 By default, automatically chooses the best available device. 344 ndim: The dimensionality of the data. If not given will be derived from the data. 345 346 Returns: 347 The napari viewer, only returned if `return_viewer=True`. 348 """ 349 if ndim is None: 350 ndim = image.ndim - 1 if image.shape[-1] == 3 and image.ndim in (3, 4) else image.ndim 351 352 state = AnnotatorState() 353 state.image_shape = image.shape[:ndim] 354 355 state.initialize_predictor( 356 image, model_type=model_type, save_path=embedding_path, 357 halo=halo, tile_shape=tile_shape, precompute_amg_state=False, 358 ndim=ndim, checkpoint_path=checkpoint_path, device=device, 359 skip_load=False, use_cli=True, 360 ) 361 362 if viewer is None: 363 viewer = napari.Viewer() 364 365 viewer.add_image(image, name="image") 366 viewer.add_labels(segmentation, name="segmentation") 367 368 annotator = ObjectClassifier(viewer) 369 370 # Trigger layer update of the annotator so that layers have the correct shape. 371 # And initialize the 'committed_objects' with the segmentation result if it was given. 372 annotator._update_image() 373 374 # Add the annotator widget to the viewer and sync widgets. 375 viewer.window.add_dock_widget(annotator) 376 _sync_embedding_widget( 377 widget=state.widgets["embeddings"], 378 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 379 save_path=embedding_path, 380 checkpoint_path=checkpoint_path, 381 device=device, 382 tile_shape=tile_shape, 383 halo=halo, 384 ) 385 386 if return_viewer: 387 return viewer 388 389 napari.run() 390 391 392def image_series_object_classifier( 393 images: List[np.ndarray], 394 segmentations: List[np.ndarray], 395 output_folder: str, 396 embedding_paths: Optional[List[Union[str, util.ImageEmbeddings]]] = None, 397 model_type: str = util._DEFAULT_MODEL, 398 tile_shape: Optional[Tuple[int, int]] = None, 399 halo: Optional[Tuple[int, int]] = None, 400 checkpoint_path: Optional[str] = None, 401 device: Optional[Union[str, torch.device]] = None, 402 ndim: Optional[int] = None, 403) -> None: 404 """Start the object classifier for a list of images and segmentations. 405 406 This function will save the all features and labels for annotated objects, 407 to enable training a random forest on multiple images. 408 409 Args: 410 images: The input images. 411 segmentations: The input segmentations. 412 output_folder: The folder where segmentation results, trained random forest 413 and the features, labels aggregated during training will be saved. 414 embedding_paths: Filepaths where to save the embeddings 415 or the precompted image embeddings computed by `precompute_image_embeddings`. 416 model_type: The Segment Anything model to use. For details on the available models check out 417 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 418 tile_shape: Shape of tiles for tiled embedding prediction. 419 If `None` then the whole image is passed to Segment Anything. 420 halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders. 421 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 422 device: The computational device to use for the SAM model. 423 By default, automatically chooses the best available device. 424 ndim: The dimensionality of the data. If not given will be derived from the data. 425 """ 426 # TODO precompute the embeddings if not computed, can re-use 'precompute' from image series annotator. 427 # TODO support file paths as inputs 428 # TODO option to skip segmented 429 if len(images) != len(segmentations): 430 raise ValueError( 431 f"Expect the same number of images and segmentations, got {len(images)}, {len(segmentations)}." 432 ) 433 434 end_msg = "You have annotated the last image. Do you wish to close napari?" 435 436 # Initialize the object classifier on the fist image / segmentation. 437 viewer = object_classifier( 438 image=images[0], segmentation=segmentations[0], 439 embedding_path=None if embedding_paths is None else embedding_paths[0], 440 model_type=model_type, tile_shape=tile_shape, halo=halo, 441 return_viewer=True, checkpoint_path=checkpoint_path, 442 device=device, ndim=ndim, 443 ) 444 445 os.makedirs(output_folder, exist_ok=True) 446 next_image_id = 0 447 448 def _save_prediction(image, pred, image_id): 449 fname = f"{Path(image).stem}_prediction.tif" if isinstance(image, str) else f"prediction_{image_id}.tif" 450 save_path = os.path.join(output_folder, fname) 451 imageio.imwrite(save_path, pred, compression="zlib") 452 453 # TODO handle cases where rf for the image was not trained, raise a message, enable contnuing 454 # Add functionality for going to the next image. 455 @magicgui(call_button="Next Image [N]") 456 def next_image(*args): 457 nonlocal next_image_id 458 459 # Get the state and the current segmentation (note that next image id has not yet been increased) 460 state = AnnotatorState() 461 segmentation = segmentations[next_image_id] 462 463 # Keep track of the previous features and labels. 464 labels = _accumulate_labels(segmentation, viewer.layers["annotations"].data) 465 valid = labels != 0 466 if valid.sum() > 0: 467 features, labels = state.object_features[valid], labels[valid] 468 if state.previous_features is None: 469 state.previous_features, state.previous_labels = features, labels 470 else: 471 state.previous_features = np.concatenate([state.previous_features, features], axis=0) 472 state.previous_labels = np.concatenate([state.previous_labels, labels], axis=0) 473 # Save the accumulated features and labels. 474 np.save(os.path.join(output_folder, "features.npy"), state.previous_features) 475 np.save(os.path.join(output_folder, "labels.npy"), state.previous_labels) 476 477 # Save the current prediction and RF. 478 _save_prediction(images[next_image_id], viewer.layers["prediction"].data, next_image_id) 479 dump(state.object_rf, os.path.join(output_folder, "rf.joblib")) 480 481 # Go to the next image. 482 next_image_id += 1 483 484 # Check if we are done. 485 if next_image_id == len(images): 486 # Inform the user via dialog. 487 abort = widgets._generate_message("info", end_msg) 488 if not abort: 489 viewer.close() 490 return 491 492 # Get the next image, segmentation and embedding_path. 493 image = images[next_image_id] 494 segmentation = segmentations[next_image_id] 495 embedding_path = None if embedding_paths is None else embedding_paths[next_image_id] 496 497 # Set the new image in the viewer, state and annotator. 498 viewer.layers["image"].data = image 499 viewer.layers["segmentation"].data = segmentation 500 501 state.initialize_predictor( 502 image, model_type=model_type, ndim=ndim, 503 save_path=embedding_path, 504 tile_shape=tile_shape, halo=halo, 505 predictor=state.predictor, device=device, 506 ) 507 state.image_shape = image.shape if image.ndim == ndim else image.shape[:-1] 508 state.annotator._update_image() 509 510 # Clear the object features and seg-ids from the state. 511 state.object_features = None 512 state.seg_ids = None 513 514 viewer.window.add_dock_widget(next_image) 515 516 @viewer.bind_key("n", overwrite=True) 517 def _next_image(viewer): 518 next_image(viewer) 519 520 napari.run() 521 522 523# TODO: folder annotator 524# TODO: main function
class
ObjectClassifier(PyQt5.QtWidgets.QScrollArea):
123class ObjectClassifier(QtWidgets.QScrollArea): 124 125 def _require_layers(self, layer_choices: Optional[List[str]] = None): 126 # Check whether the image is initialized already. And use the image shape and scale for the layers. 127 state = AnnotatorState() 128 shape = self._shape if state.image_shape is None else state.image_shape 129 130 # Add the label layers for the current object, the automatic segmentation and the committed segmentation. 131 dummy_data = np.zeros(shape, dtype="uint32") 132 image_scale = state.image_scale 133 134 # Before adding new layers, we always check whether a layer with this name already exists or not. 135 if "annotations" not in self._viewer.layers: 136 if layer_choices and "annotations" in layer_choices: 137 widgets._validation_window_for_missing_layer("annotations") 138 annotation_layer = self._viewer.add_labels(data=dummy_data, name="annotations") 139 if image_scale is not None: 140 self._viewer.layers["annotations"].scale = image_scale 141 # Reduce the brush size and set the default mode to "paint" brush mode. 142 annotation_layer.brush_size = 3 143 annotation_layer.mode = "paint" 144 145 if "prediction" not in self._viewer.layers: 146 if layer_choices and "prediction" in layer_choices: 147 widgets._validation_window_for_missing_layer("prediction") 148 self._viewer.add_labels(data=dummy_data, name="prediction") 149 if image_scale is not None: 150 self._viewer.layers["prediction"].scale = image_scale 151 152 def _create_segmentation_layer_section(self): 153 segmentation_selection = QtWidgets.QVBoxLayout() 154 segmentation_layer_widget = QtWidgets.QLabel("Segmentation:") 155 segmentation_selection.addWidget(segmentation_layer_widget) 156 self.segmentation_selection = create_widget(annotation=napari.layers.Labels) 157 state = AnnotatorState() 158 state.segmentation_selection = self.segmentation_selection 159 segmentation_selection.addWidget(self.segmentation_selection.native) 160 return segmentation_selection 161 162 def _create_label_widget(self): 163 self._label_form = QtWidgets.QFormLayout() 164 scroll_area = QtWidgets.QScrollArea() 165 inner = QtWidgets.QWidget() 166 inner.setLayout(self._label_form) 167 scroll_area.setWidget(inner) 168 scroll_area.setWidgetResizable(True) 169 170 layout = QtWidgets.QVBoxLayout() 171 layout.addWidget(QtWidgets.QLabel("Object label names:")) 172 layout.addWidget(scroll_area) 173 174 return layout 175 176 def _refresh_label_widget(self): 177 state = AnnotatorState() 178 179 # Get the current label ids. 180 ids = np.unique(self._viewer.layers["annotations"].data)[1:] 181 if state.previous_labels is not None: 182 ids = np.union1d(ids, np.unique(state.previous_labels)) 183 184 # Add new rows. 185 for lbl in ids: 186 if lbl in self._label_names: 187 continue 188 line = QtWidgets.QLineEdit(self._label_names.get(lbl, "")) 189 self._label_names[lbl] = "" 190 self._label_form.addRow(f"ID {lbl}", line) 191 line.textChanged.connect(lambda txt, lbl=lbl: self._label_names.__setitem__(lbl, txt)) 192 193 # Remove rows whose label vanished. 194 for row in reversed(range(self._label_form.rowCount())): 195 lbl_text = self._label_form.itemAt(row, QtWidgets.QFormLayout.LabelRole).widget().text() 196 lbl_id = int(lbl_text.split()[1]) 197 if lbl_id not in ids: 198 # Remove label+field widgets completely. 199 w_label = self._label_form.itemAt(row, QtWidgets.QFormLayout.LabelRole).widget() 200 w_edit = self._label_form.itemAt(row, QtWidgets.QFormLayout.FieldRole).widget() 201 self._label_form.removeRow(row) 202 w_label.deleteLater() 203 w_edit.deleteLater() 204 self.names.pop(lbl_id, None) 205 206 def _create_widgets(self): 207 # Create the embedding widget and connect all events related to it. 208 self._embedding_widget = widgets.EmbeddingWidget() 209 # Connect events for the image selection box. 210 self._viewer.layers.events.inserted.connect(self._embedding_widget.image_selection.reset_choices) 211 self._viewer.layers.events.removed.connect(self._embedding_widget.image_selection.reset_choices) 212 # Connect the run button with the function to update the image. 213 self._embedding_widget.run_button.clicked.connect(self._update_image) 214 215 # Create the widget for training and prediction of the classifier. 216 self._train_and_predict_widget = _train_and_predict_rf_widget() 217 218 # Create the widget for segmentation selection. 219 self._seg_selection_widget = self._create_segmentation_layer_section() 220 221 # Create the widget for displaying the current label state. 222 self._label_widget = self._create_label_widget() 223 224 # Cretate the widget for exporting the RF. 225 self._export_rf_widget = _create_export_rf_widget() 226 227 self._widgets = { 228 "embeddings": self._embedding_widget, 229 "segmentation_selection": self._seg_selection_widget, 230 "train_and_predict": self._train_and_predict_widget, 231 "label_widget": self._label_widget, 232 "export_rf": self._export_rf_widget, 233 } 234 235 def __init__(self, viewer: "napari.viewer.Viewer") -> None: 236 """Create the GUI for the object classifier. 237 238 Args: 239 viewer: The napari viewer. 240 """ 241 super().__init__() 242 self._viewer = viewer 243 self._annotator_widget = QtWidgets.QWidget() 244 self._annotator_widget.setLayout(QtWidgets.QVBoxLayout()) 245 246 # Add the layers for prompts and segmented obejcts. 247 # Initialize with a dummy shape, which is reset to the correct shape once an image is set. 248 self._shape = (256, 256) 249 self._require_layers() 250 self._ndim = len(self._shape) 251 252 # Create all the widgets and add them to the layout. 253 self._label_names = {} # The names for the object labels. 254 self._create_widgets() 255 256 # We could refactor this. 257 for widget_name, widget in self._widgets.items(): 258 widget_frame = QtWidgets.QGroupBox() 259 widget_layout = QtWidgets.QVBoxLayout() 260 if isinstance(widget, (Container, FunctionGui, Widget)): 261 # This is a magicgui type and we need to get the native qt widget. 262 widget_layout.addWidget(widget.native) 263 elif isinstance(widget, QtWidgets.QLayout): 264 widget_layout.addLayout(widget) 265 else: 266 # This is a qt type and we add the widget directly. 267 widget_layout.addWidget(widget) 268 widget_frame.setLayout(widget_layout) 269 self._annotator_widget.layout().addWidget(widget_frame) 270 271 # Connect the label layer and the refresh function. 272 self._refresh_label_widget() 273 274 # Set the expected annotator class to the state. 275 state = AnnotatorState() 276 state.annotator = self 277 278 # Add the widgets to the state. 279 state.widgets = self._widgets 280 281 # Add the widget to the scroll area. 282 self.setWidgetResizable(True) # Allow widget to resize within scroll area. 283 self.setWidget(self._annotator_widget) 284 285 def _update_image(self, segmentation_result=None): 286 state = AnnotatorState() 287 288 # Whether embeddings already exist and avoid clearing objects in layers. 289 if state.skip_recomputing_embeddings: 290 return 291 292 if state.image_shape is None: 293 return 294 295 # Update the dimension and image shape if it has changed. 296 if state.image_shape != self._shape: 297 self._ndim = len(state.image_shape) 298 self._shape = state.image_shape 299 300 # Before we reset the layers, we ensure all expected layers exist. 301 self._require_layers() 302 303 # Update the image scale. 304 scale = state.image_scale 305 306 # Reset all layers. 307 self._viewer.layers["annotations"].data = np.zeros(self._shape, dtype="uint32") 308 self._viewer.layers["annotations"].scale = scale 309 self._viewer.layers["prediction"].data = np.zeros(self._shape, dtype="uint32") 310 self._viewer.layers["prediction"].scale = scale
QScrollArea(parent: Optional[QWidget] = None)
ObjectClassifier(viewer: napari.viewer.Viewer)
235 def __init__(self, viewer: "napari.viewer.Viewer") -> None: 236 """Create the GUI for the object classifier. 237 238 Args: 239 viewer: The napari viewer. 240 """ 241 super().__init__() 242 self._viewer = viewer 243 self._annotator_widget = QtWidgets.QWidget() 244 self._annotator_widget.setLayout(QtWidgets.QVBoxLayout()) 245 246 # Add the layers for prompts and segmented obejcts. 247 # Initialize with a dummy shape, which is reset to the correct shape once an image is set. 248 self._shape = (256, 256) 249 self._require_layers() 250 self._ndim = len(self._shape) 251 252 # Create all the widgets and add them to the layout. 253 self._label_names = {} # The names for the object labels. 254 self._create_widgets() 255 256 # We could refactor this. 257 for widget_name, widget in self._widgets.items(): 258 widget_frame = QtWidgets.QGroupBox() 259 widget_layout = QtWidgets.QVBoxLayout() 260 if isinstance(widget, (Container, FunctionGui, Widget)): 261 # This is a magicgui type and we need to get the native qt widget. 262 widget_layout.addWidget(widget.native) 263 elif isinstance(widget, QtWidgets.QLayout): 264 widget_layout.addLayout(widget) 265 else: 266 # This is a qt type and we add the widget directly. 267 widget_layout.addWidget(widget) 268 widget_frame.setLayout(widget_layout) 269 self._annotator_widget.layout().addWidget(widget_frame) 270 271 # Connect the label layer and the refresh function. 272 self._refresh_label_widget() 273 274 # Set the expected annotator class to the state. 275 state = AnnotatorState() 276 state.annotator = self 277 278 # Add the widgets to the state. 279 state.widgets = self._widgets 280 281 # Add the widget to the scroll area. 282 self.setWidgetResizable(True) # Allow widget to resize within scroll area. 283 self.setWidget(self._annotator_widget)
Create the GUI for the object classifier.
Arguments:
- viewer: The napari viewer.
Inherited Members
- PyQt5.QtWidgets.QScrollArea
- alignment
- ensureVisible
- ensureWidgetVisible
- event
- eventFilter
- focusNextPrevChild
- resizeEvent
- scrollContentsBy
- setAlignment
- setWidget
- setWidgetResizable
- sizeHint
- takeWidget
- viewportSizeHint
- widget
- widgetResizable
- PyQt5.QtWidgets.QAbstractScrollArea
- SizeAdjustPolicy
- addScrollBarWidget
- contextMenuEvent
- cornerWidget
- dragEnterEvent
- dragLeaveEvent
- dragMoveEvent
- dropEvent
- horizontalScrollBar
- horizontalScrollBarPolicy
- keyPressEvent
- maximumViewportSize
- minimumSizeHint
- mouseDoubleClickEvent
- mouseMoveEvent
- mousePressEvent
- mouseReleaseEvent
- paintEvent
- scrollBarWidgets
- setCornerWidget
- setHorizontalScrollBar
- setHorizontalScrollBarPolicy
- setSizeAdjustPolicy
- setVerticalScrollBar
- setVerticalScrollBarPolicy
- setViewport
- setViewportMargins
- setupViewport
- sizeAdjustPolicy
- verticalScrollBar
- verticalScrollBarPolicy
- viewport
- viewportEvent
- viewportMargins
- wheelEvent
- AdjustIgnored
- AdjustToContents
- AdjustToContentsOnFirstShow
- PyQt5.QtWidgets.QFrame
- Shadow
- Shape
- StyleMask
- changeEvent
- drawFrame
- frameRect
- frameShadow
- frameShape
- frameStyle
- frameWidth
- initStyleOption
- lineWidth
- midLineWidth
- setFrameRect
- setFrameShadow
- setFrameShape
- setFrameStyle
- setLineWidth
- setMidLineWidth
- Box
- HLine
- NoFrame
- Panel
- Plain
- Raised
- Shadow_Mask
- Shape_Mask
- StyledPanel
- Sunken
- VLine
- WinPanel
- PyQt5.QtWidgets.QWidget
- RenderFlag
- RenderFlags
- acceptDrops
- accessibleDescription
- accessibleName
- actionEvent
- actions
- activateWindow
- addAction
- addActions
- adjustSize
- autoFillBackground
- backgroundRole
- baseSize
- childAt
- childrenRect
- childrenRegion
- clearFocus
- clearMask
- close
- closeEvent
- contentsMargins
- contentsRect
- contextMenuPolicy
- create
- createWindowContainer
- cursor
- destroy
- devType
- effectiveWinId
- ensurePolished
- enterEvent
- find
- focusInEvent
- focusNextChild
- focusOutEvent
- focusPolicy
- focusPreviousChild
- focusProxy
- focusWidget
- font
- fontInfo
- fontMetrics
- foregroundRole
- frameGeometry
- frameSize
- geometry
- getContentsMargins
- grab
- grabGesture
- grabKeyboard
- grabMouse
- grabShortcut
- graphicsEffect
- graphicsProxyWidget
- hasFocus
- hasHeightForWidth
- hasMouseTracking
- hasTabletTracking
- height
- heightForWidth
- hide
- hideEvent
- initPainter
- inputMethodEvent
- inputMethodHints
- inputMethodQuery
- insertAction
- insertActions
- isActiveWindow
- isAncestorOf
- isEnabled
- isEnabledTo
- isFullScreen
- isHidden
- isLeftToRight
- isMaximized
- isMinimized
- isModal
- isRightToLeft
- isVisible
- isVisibleTo
- isWindow
- isWindowModified
- keyReleaseEvent
- keyboardGrabber
- layout
- layoutDirection
- leaveEvent
- locale
- lower
- mapFrom
- mapFromGlobal
- mapFromParent
- mapTo
- mapToGlobal
- mapToParent
- mask
- maximumHeight
- maximumSize
- maximumWidth
- metric
- minimumHeight
- minimumSize
- minimumWidth
- mouseGrabber
- move
- moveEvent
- nativeEvent
- nativeParentWidget
- nextInFocusChain
- normalGeometry
- overrideWindowFlags
- overrideWindowState
- paintEngine
- palette
- parentWidget
- pos
- previousInFocusChain
- raise_
- rect
- releaseKeyboard
- releaseMouse
- releaseShortcut
- removeAction
- render
- repaint
- resize
- restoreGeometry
- saveGeometry
- screen
- scroll
- setAcceptDrops
- setAccessibleDescription
- setAccessibleName
- setAttribute
- setAutoFillBackground
- setBackgroundRole
- setBaseSize
- setContentsMargins
- setContextMenuPolicy
- setCursor
- setDisabled
- setEnabled
- setFixedHeight
- setFixedSize
- setFixedWidth
- setFocus
- setFocusPolicy
- setFocusProxy
- setFont
- setForegroundRole
- setGeometry
- setGraphicsEffect
- setHidden
- setInputMethodHints
- setLayout
- setLayoutDirection
- setLocale
- setMask
- setMaximumHeight
- setMaximumSize
- setMaximumWidth
- setMinimumHeight
- setMinimumSize
- setMinimumWidth
- setMouseTracking
- setPalette
- setParent
- setShortcutAutoRepeat
- setShortcutEnabled
- setSizeIncrement
- setSizePolicy
- setStatusTip
- setStyle
- setStyleSheet
- setTabOrder
- setTabletTracking
- setToolTip
- setToolTipDuration
- setUpdatesEnabled
- setVisible
- setWhatsThis
- setWindowFilePath
- setWindowFlag
- setWindowFlags
- setWindowIcon
- setWindowIconText
- setWindowModality
- setWindowModified
- setWindowOpacity
- setWindowRole
- setWindowState
- setWindowTitle
- show
- showEvent
- showFullScreen
- showMaximized
- showMinimized
- showNormal
- size
- sizeIncrement
- sizePolicy
- stackUnder
- statusTip
- style
- styleSheet
- tabletEvent
- testAttribute
- toolTip
- toolTipDuration
- underMouse
- ungrabGesture
- unsetCursor
- unsetLayoutDirection
- unsetLocale
- update
- updateGeometry
- updateMicroFocus
- updatesEnabled
- visibleRegion
- whatsThis
- width
- winId
- window
- windowFilePath
- windowFlags
- windowHandle
- windowIcon
- windowIconText
- windowModality
- windowOpacity
- windowRole
- windowState
- windowTitle
- windowType
- x
- y
- DrawChildren
- DrawWindowBackground
- IgnoreMask
- windowIconTextChanged
- windowIconChanged
- windowTitleChanged
- customContextMenuRequested
- PyQt5.QtCore.QObject
- blockSignals
- childEvent
- children
- connectNotify
- customEvent
- deleteLater
- disconnect
- disconnectNotify
- dumpObjectInfo
- dumpObjectTree
- dynamicPropertyNames
- findChild
- findChildren
- inherits
- installEventFilter
- isSignalConnected
- isWidgetType
- isWindowType
- killTimer
- metaObject
- moveToThread
- objectName
- parent
- property
- pyqtConfigure
- receivers
- removeEventFilter
- sender
- senderSignalIndex
- setObjectName
- setProperty
- signalsBlocked
- startTimer
- thread
- timerEvent
- tr
- staticMetaObject
- objectNameChanged
- destroyed
- PyQt5.QtGui.QPaintDevice
- PaintDeviceMetric
- colorCount
- depth
- devicePixelRatio
- devicePixelRatioF
- devicePixelRatioFScale
- heightMM
- logicalDpiX
- logicalDpiY
- paintingActive
- physicalDpiX
- physicalDpiY
- widthMM
- PdmDepth
- PdmDevicePixelRatio
- PdmDevicePixelRatioScaled
- PdmDpiX
- PdmDpiY
- PdmHeight
- PdmHeightMM
- PdmNumColors
- PdmPhysicalDpiX
- PdmPhysicalDpiY
- PdmWidth
- PdmWidthMM
def
object_classifier( image: numpy.ndarray, segmentation: numpy.ndarray, embedding_path: Union[str, Dict[str, Any], NoneType] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, ndim: Optional[int] = None) -> Optional[napari.viewer.Viewer]:
313def object_classifier( 314 image: np.ndarray, 315 segmentation: np.ndarray, 316 embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None, 317 model_type: str = util._DEFAULT_MODEL, 318 tile_shape: Optional[Tuple[int, int]] = None, 319 halo: Optional[Tuple[int, int]] = None, 320 return_viewer: bool = False, 321 viewer: Optional["napari.viewer.Viewer"] = None, 322 checkpoint_path: Optional[str] = None, 323 device: Optional[Union[str, torch.device]] = None, 324 ndim: Optional[int] = None, 325) -> Optional["napari.viewer.Viewer"]: 326 """Start the object classifier for a given image and segmentation. 327 328 Args: 329 image: The image data. 330 segmentation: The segmentation data. 331 embedding_path: Filepath where to save the embeddings 332 or the precompted image embeddings computed by `precompute_image_embeddings`. 333 model_type: The Segment Anything model to use. For details on the available models check out 334 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 335 tile_shape: Shape of tiles for tiled embedding prediction. 336 If `None` then the whole image is passed to Segment Anything. 337 halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders. 338 return_viewer: Whether to return the napari viewer to further modify it before starting the tool. 339 By default, does not return the napari viewer. 340 viewer: The viewer to which the Segment Anything functionality should be added. 341 This enables using a pre-initialized viewer. 342 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 343 device: The computational device to use for the SAM model. 344 By default, automatically chooses the best available device. 345 ndim: The dimensionality of the data. If not given will be derived from the data. 346 347 Returns: 348 The napari viewer, only returned if `return_viewer=True`. 349 """ 350 if ndim is None: 351 ndim = image.ndim - 1 if image.shape[-1] == 3 and image.ndim in (3, 4) else image.ndim 352 353 state = AnnotatorState() 354 state.image_shape = image.shape[:ndim] 355 356 state.initialize_predictor( 357 image, model_type=model_type, save_path=embedding_path, 358 halo=halo, tile_shape=tile_shape, precompute_amg_state=False, 359 ndim=ndim, checkpoint_path=checkpoint_path, device=device, 360 skip_load=False, use_cli=True, 361 ) 362 363 if viewer is None: 364 viewer = napari.Viewer() 365 366 viewer.add_image(image, name="image") 367 viewer.add_labels(segmentation, name="segmentation") 368 369 annotator = ObjectClassifier(viewer) 370 371 # Trigger layer update of the annotator so that layers have the correct shape. 372 # And initialize the 'committed_objects' with the segmentation result if it was given. 373 annotator._update_image() 374 375 # Add the annotator widget to the viewer and sync widgets. 376 viewer.window.add_dock_widget(annotator) 377 _sync_embedding_widget( 378 widget=state.widgets["embeddings"], 379 model_type=model_type if checkpoint_path is None else state.predictor.model_type, 380 save_path=embedding_path, 381 checkpoint_path=checkpoint_path, 382 device=device, 383 tile_shape=tile_shape, 384 halo=halo, 385 ) 386 387 if return_viewer: 388 return viewer 389 390 napari.run()
Start the object classifier for a given image and segmentation.
Arguments:
- image: The image data.
- segmentation: The segmentation data.
- embedding_path: Filepath where to save the embeddings
or the precompted image embeddings computed by
precompute_image_embeddings
. - model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
- tile_shape: Shape of tiles for tiled embedding prediction.
If
None
then the whole image is passed to Segment Anything. - halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
- return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
- viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
- checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
- device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
- ndim: The dimensionality of the data. If not given will be derived from the data.
Returns:
The napari viewer, only returned if
return_viewer=True
.
def
image_series_object_classifier( images: List[numpy.ndarray], segmentations: List[numpy.ndarray], output_folder: str, embedding_paths: Optional[List[Union[str, Dict[str, Any]]]] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, ndim: Optional[int] = None) -> None:
393def image_series_object_classifier( 394 images: List[np.ndarray], 395 segmentations: List[np.ndarray], 396 output_folder: str, 397 embedding_paths: Optional[List[Union[str, util.ImageEmbeddings]]] = None, 398 model_type: str = util._DEFAULT_MODEL, 399 tile_shape: Optional[Tuple[int, int]] = None, 400 halo: Optional[Tuple[int, int]] = None, 401 checkpoint_path: Optional[str] = None, 402 device: Optional[Union[str, torch.device]] = None, 403 ndim: Optional[int] = None, 404) -> None: 405 """Start the object classifier for a list of images and segmentations. 406 407 This function will save the all features and labels for annotated objects, 408 to enable training a random forest on multiple images. 409 410 Args: 411 images: The input images. 412 segmentations: The input segmentations. 413 output_folder: The folder where segmentation results, trained random forest 414 and the features, labels aggregated during training will be saved. 415 embedding_paths: Filepaths where to save the embeddings 416 or the precompted image embeddings computed by `precompute_image_embeddings`. 417 model_type: The Segment Anything model to use. For details on the available models check out 418 https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models. 419 tile_shape: Shape of tiles for tiled embedding prediction. 420 If `None` then the whole image is passed to Segment Anything. 421 halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders. 422 checkpoint_path: Path to a custom checkpoint from which to load the SAM model. 423 device: The computational device to use for the SAM model. 424 By default, automatically chooses the best available device. 425 ndim: The dimensionality of the data. If not given will be derived from the data. 426 """ 427 # TODO precompute the embeddings if not computed, can re-use 'precompute' from image series annotator. 428 # TODO support file paths as inputs 429 # TODO option to skip segmented 430 if len(images) != len(segmentations): 431 raise ValueError( 432 f"Expect the same number of images and segmentations, got {len(images)}, {len(segmentations)}." 433 ) 434 435 end_msg = "You have annotated the last image. Do you wish to close napari?" 436 437 # Initialize the object classifier on the fist image / segmentation. 438 viewer = object_classifier( 439 image=images[0], segmentation=segmentations[0], 440 embedding_path=None if embedding_paths is None else embedding_paths[0], 441 model_type=model_type, tile_shape=tile_shape, halo=halo, 442 return_viewer=True, checkpoint_path=checkpoint_path, 443 device=device, ndim=ndim, 444 ) 445 446 os.makedirs(output_folder, exist_ok=True) 447 next_image_id = 0 448 449 def _save_prediction(image, pred, image_id): 450 fname = f"{Path(image).stem}_prediction.tif" if isinstance(image, str) else f"prediction_{image_id}.tif" 451 save_path = os.path.join(output_folder, fname) 452 imageio.imwrite(save_path, pred, compression="zlib") 453 454 # TODO handle cases where rf for the image was not trained, raise a message, enable contnuing 455 # Add functionality for going to the next image. 456 @magicgui(call_button="Next Image [N]") 457 def next_image(*args): 458 nonlocal next_image_id 459 460 # Get the state and the current segmentation (note that next image id has not yet been increased) 461 state = AnnotatorState() 462 segmentation = segmentations[next_image_id] 463 464 # Keep track of the previous features and labels. 465 labels = _accumulate_labels(segmentation, viewer.layers["annotations"].data) 466 valid = labels != 0 467 if valid.sum() > 0: 468 features, labels = state.object_features[valid], labels[valid] 469 if state.previous_features is None: 470 state.previous_features, state.previous_labels = features, labels 471 else: 472 state.previous_features = np.concatenate([state.previous_features, features], axis=0) 473 state.previous_labels = np.concatenate([state.previous_labels, labels], axis=0) 474 # Save the accumulated features and labels. 475 np.save(os.path.join(output_folder, "features.npy"), state.previous_features) 476 np.save(os.path.join(output_folder, "labels.npy"), state.previous_labels) 477 478 # Save the current prediction and RF. 479 _save_prediction(images[next_image_id], viewer.layers["prediction"].data, next_image_id) 480 dump(state.object_rf, os.path.join(output_folder, "rf.joblib")) 481 482 # Go to the next image. 483 next_image_id += 1 484 485 # Check if we are done. 486 if next_image_id == len(images): 487 # Inform the user via dialog. 488 abort = widgets._generate_message("info", end_msg) 489 if not abort: 490 viewer.close() 491 return 492 493 # Get the next image, segmentation and embedding_path. 494 image = images[next_image_id] 495 segmentation = segmentations[next_image_id] 496 embedding_path = None if embedding_paths is None else embedding_paths[next_image_id] 497 498 # Set the new image in the viewer, state and annotator. 499 viewer.layers["image"].data = image 500 viewer.layers["segmentation"].data = segmentation 501 502 state.initialize_predictor( 503 image, model_type=model_type, ndim=ndim, 504 save_path=embedding_path, 505 tile_shape=tile_shape, halo=halo, 506 predictor=state.predictor, device=device, 507 ) 508 state.image_shape = image.shape if image.ndim == ndim else image.shape[:-1] 509 state.annotator._update_image() 510 511 # Clear the object features and seg-ids from the state. 512 state.object_features = None 513 state.seg_ids = None 514 515 viewer.window.add_dock_widget(next_image) 516 517 @viewer.bind_key("n", overwrite=True) 518 def _next_image(viewer): 519 next_image(viewer) 520 521 napari.run()
Start the object classifier for a list of images and segmentations.
This function will save the all features and labels for annotated objects, to enable training a random forest on multiple images.
Arguments:
- images: The input images.
- segmentations: The input segmentations.
- output_folder: The folder where segmentation results, trained random forest and the features, labels aggregated during training will be saved.
- embedding_paths: Filepaths where to save the embeddings
or the precompted image embeddings computed by
precompute_image_embeddings
. - model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
- tile_shape: Shape of tiles for tiled embedding prediction.
If
None
then the whole image is passed to Segment Anything. - halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
- checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
- device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
- ndim: The dimensionality of the data. If not given will be derived from the data.