micro_sam.sam_annotator.object_classifier

  1import os
  2from joblib import dump
  3from multiprocessing import cpu_count
  4from pathlib import Path
  5from typing import List, Optional, Tuple, Union
  6
  7import imageio.v3 as imageio
  8import napari
  9import numpy as np
 10import torch
 11
 12from magicgui import magic_factory, magicgui
 13from magicgui.widgets import Widget, Container, FunctionGui, create_widget
 14from qtpy import QtWidgets
 15
 16from skimage.measure import regionprops_table
 17from sklearn.ensemble import RandomForestClassifier
 18
 19from .. import util
 20from ..object_classification import compute_object_features, project_prediction_to_segmentation
 21from ._state import AnnotatorState
 22from . import _widgets as widgets
 23from .util import _sync_embedding_widget
 24
 25#
 26# Utility functionality.
 27# Some of this could be refactored to general purpose functionality that can also
 28# be used for inference with the trained classifier.
 29#
 30
 31
 32def _accumulate_labels(segmentation, annotations):
 33
 34    def majority_label(mask, annotation):
 35        ids, counts = np.unique(annotation[mask], return_counts=True)
 36        if len(ids) == 1 and ids[0] == 0:
 37            return 0
 38        if ids[0] == 0:
 39            ids, counts = ids[1:], counts[1:]
 40        return ids[np.argmax(counts)]
 41
 42    all_features = regionprops_table(
 43        segmentation, intensity_image=annotations, properties=("label",),
 44        extra_properties=[majority_label],
 45    )
 46    return all_features["majority_label"].astype("int")
 47
 48
 49def _train_rf(features, labels, previous_features=None, previous_labels=None, **rf_kwargs):
 50    assert len(features) == len(labels)
 51    valid = labels != 0
 52    X, y = features[valid], labels[valid]
 53
 54    if previous_features is not None:
 55        assert previous_labels is not None and len(previous_features) == len(previous_labels)
 56        X = np.concatenate([previous_features, X], axis=0)
 57        y = np.concatenate([previous_labels, y], axis=0)
 58
 59    rf = RandomForestClassifier(**rf_kwargs)
 60    rf.fit(X, y)
 61    return rf
 62
 63
 64# TODO do we add a shortcut?
 65@magic_factory(call_button="Train and predict")
 66def _train_and_predict_rf_widget(viewer: "napari.viewer.Viewer") -> None:
 67    # Get the object features and the annotations.
 68    state = AnnotatorState()
 69    state.annotator._require_layers()
 70    annotations = viewer.layers["annotations"].data
 71    segmentation = state.segmentation_selection.get_value().data
 72
 73    if state.object_features is None:
 74        if widgets._validate_embeddings(viewer):
 75            return None
 76        image_embeddings = state.image_embeddings
 77        seg_ids, features = compute_object_features(image_embeddings, segmentation)
 78        state.seg_ids = seg_ids
 79        state.object_features = features
 80    else:
 81        features, seg_ids = state.object_features, state.seg_ids
 82
 83    previous_features, previous_labels = state.previous_features, state.previous_labels
 84    labels = _accumulate_labels(segmentation, annotations)
 85    if (labels == 0).all() and (previous_labels is None):
 86        return widgets._generate_message("error", "You have not provided any annotations.")
 87
 88    # Run RF training and store it in the state.
 89    rf = _train_rf(
 90        features, labels, previous_features=previous_features, previous_labels=previous_labels,
 91        n_estimators=200, max_depth=10, n_jobs=cpu_count(),
 92    )
 93    state.object_rf = rf
 94
 95    # Run and set the prediction.
 96    pred = rf.predict(features)
 97    prediction_data = project_prediction_to_segmentation(segmentation, pred, seg_ids)
 98    viewer.layers["prediction"].data = prediction_data
 99
100    state.annotator._refresh_label_widget()
101
102
103@magic_factory(call_button="Export Classifier")
104def _create_export_rf_widget(export_path: Optional[Path] = None) -> None:
105    state = AnnotatorState()
106    rf = state.object_rf
107    if rf is None:
108        return widgets._generate_message("error", "You have not run training yet.")
109    if export_path is None or export_path == "":
110        return widgets._generate_message("error", "You have to provide an export path.")
111    # Do we add an extension? .joblib?
112    dump(rf, export_path)
113    # TODO show an info method about the export
114
115#
116# Object classifier implementation.
117#
118
119
120# TODO add a gui element that shows the current label ids, how many objects are labeled, and that
121# enables naming them so that the user can keep track of what has been labeled
122class ObjectClassifier(QtWidgets.QScrollArea):
123
124    def _require_layers(self, layer_choices: Optional[List[str]] = None):
125        # Check whether the image is initialized already. And use the image shape and scale for the layers.
126        state = AnnotatorState()
127        shape = self._shape if state.image_shape is None else state.image_shape
128
129        # Add the label layers for the current object, the automatic segmentation and the committed segmentation.
130        dummy_data = np.zeros(shape, dtype="uint32")
131        image_scale = state.image_scale
132
133        # Before adding new layers, we always check whether a layer with this name already exists or not.
134        if "annotations" not in self._viewer.layers:
135            if layer_choices and "annotations" in layer_choices:
136                widgets._validation_window_for_missing_layer("annotations")
137            annotation_layer = self._viewer.add_labels(data=dummy_data, name="annotations")
138            if image_scale is not None:
139                self._viewer.layers["annotations"].scale = image_scale
140            # Reduce the brush size and set the default mode to "paint" brush mode.
141            annotation_layer.brush_size = 3
142            annotation_layer.mode = "paint"
143
144        if "prediction" not in self._viewer.layers:
145            if layer_choices and "prediction" in layer_choices:
146                widgets._validation_window_for_missing_layer("prediction")
147            self._viewer.add_labels(data=dummy_data, name="prediction")
148            if image_scale is not None:
149                self._viewer.layers["prediction"].scale = image_scale
150
151    def _create_segmentation_layer_section(self):
152        segmentation_selection = QtWidgets.QVBoxLayout()
153        segmentation_layer_widget = QtWidgets.QLabel("Segmentation:")
154        segmentation_selection.addWidget(segmentation_layer_widget)
155        self.segmentation_selection = create_widget(annotation=napari.layers.Labels)
156        state = AnnotatorState()
157        state.segmentation_selection = self.segmentation_selection
158        segmentation_selection.addWidget(self.segmentation_selection.native)
159        return segmentation_selection
160
161    def _create_label_widget(self):
162        self._label_form = QtWidgets.QFormLayout()
163        scroll_area = QtWidgets.QScrollArea()
164        inner = QtWidgets.QWidget()
165        inner.setLayout(self._label_form)
166        scroll_area.setWidget(inner)
167        scroll_area.setWidgetResizable(True)
168
169        layout = QtWidgets.QVBoxLayout()
170        layout.addWidget(QtWidgets.QLabel("Object label names:"))
171        layout.addWidget(scroll_area)
172
173        return layout
174
175    def _refresh_label_widget(self):
176        state = AnnotatorState()
177
178        # Get the current label ids.
179        ids = np.unique(self._viewer.layers["annotations"].data)[1:]
180        if state.previous_labels is not None:
181            ids = np.union1d(ids, np.unique(state.previous_labels))
182
183        # Add new rows.
184        for lbl in ids:
185            if lbl in self._label_names:
186                continue
187            line = QtWidgets.QLineEdit(self._label_names.get(lbl, ""))
188            self._label_names[lbl] = ""
189            self._label_form.addRow(f"ID {lbl}", line)
190            line.textChanged.connect(lambda txt, lbl=lbl: self._label_names.__setitem__(lbl, txt))
191
192        # Remove rows whose label vanished.
193        for row in reversed(range(self._label_form.rowCount())):
194            lbl_text = self._label_form.itemAt(row, QtWidgets.QFormLayout.LabelRole).widget().text()
195            lbl_id = int(lbl_text.split()[1])
196            if lbl_id not in ids:
197                # Remove label+field widgets completely.
198                w_label = self._label_form.itemAt(row, QtWidgets.QFormLayout.LabelRole).widget()
199                w_edit = self._label_form.itemAt(row, QtWidgets.QFormLayout.FieldRole).widget()
200                self._label_form.removeRow(row)
201                w_label.deleteLater()
202                w_edit.deleteLater()
203                self.names.pop(lbl_id, None)
204
205    def _create_widgets(self):
206        # Create the embedding widget and connect all events related to it.
207        self._embedding_widget = widgets.EmbeddingWidget()
208        # Connect events for the image selection box.
209        self._viewer.layers.events.inserted.connect(self._embedding_widget.image_selection.reset_choices)
210        self._viewer.layers.events.removed.connect(self._embedding_widget.image_selection.reset_choices)
211        # Connect the run button with the function to update the image.
212        self._embedding_widget.run_button.clicked.connect(self._update_image)
213
214        # Create the widget for training and prediction of the classifier.
215        self._train_and_predict_widget = _train_and_predict_rf_widget()
216
217        # Create the widget for segmentation selection.
218        self._seg_selection_widget = self._create_segmentation_layer_section()
219
220        # Create the widget for displaying the current label state.
221        self._label_widget = self._create_label_widget()
222
223        # Cretate the widget for exporting the RF.
224        self._export_rf_widget = _create_export_rf_widget()
225
226        self._widgets = {
227            "embeddings": self._embedding_widget,
228            "segmentation_selection": self._seg_selection_widget,
229            "train_and_predict": self._train_and_predict_widget,
230            "label_widget": self._label_widget,
231            "export_rf": self._export_rf_widget,
232        }
233
234    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
235        """Create the GUI for the object classifier.
236
237        Args:
238            viewer: The napari viewer.
239        """
240        super().__init__()
241        self._viewer = viewer
242        self._annotator_widget = QtWidgets.QWidget()
243        self._annotator_widget.setLayout(QtWidgets.QVBoxLayout())
244
245        # Add the layers for prompts and segmented obejcts.
246        # Initialize with a dummy shape, which is reset to the correct shape once an image is set.
247        self._shape = (256, 256)
248        self._require_layers()
249        self._ndim = len(self._shape)
250
251        # Create all the widgets and add them to the layout.
252        self._label_names = {}  # The names for the object labels.
253        self._create_widgets()
254
255        # We could refactor this.
256        for widget_name, widget in self._widgets.items():
257            widget_frame = QtWidgets.QGroupBox()
258            widget_layout = QtWidgets.QVBoxLayout()
259            if isinstance(widget, (Container, FunctionGui, Widget)):
260                # This is a magicgui type and we need to get the native qt widget.
261                widget_layout.addWidget(widget.native)
262            elif isinstance(widget, QtWidgets.QLayout):
263                widget_layout.addLayout(widget)
264            else:
265                # This is a qt type and we add the widget directly.
266                widget_layout.addWidget(widget)
267            widget_frame.setLayout(widget_layout)
268            self._annotator_widget.layout().addWidget(widget_frame)
269
270        # Connect the label layer and the refresh function.
271        self._refresh_label_widget()
272
273        # Set the expected annotator class to the state.
274        state = AnnotatorState()
275        state.annotator = self
276
277        # Add the widgets to the state.
278        state.widgets = self._widgets
279
280        # Add the widget to the scroll area.
281        self.setWidgetResizable(True)  # Allow widget to resize within scroll area.
282        self.setWidget(self._annotator_widget)
283
284    def _update_image(self, segmentation_result=None):
285        state = AnnotatorState()
286
287        # Whether embeddings already exist and avoid clearing objects in layers.
288        if state.skip_recomputing_embeddings:
289            return
290
291        if state.image_shape is None:
292            return
293
294        # Update the dimension and image shape if it has changed.
295        if state.image_shape != self._shape:
296            self._ndim = len(state.image_shape)
297            self._shape = state.image_shape
298
299        # Before we reset the layers, we ensure all expected layers exist.
300        self._require_layers()
301
302        # Update the image scale.
303        scale = state.image_scale
304
305        # Reset all layers.
306        self._viewer.layers["annotations"].data = np.zeros(self._shape, dtype="uint32")
307        self._viewer.layers["annotations"].scale = scale
308        self._viewer.layers["prediction"].data = np.zeros(self._shape, dtype="uint32")
309        self._viewer.layers["prediction"].scale = scale
310
311
312def object_classifier(
313    image: np.ndarray,
314    segmentation: np.ndarray,
315    embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None,
316    model_type: str = util._DEFAULT_MODEL,
317    tile_shape: Optional[Tuple[int, int]] = None,
318    halo: Optional[Tuple[int, int]] = None,
319    return_viewer: bool = False,
320    viewer: Optional["napari.viewer.Viewer"] = None,
321    checkpoint_path: Optional[str] = None,
322    device: Optional[Union[str, torch.device]] = None,
323    ndim: Optional[int] = None,
324) -> Optional["napari.viewer.Viewer"]:
325    """Start the object classifier for a given image and segmentation.
326
327    Args:
328        image: The image data.
329        segmentation: The segmentation data.
330        embedding_path: Filepath where to save the embeddings
331            or the precompted image embeddings computed by `precompute_image_embeddings`.
332        model_type: The Segment Anything model to use. For details on the available models check out
333            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
334        tile_shape: Shape of tiles for tiled embedding prediction.
335            If `None` then the whole image is passed to Segment Anything.
336        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
337        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
338            By default, does not return the napari viewer.
339        viewer: The viewer to which the Segment Anything functionality should be added.
340            This enables using a pre-initialized viewer.
341        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
342        device: The computational device to use for the SAM model.
343            By default, automatically chooses the best available device.
344        ndim: The dimensionality of the data. If not given will be derived from the data.
345
346    Returns:
347        The napari viewer, only returned if `return_viewer=True`.
348    """
349    if ndim is None:
350        ndim = image.ndim - 1 if image.shape[-1] == 3 and image.ndim in (3, 4) else image.ndim
351
352    state = AnnotatorState()
353    state.image_shape = image.shape[:ndim]
354
355    state.initialize_predictor(
356        image, model_type=model_type, save_path=embedding_path,
357        halo=halo, tile_shape=tile_shape, precompute_amg_state=False,
358        ndim=ndim, checkpoint_path=checkpoint_path, device=device,
359        skip_load=False, use_cli=True,
360    )
361
362    if viewer is None:
363        viewer = napari.Viewer()
364
365    viewer.add_image(image, name="image")
366    viewer.add_labels(segmentation, name="segmentation")
367
368    annotator = ObjectClassifier(viewer)
369
370    # Trigger layer update of the annotator so that layers have the correct shape.
371    # And initialize the 'committed_objects' with the segmentation result if it was given.
372    annotator._update_image()
373
374    # Add the annotator widget to the viewer and sync widgets.
375    viewer.window.add_dock_widget(annotator)
376    _sync_embedding_widget(
377        widget=state.widgets["embeddings"],
378        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
379        save_path=embedding_path,
380        checkpoint_path=checkpoint_path,
381        device=device,
382        tile_shape=tile_shape,
383        halo=halo,
384    )
385
386    if return_viewer:
387        return viewer
388
389    napari.run()
390
391
392def image_series_object_classifier(
393    images: List[np.ndarray],
394    segmentations: List[np.ndarray],
395    output_folder: str,
396    embedding_paths: Optional[List[Union[str, util.ImageEmbeddings]]] = None,
397    model_type: str = util._DEFAULT_MODEL,
398    tile_shape: Optional[Tuple[int, int]] = None,
399    halo: Optional[Tuple[int, int]] = None,
400    checkpoint_path: Optional[str] = None,
401    device: Optional[Union[str, torch.device]] = None,
402    ndim: Optional[int] = None,
403) -> None:
404    """Start the object classifier for a list of images and segmentations.
405
406    This function will save the all features and labels for annotated objects,
407    to enable training a random forest on multiple images.
408
409    Args:
410        images: The input images.
411        segmentations: The input segmentations.
412        output_folder: The folder where segmentation results, trained random forest
413            and the features, labels aggregated during training will be saved.
414        embedding_paths: Filepaths where to save the embeddings
415            or the precompted image embeddings computed by `precompute_image_embeddings`.
416        model_type: The Segment Anything model to use. For details on the available models check out
417            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
418        tile_shape: Shape of tiles for tiled embedding prediction.
419            If `None` then the whole image is passed to Segment Anything.
420        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
421        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
422        device: The computational device to use for the SAM model.
423            By default, automatically chooses the best available device.
424        ndim: The dimensionality of the data. If not given will be derived from the data.
425    """
426    # TODO precompute the embeddings if not computed, can re-use 'precompute' from image series annotator.
427    # TODO support file paths as inputs
428    # TODO option to skip segmented
429    if len(images) != len(segmentations):
430        raise ValueError(
431            f"Expect the same number of images and segmentations, got {len(images)}, {len(segmentations)}."
432        )
433
434    end_msg = "You have annotated the last image. Do you wish to close napari?"
435
436    # Initialize the object classifier on the fist image / segmentation.
437    viewer = object_classifier(
438        image=images[0], segmentation=segmentations[0],
439        embedding_path=None if embedding_paths is None else embedding_paths[0],
440        model_type=model_type, tile_shape=tile_shape, halo=halo,
441        return_viewer=True, checkpoint_path=checkpoint_path,
442        device=device, ndim=ndim,
443    )
444
445    os.makedirs(output_folder, exist_ok=True)
446    next_image_id = 0
447
448    def _save_prediction(image, pred, image_id):
449        fname = f"{Path(image).stem}_prediction.tif" if isinstance(image, str) else f"prediction_{image_id}.tif"
450        save_path = os.path.join(output_folder, fname)
451        imageio.imwrite(save_path, pred, compression="zlib")
452
453    # TODO handle cases where rf for the image was not trained, raise a message, enable contnuing
454    # Add functionality for going to the next image.
455    @magicgui(call_button="Next Image [N]")
456    def next_image(*args):
457        nonlocal next_image_id
458
459        # Get the state and the current segmentation (note that next image id has not yet been increased)
460        state = AnnotatorState()
461        segmentation = segmentations[next_image_id]
462
463        # Keep track of the previous features and labels.
464        labels = _accumulate_labels(segmentation, viewer.layers["annotations"].data)
465        valid = labels != 0
466        if valid.sum() > 0:
467            features, labels = state.object_features[valid], labels[valid]
468            if state.previous_features is None:
469                state.previous_features, state.previous_labels = features, labels
470            else:
471                state.previous_features = np.concatenate([state.previous_features, features], axis=0)
472                state.previous_labels = np.concatenate([state.previous_labels, labels], axis=0)
473            # Save the accumulated features and labels.
474            np.save(os.path.join(output_folder, "features.npy"), state.previous_features)
475            np.save(os.path.join(output_folder, "labels.npy"), state.previous_labels)
476
477        # Save the current prediction and RF.
478        _save_prediction(images[next_image_id], viewer.layers["prediction"].data, next_image_id)
479        dump(state.object_rf, os.path.join(output_folder, "rf.joblib"))
480
481        # Go to the next image.
482        next_image_id += 1
483
484        # Check if we are done.
485        if next_image_id == len(images):
486            # Inform the user via dialog.
487            abort = widgets._generate_message("info", end_msg)
488            if not abort:
489                viewer.close()
490            return
491
492        # Get the next image, segmentation and embedding_path.
493        image = images[next_image_id]
494        segmentation = segmentations[next_image_id]
495        embedding_path = None if embedding_paths is None else embedding_paths[next_image_id]
496
497        # Set the new image in the viewer, state and annotator.
498        viewer.layers["image"].data = image
499        viewer.layers["segmentation"].data = segmentation
500
501        state.initialize_predictor(
502            image, model_type=model_type, ndim=ndim,
503            save_path=embedding_path,
504            tile_shape=tile_shape, halo=halo,
505            predictor=state.predictor, device=device,
506        )
507        state.image_shape = image.shape if image.ndim == ndim else image.shape[:-1]
508        state.annotator._update_image()
509
510        # Clear the object features and seg-ids from the state.
511        state.object_features = None
512        state.seg_ids = None
513
514    viewer.window.add_dock_widget(next_image)
515
516    @viewer.bind_key("n", overwrite=True)
517    def _next_image(viewer):
518        next_image(viewer)
519
520    napari.run()
521
522
523# TODO: folder annotator
524# TODO: main function
class ObjectClassifier(PyQt5.QtWidgets.QScrollArea):
123class ObjectClassifier(QtWidgets.QScrollArea):
124
125    def _require_layers(self, layer_choices: Optional[List[str]] = None):
126        # Check whether the image is initialized already. And use the image shape and scale for the layers.
127        state = AnnotatorState()
128        shape = self._shape if state.image_shape is None else state.image_shape
129
130        # Add the label layers for the current object, the automatic segmentation and the committed segmentation.
131        dummy_data = np.zeros(shape, dtype="uint32")
132        image_scale = state.image_scale
133
134        # Before adding new layers, we always check whether a layer with this name already exists or not.
135        if "annotations" not in self._viewer.layers:
136            if layer_choices and "annotations" in layer_choices:
137                widgets._validation_window_for_missing_layer("annotations")
138            annotation_layer = self._viewer.add_labels(data=dummy_data, name="annotations")
139            if image_scale is not None:
140                self._viewer.layers["annotations"].scale = image_scale
141            # Reduce the brush size and set the default mode to "paint" brush mode.
142            annotation_layer.brush_size = 3
143            annotation_layer.mode = "paint"
144
145        if "prediction" not in self._viewer.layers:
146            if layer_choices and "prediction" in layer_choices:
147                widgets._validation_window_for_missing_layer("prediction")
148            self._viewer.add_labels(data=dummy_data, name="prediction")
149            if image_scale is not None:
150                self._viewer.layers["prediction"].scale = image_scale
151
152    def _create_segmentation_layer_section(self):
153        segmentation_selection = QtWidgets.QVBoxLayout()
154        segmentation_layer_widget = QtWidgets.QLabel("Segmentation:")
155        segmentation_selection.addWidget(segmentation_layer_widget)
156        self.segmentation_selection = create_widget(annotation=napari.layers.Labels)
157        state = AnnotatorState()
158        state.segmentation_selection = self.segmentation_selection
159        segmentation_selection.addWidget(self.segmentation_selection.native)
160        return segmentation_selection
161
162    def _create_label_widget(self):
163        self._label_form = QtWidgets.QFormLayout()
164        scroll_area = QtWidgets.QScrollArea()
165        inner = QtWidgets.QWidget()
166        inner.setLayout(self._label_form)
167        scroll_area.setWidget(inner)
168        scroll_area.setWidgetResizable(True)
169
170        layout = QtWidgets.QVBoxLayout()
171        layout.addWidget(QtWidgets.QLabel("Object label names:"))
172        layout.addWidget(scroll_area)
173
174        return layout
175
176    def _refresh_label_widget(self):
177        state = AnnotatorState()
178
179        # Get the current label ids.
180        ids = np.unique(self._viewer.layers["annotations"].data)[1:]
181        if state.previous_labels is not None:
182            ids = np.union1d(ids, np.unique(state.previous_labels))
183
184        # Add new rows.
185        for lbl in ids:
186            if lbl in self._label_names:
187                continue
188            line = QtWidgets.QLineEdit(self._label_names.get(lbl, ""))
189            self._label_names[lbl] = ""
190            self._label_form.addRow(f"ID {lbl}", line)
191            line.textChanged.connect(lambda txt, lbl=lbl: self._label_names.__setitem__(lbl, txt))
192
193        # Remove rows whose label vanished.
194        for row in reversed(range(self._label_form.rowCount())):
195            lbl_text = self._label_form.itemAt(row, QtWidgets.QFormLayout.LabelRole).widget().text()
196            lbl_id = int(lbl_text.split()[1])
197            if lbl_id not in ids:
198                # Remove label+field widgets completely.
199                w_label = self._label_form.itemAt(row, QtWidgets.QFormLayout.LabelRole).widget()
200                w_edit = self._label_form.itemAt(row, QtWidgets.QFormLayout.FieldRole).widget()
201                self._label_form.removeRow(row)
202                w_label.deleteLater()
203                w_edit.deleteLater()
204                self.names.pop(lbl_id, None)
205
206    def _create_widgets(self):
207        # Create the embedding widget and connect all events related to it.
208        self._embedding_widget = widgets.EmbeddingWidget()
209        # Connect events for the image selection box.
210        self._viewer.layers.events.inserted.connect(self._embedding_widget.image_selection.reset_choices)
211        self._viewer.layers.events.removed.connect(self._embedding_widget.image_selection.reset_choices)
212        # Connect the run button with the function to update the image.
213        self._embedding_widget.run_button.clicked.connect(self._update_image)
214
215        # Create the widget for training and prediction of the classifier.
216        self._train_and_predict_widget = _train_and_predict_rf_widget()
217
218        # Create the widget for segmentation selection.
219        self._seg_selection_widget = self._create_segmentation_layer_section()
220
221        # Create the widget for displaying the current label state.
222        self._label_widget = self._create_label_widget()
223
224        # Cretate the widget for exporting the RF.
225        self._export_rf_widget = _create_export_rf_widget()
226
227        self._widgets = {
228            "embeddings": self._embedding_widget,
229            "segmentation_selection": self._seg_selection_widget,
230            "train_and_predict": self._train_and_predict_widget,
231            "label_widget": self._label_widget,
232            "export_rf": self._export_rf_widget,
233        }
234
235    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
236        """Create the GUI for the object classifier.
237
238        Args:
239            viewer: The napari viewer.
240        """
241        super().__init__()
242        self._viewer = viewer
243        self._annotator_widget = QtWidgets.QWidget()
244        self._annotator_widget.setLayout(QtWidgets.QVBoxLayout())
245
246        # Add the layers for prompts and segmented obejcts.
247        # Initialize with a dummy shape, which is reset to the correct shape once an image is set.
248        self._shape = (256, 256)
249        self._require_layers()
250        self._ndim = len(self._shape)
251
252        # Create all the widgets and add them to the layout.
253        self._label_names = {}  # The names for the object labels.
254        self._create_widgets()
255
256        # We could refactor this.
257        for widget_name, widget in self._widgets.items():
258            widget_frame = QtWidgets.QGroupBox()
259            widget_layout = QtWidgets.QVBoxLayout()
260            if isinstance(widget, (Container, FunctionGui, Widget)):
261                # This is a magicgui type and we need to get the native qt widget.
262                widget_layout.addWidget(widget.native)
263            elif isinstance(widget, QtWidgets.QLayout):
264                widget_layout.addLayout(widget)
265            else:
266                # This is a qt type and we add the widget directly.
267                widget_layout.addWidget(widget)
268            widget_frame.setLayout(widget_layout)
269            self._annotator_widget.layout().addWidget(widget_frame)
270
271        # Connect the label layer and the refresh function.
272        self._refresh_label_widget()
273
274        # Set the expected annotator class to the state.
275        state = AnnotatorState()
276        state.annotator = self
277
278        # Add the widgets to the state.
279        state.widgets = self._widgets
280
281        # Add the widget to the scroll area.
282        self.setWidgetResizable(True)  # Allow widget to resize within scroll area.
283        self.setWidget(self._annotator_widget)
284
285    def _update_image(self, segmentation_result=None):
286        state = AnnotatorState()
287
288        # Whether embeddings already exist and avoid clearing objects in layers.
289        if state.skip_recomputing_embeddings:
290            return
291
292        if state.image_shape is None:
293            return
294
295        # Update the dimension and image shape if it has changed.
296        if state.image_shape != self._shape:
297            self._ndim = len(state.image_shape)
298            self._shape = state.image_shape
299
300        # Before we reset the layers, we ensure all expected layers exist.
301        self._require_layers()
302
303        # Update the image scale.
304        scale = state.image_scale
305
306        # Reset all layers.
307        self._viewer.layers["annotations"].data = np.zeros(self._shape, dtype="uint32")
308        self._viewer.layers["annotations"].scale = scale
309        self._viewer.layers["prediction"].data = np.zeros(self._shape, dtype="uint32")
310        self._viewer.layers["prediction"].scale = scale

QScrollArea(parent: Optional[QWidget] = None)

ObjectClassifier(viewer: napari.viewer.Viewer)
235    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
236        """Create the GUI for the object classifier.
237
238        Args:
239            viewer: The napari viewer.
240        """
241        super().__init__()
242        self._viewer = viewer
243        self._annotator_widget = QtWidgets.QWidget()
244        self._annotator_widget.setLayout(QtWidgets.QVBoxLayout())
245
246        # Add the layers for prompts and segmented obejcts.
247        # Initialize with a dummy shape, which is reset to the correct shape once an image is set.
248        self._shape = (256, 256)
249        self._require_layers()
250        self._ndim = len(self._shape)
251
252        # Create all the widgets and add them to the layout.
253        self._label_names = {}  # The names for the object labels.
254        self._create_widgets()
255
256        # We could refactor this.
257        for widget_name, widget in self._widgets.items():
258            widget_frame = QtWidgets.QGroupBox()
259            widget_layout = QtWidgets.QVBoxLayout()
260            if isinstance(widget, (Container, FunctionGui, Widget)):
261                # This is a magicgui type and we need to get the native qt widget.
262                widget_layout.addWidget(widget.native)
263            elif isinstance(widget, QtWidgets.QLayout):
264                widget_layout.addLayout(widget)
265            else:
266                # This is a qt type and we add the widget directly.
267                widget_layout.addWidget(widget)
268            widget_frame.setLayout(widget_layout)
269            self._annotator_widget.layout().addWidget(widget_frame)
270
271        # Connect the label layer and the refresh function.
272        self._refresh_label_widget()
273
274        # Set the expected annotator class to the state.
275        state = AnnotatorState()
276        state.annotator = self
277
278        # Add the widgets to the state.
279        state.widgets = self._widgets
280
281        # Add the widget to the scroll area.
282        self.setWidgetResizable(True)  # Allow widget to resize within scroll area.
283        self.setWidget(self._annotator_widget)

Create the GUI for the object classifier.

Arguments:
  • viewer: The napari viewer.
Inherited Members
PyQt5.QtWidgets.QScrollArea
alignment
ensureVisible
ensureWidgetVisible
event
eventFilter
focusNextPrevChild
resizeEvent
scrollContentsBy
setAlignment
setWidget
setWidgetResizable
sizeHint
takeWidget
viewportSizeHint
widget
widgetResizable
PyQt5.QtWidgets.QAbstractScrollArea
SizeAdjustPolicy
addScrollBarWidget
contextMenuEvent
cornerWidget
dragEnterEvent
dragLeaveEvent
dragMoveEvent
dropEvent
horizontalScrollBar
horizontalScrollBarPolicy
keyPressEvent
maximumViewportSize
minimumSizeHint
mouseDoubleClickEvent
mouseMoveEvent
mousePressEvent
mouseReleaseEvent
paintEvent
scrollBarWidgets
setCornerWidget
setHorizontalScrollBar
setHorizontalScrollBarPolicy
setSizeAdjustPolicy
setVerticalScrollBar
setVerticalScrollBarPolicy
setViewport
setViewportMargins
setupViewport
sizeAdjustPolicy
verticalScrollBar
verticalScrollBarPolicy
viewport
viewportEvent
viewportMargins
wheelEvent
AdjustIgnored
AdjustToContents
AdjustToContentsOnFirstShow
PyQt5.QtWidgets.QFrame
Shadow
Shape
StyleMask
changeEvent
drawFrame
frameRect
frameShadow
frameShape
frameStyle
frameWidth
initStyleOption
lineWidth
midLineWidth
setFrameRect
setFrameShadow
setFrameShape
setFrameStyle
setLineWidth
setMidLineWidth
Box
HLine
NoFrame
Panel
Plain
Raised
Shadow_Mask
Shape_Mask
StyledPanel
Sunken
VLine
WinPanel
PyQt5.QtWidgets.QWidget
RenderFlag
RenderFlags
acceptDrops
accessibleDescription
accessibleName
actionEvent
actions
activateWindow
addAction
addActions
adjustSize
autoFillBackground
backgroundRole
baseSize
childAt
childrenRect
childrenRegion
clearFocus
clearMask
close
closeEvent
contentsMargins
contentsRect
contextMenuPolicy
create
createWindowContainer
cursor
destroy
devType
effectiveWinId
ensurePolished
enterEvent
find
focusInEvent
focusNextChild
focusOutEvent
focusPolicy
focusPreviousChild
focusProxy
focusWidget
font
fontInfo
fontMetrics
foregroundRole
frameGeometry
frameSize
geometry
getContentsMargins
grab
grabGesture
grabKeyboard
grabMouse
grabShortcut
graphicsEffect
graphicsProxyWidget
hasFocus
hasHeightForWidth
hasMouseTracking
hasTabletTracking
height
heightForWidth
hide
hideEvent
initPainter
inputMethodEvent
inputMethodHints
inputMethodQuery
insertAction
insertActions
isActiveWindow
isAncestorOf
isEnabled
isEnabledTo
isFullScreen
isHidden
isLeftToRight
isMaximized
isMinimized
isModal
isRightToLeft
isVisible
isVisibleTo
isWindow
isWindowModified
keyReleaseEvent
keyboardGrabber
layout
layoutDirection
leaveEvent
locale
lower
mapFrom
mapFromGlobal
mapFromParent
mapTo
mapToGlobal
mapToParent
mask
maximumHeight
maximumSize
maximumWidth
metric
minimumHeight
minimumSize
minimumWidth
mouseGrabber
move
moveEvent
nativeEvent
nativeParentWidget
nextInFocusChain
normalGeometry
overrideWindowFlags
overrideWindowState
paintEngine
palette
parentWidget
pos
previousInFocusChain
raise_
rect
releaseKeyboard
releaseMouse
releaseShortcut
removeAction
render
repaint
resize
restoreGeometry
saveGeometry
screen
scroll
setAcceptDrops
setAccessibleDescription
setAccessibleName
setAttribute
setAutoFillBackground
setBackgroundRole
setBaseSize
setContentsMargins
setContextMenuPolicy
setCursor
setDisabled
setEnabled
setFixedHeight
setFixedSize
setFixedWidth
setFocus
setFocusPolicy
setFocusProxy
setFont
setForegroundRole
setGeometry
setGraphicsEffect
setHidden
setInputMethodHints
setLayout
setLayoutDirection
setLocale
setMask
setMaximumHeight
setMaximumSize
setMaximumWidth
setMinimumHeight
setMinimumSize
setMinimumWidth
setMouseTracking
setPalette
setParent
setShortcutAutoRepeat
setShortcutEnabled
setSizeIncrement
setSizePolicy
setStatusTip
setStyle
setStyleSheet
setTabOrder
setTabletTracking
setToolTip
setToolTipDuration
setUpdatesEnabled
setVisible
setWhatsThis
setWindowFilePath
setWindowFlag
setWindowFlags
setWindowIcon
setWindowIconText
setWindowModality
setWindowModified
setWindowOpacity
setWindowRole
setWindowState
setWindowTitle
sharedPainter
show
showEvent
showFullScreen
showMaximized
showMinimized
showNormal
size
sizeIncrement
sizePolicy
stackUnder
statusTip
style
styleSheet
tabletEvent
testAttribute
toolTip
toolTipDuration
underMouse
ungrabGesture
unsetCursor
unsetLayoutDirection
unsetLocale
update
updateGeometry
updateMicroFocus
updatesEnabled
visibleRegion
whatsThis
width
winId
window
windowFilePath
windowFlags
windowHandle
windowIcon
windowIconText
windowModality
windowOpacity
windowRole
windowState
windowTitle
windowType
x
y
DrawChildren
DrawWindowBackground
IgnoreMask
windowIconTextChanged
windowIconChanged
windowTitleChanged
customContextMenuRequested
PyQt5.QtCore.QObject
blockSignals
childEvent
children
connectNotify
customEvent
deleteLater
disconnect
disconnectNotify
dumpObjectInfo
dumpObjectTree
dynamicPropertyNames
findChild
findChildren
inherits
installEventFilter
isSignalConnected
isWidgetType
isWindowType
killTimer
metaObject
moveToThread
objectName
parent
property
pyqtConfigure
receivers
removeEventFilter
sender
senderSignalIndex
setObjectName
setProperty
signalsBlocked
startTimer
thread
timerEvent
tr
staticMetaObject
objectNameChanged
destroyed
PyQt5.QtGui.QPaintDevice
PaintDeviceMetric
colorCount
depth
devicePixelRatio
devicePixelRatioF
devicePixelRatioFScale
heightMM
logicalDpiX
logicalDpiY
paintingActive
physicalDpiX
physicalDpiY
widthMM
PdmDepth
PdmDevicePixelRatio
PdmDevicePixelRatioScaled
PdmDpiX
PdmDpiY
PdmHeight
PdmHeightMM
PdmNumColors
PdmPhysicalDpiX
PdmPhysicalDpiY
PdmWidth
PdmWidthMM
def object_classifier( image: numpy.ndarray, segmentation: numpy.ndarray, embedding_path: Union[str, Dict[str, Any], NoneType] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, ndim: Optional[int] = None) -> Optional[napari.viewer.Viewer]:
313def object_classifier(
314    image: np.ndarray,
315    segmentation: np.ndarray,
316    embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None,
317    model_type: str = util._DEFAULT_MODEL,
318    tile_shape: Optional[Tuple[int, int]] = None,
319    halo: Optional[Tuple[int, int]] = None,
320    return_viewer: bool = False,
321    viewer: Optional["napari.viewer.Viewer"] = None,
322    checkpoint_path: Optional[str] = None,
323    device: Optional[Union[str, torch.device]] = None,
324    ndim: Optional[int] = None,
325) -> Optional["napari.viewer.Viewer"]:
326    """Start the object classifier for a given image and segmentation.
327
328    Args:
329        image: The image data.
330        segmentation: The segmentation data.
331        embedding_path: Filepath where to save the embeddings
332            or the precompted image embeddings computed by `precompute_image_embeddings`.
333        model_type: The Segment Anything model to use. For details on the available models check out
334            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
335        tile_shape: Shape of tiles for tiled embedding prediction.
336            If `None` then the whole image is passed to Segment Anything.
337        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
338        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
339            By default, does not return the napari viewer.
340        viewer: The viewer to which the Segment Anything functionality should be added.
341            This enables using a pre-initialized viewer.
342        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
343        device: The computational device to use for the SAM model.
344            By default, automatically chooses the best available device.
345        ndim: The dimensionality of the data. If not given will be derived from the data.
346
347    Returns:
348        The napari viewer, only returned if `return_viewer=True`.
349    """
350    if ndim is None:
351        ndim = image.ndim - 1 if image.shape[-1] == 3 and image.ndim in (3, 4) else image.ndim
352
353    state = AnnotatorState()
354    state.image_shape = image.shape[:ndim]
355
356    state.initialize_predictor(
357        image, model_type=model_type, save_path=embedding_path,
358        halo=halo, tile_shape=tile_shape, precompute_amg_state=False,
359        ndim=ndim, checkpoint_path=checkpoint_path, device=device,
360        skip_load=False, use_cli=True,
361    )
362
363    if viewer is None:
364        viewer = napari.Viewer()
365
366    viewer.add_image(image, name="image")
367    viewer.add_labels(segmentation, name="segmentation")
368
369    annotator = ObjectClassifier(viewer)
370
371    # Trigger layer update of the annotator so that layers have the correct shape.
372    # And initialize the 'committed_objects' with the segmentation result if it was given.
373    annotator._update_image()
374
375    # Add the annotator widget to the viewer and sync widgets.
376    viewer.window.add_dock_widget(annotator)
377    _sync_embedding_widget(
378        widget=state.widgets["embeddings"],
379        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
380        save_path=embedding_path,
381        checkpoint_path=checkpoint_path,
382        device=device,
383        tile_shape=tile_shape,
384        halo=halo,
385    )
386
387    if return_viewer:
388        return viewer
389
390    napari.run()

Start the object classifier for a given image and segmentation.

Arguments:
  • image: The image data.
  • segmentation: The segmentation data.
  • embedding_path: Filepath where to save the embeddings or the precompted image embeddings computed by precompute_image_embeddings.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
  • ndim: The dimensionality of the data. If not given will be derived from the data.
Returns:

The napari viewer, only returned if return_viewer=True.

def image_series_object_classifier( images: List[numpy.ndarray], segmentations: List[numpy.ndarray], output_folder: str, embedding_paths: Optional[List[Union[str, Dict[str, Any]]]] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, ndim: Optional[int] = None) -> None:
393def image_series_object_classifier(
394    images: List[np.ndarray],
395    segmentations: List[np.ndarray],
396    output_folder: str,
397    embedding_paths: Optional[List[Union[str, util.ImageEmbeddings]]] = None,
398    model_type: str = util._DEFAULT_MODEL,
399    tile_shape: Optional[Tuple[int, int]] = None,
400    halo: Optional[Tuple[int, int]] = None,
401    checkpoint_path: Optional[str] = None,
402    device: Optional[Union[str, torch.device]] = None,
403    ndim: Optional[int] = None,
404) -> None:
405    """Start the object classifier for a list of images and segmentations.
406
407    This function will save the all features and labels for annotated objects,
408    to enable training a random forest on multiple images.
409
410    Args:
411        images: The input images.
412        segmentations: The input segmentations.
413        output_folder: The folder where segmentation results, trained random forest
414            and the features, labels aggregated during training will be saved.
415        embedding_paths: Filepaths where to save the embeddings
416            or the precompted image embeddings computed by `precompute_image_embeddings`.
417        model_type: The Segment Anything model to use. For details on the available models check out
418            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
419        tile_shape: Shape of tiles for tiled embedding prediction.
420            If `None` then the whole image is passed to Segment Anything.
421        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
422        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
423        device: The computational device to use for the SAM model.
424            By default, automatically chooses the best available device.
425        ndim: The dimensionality of the data. If not given will be derived from the data.
426    """
427    # TODO precompute the embeddings if not computed, can re-use 'precompute' from image series annotator.
428    # TODO support file paths as inputs
429    # TODO option to skip segmented
430    if len(images) != len(segmentations):
431        raise ValueError(
432            f"Expect the same number of images and segmentations, got {len(images)}, {len(segmentations)}."
433        )
434
435    end_msg = "You have annotated the last image. Do you wish to close napari?"
436
437    # Initialize the object classifier on the fist image / segmentation.
438    viewer = object_classifier(
439        image=images[0], segmentation=segmentations[0],
440        embedding_path=None if embedding_paths is None else embedding_paths[0],
441        model_type=model_type, tile_shape=tile_shape, halo=halo,
442        return_viewer=True, checkpoint_path=checkpoint_path,
443        device=device, ndim=ndim,
444    )
445
446    os.makedirs(output_folder, exist_ok=True)
447    next_image_id = 0
448
449    def _save_prediction(image, pred, image_id):
450        fname = f"{Path(image).stem}_prediction.tif" if isinstance(image, str) else f"prediction_{image_id}.tif"
451        save_path = os.path.join(output_folder, fname)
452        imageio.imwrite(save_path, pred, compression="zlib")
453
454    # TODO handle cases where rf for the image was not trained, raise a message, enable contnuing
455    # Add functionality for going to the next image.
456    @magicgui(call_button="Next Image [N]")
457    def next_image(*args):
458        nonlocal next_image_id
459
460        # Get the state and the current segmentation (note that next image id has not yet been increased)
461        state = AnnotatorState()
462        segmentation = segmentations[next_image_id]
463
464        # Keep track of the previous features and labels.
465        labels = _accumulate_labels(segmentation, viewer.layers["annotations"].data)
466        valid = labels != 0
467        if valid.sum() > 0:
468            features, labels = state.object_features[valid], labels[valid]
469            if state.previous_features is None:
470                state.previous_features, state.previous_labels = features, labels
471            else:
472                state.previous_features = np.concatenate([state.previous_features, features], axis=0)
473                state.previous_labels = np.concatenate([state.previous_labels, labels], axis=0)
474            # Save the accumulated features and labels.
475            np.save(os.path.join(output_folder, "features.npy"), state.previous_features)
476            np.save(os.path.join(output_folder, "labels.npy"), state.previous_labels)
477
478        # Save the current prediction and RF.
479        _save_prediction(images[next_image_id], viewer.layers["prediction"].data, next_image_id)
480        dump(state.object_rf, os.path.join(output_folder, "rf.joblib"))
481
482        # Go to the next image.
483        next_image_id += 1
484
485        # Check if we are done.
486        if next_image_id == len(images):
487            # Inform the user via dialog.
488            abort = widgets._generate_message("info", end_msg)
489            if not abort:
490                viewer.close()
491            return
492
493        # Get the next image, segmentation and embedding_path.
494        image = images[next_image_id]
495        segmentation = segmentations[next_image_id]
496        embedding_path = None if embedding_paths is None else embedding_paths[next_image_id]
497
498        # Set the new image in the viewer, state and annotator.
499        viewer.layers["image"].data = image
500        viewer.layers["segmentation"].data = segmentation
501
502        state.initialize_predictor(
503            image, model_type=model_type, ndim=ndim,
504            save_path=embedding_path,
505            tile_shape=tile_shape, halo=halo,
506            predictor=state.predictor, device=device,
507        )
508        state.image_shape = image.shape if image.ndim == ndim else image.shape[:-1]
509        state.annotator._update_image()
510
511        # Clear the object features and seg-ids from the state.
512        state.object_features = None
513        state.seg_ids = None
514
515    viewer.window.add_dock_widget(next_image)
516
517    @viewer.bind_key("n", overwrite=True)
518    def _next_image(viewer):
519        next_image(viewer)
520
521    napari.run()

Start the object classifier for a list of images and segmentations.

This function will save the all features and labels for annotated objects, to enable training a random forest on multiple images.

Arguments:
  • images: The input images.
  • segmentations: The input segmentations.
  • output_folder: The folder where segmentation results, trained random forest and the features, labels aggregated during training will be saved.
  • embedding_paths: Filepaths where to save the embeddings or the precompted image embeddings computed by precompute_image_embeddings.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
  • ndim: The dimensionality of the data. If not given will be derived from the data.