micro_sam.sam_annotator.object_classifier

  1import os
  2from joblib import dump
  3from multiprocessing import cpu_count
  4from pathlib import Path
  5from typing import List, Optional, Tuple, Union
  6
  7import imageio.v3 as imageio
  8import napari
  9import numpy as np
 10import torch
 11
 12from magicgui import magic_factory, magicgui
 13from magicgui.widgets import Widget, Container, FunctionGui, create_widget
 14from qtpy import QtWidgets
 15
 16from skimage.measure import regionprops_table
 17from sklearn.ensemble import RandomForestClassifier
 18
 19from .. import util
 20from ..object_classification import compute_object_features, project_prediction_to_segmentation
 21from ._state import AnnotatorState
 22from . import _widgets as widgets
 23from .util import _sync_embedding_widget
 24
 25#
 26# Utility functionality.
 27# Some of this could be refactored to general purpose functionality that can also
 28# be used for inference with the trained classifier.
 29#
 30
 31
 32def _accumulate_labels(segmentation, annotations):
 33
 34    def majority_label(mask, annotation):
 35        ids, counts = np.unique(annotation[mask], return_counts=True)
 36        if len(ids) == 1 and ids[0] == 0:
 37            return 0
 38        if ids[0] == 0:
 39            ids, counts = ids[1:], counts[1:]
 40        return ids[np.argmax(counts)]
 41
 42    all_features = regionprops_table(
 43        segmentation, intensity_image=annotations, properties=("label",),
 44        extra_properties=[majority_label],
 45    )
 46    return all_features["majority_label"].astype("int")
 47
 48
 49def _train_rf(features, labels, previous_features=None, previous_labels=None, **rf_kwargs):
 50    assert len(features) == len(labels)
 51    valid = labels != 0
 52    X, y = features[valid], labels[valid]
 53
 54    if previous_features is not None:
 55        assert previous_labels is not None and len(previous_features) == len(previous_labels)
 56        X = np.concatenate([previous_features, X], axis=0)
 57        y = np.concatenate([previous_labels, y], axis=0)
 58
 59    rf = RandomForestClassifier(**rf_kwargs)
 60    rf.fit(X, y)
 61    return rf
 62
 63
 64# TODO do we add a shortcut?
 65@magic_factory(call_button="Train and predict")
 66def _train_and_predict_rf_widget(viewer: "napari.viewer.Viewer") -> None:
 67    # Get the object features and the annotations.
 68    state = AnnotatorState()
 69    state.annotator._require_layers()
 70    annotations = viewer.layers["annotations"].data
 71    segmentation = state.segmentation_selection.get_value().data
 72
 73    if state.object_features is None:
 74        if widgets._validate_embeddings(viewer):
 75            return None
 76        image_embeddings = state.image_embeddings
 77        seg_ids, features = compute_object_features(image_embeddings, segmentation)
 78        state.seg_ids = seg_ids
 79        state.object_features = features
 80    else:
 81        features, seg_ids = state.object_features, state.seg_ids
 82
 83    previous_features, previous_labels = state.previous_features, state.previous_labels
 84    labels = _accumulate_labels(segmentation, annotations)
 85    if (labels == 0).all() and (previous_labels is None):
 86        return widgets._generate_message("error", "You have not provided any annotations.")
 87
 88    # Run RF training and store it in the state.
 89    rf = _train_rf(
 90        features, labels, previous_features=previous_features, previous_labels=previous_labels,
 91        n_estimators=200, max_depth=10, n_jobs=cpu_count(),
 92    )
 93    state.object_rf = rf
 94
 95    # Run and set the prediction.
 96    pred = rf.predict(features)
 97    prediction_data = project_prediction_to_segmentation(segmentation, pred, seg_ids)
 98    viewer.layers["prediction"].data = prediction_data
 99
100    state.annotator._refresh_label_widget()
101
102
103@magic_factory(call_button="Export Classifier")
104def _create_export_rf_widget(export_path: Optional[Path] = None) -> None:
105    state = AnnotatorState()
106    rf = state.object_rf
107    if rf is None:
108        return widgets._generate_message("error", "You have not run training yet.")
109    if export_path is None or export_path == "":
110        return widgets._generate_message("error", "You have to provide an export path.")
111    # Do we add an extension? .joblib?
112    dump(rf, export_path)
113    # TODO show an info method about the export
114
115#
116# Object classifier implementation.
117#
118
119
120# TODO add a gui element that shows the current label ids, how many objects are labeled, and that
121# enables naming them so that the user can keep track of what has been labeled
122class ObjectClassifier(QtWidgets.QScrollArea):
123
124    def _require_layers(self, layer_choices: Optional[List[str]] = None):
125        # Check whether the image is initialized already. And use the image shape and scale for the layers.
126        state = AnnotatorState()
127        shape = self._shape if state.image_shape is None else state.image_shape
128
129        # Add the label layers for the current object, the automatic segmentation and the committed segmentation.
130        dummy_data = np.zeros(shape, dtype="uint32")
131        image_scale = state.image_scale
132
133        # Before adding new layers, we always check whether a layer with this name already exists or not.
134        if "annotations" not in self._viewer.layers:
135            if layer_choices and "annotations" in layer_choices:
136                widgets._validation_window_for_missing_layer("annotations")
137            annotation_layer = self._viewer.add_labels(data=dummy_data, name="annotations")
138            if image_scale is not None:
139                self._viewer.layers["annotations"].scale = image_scale
140            # Reduce the brush size and set the default mode to "paint" brush mode.
141            annotation_layer.brush_size = 3
142            annotation_layer.mode = "paint"
143
144        if "prediction" not in self._viewer.layers:
145            if layer_choices and "prediction" in layer_choices:
146                widgets._validation_window_for_missing_layer("prediction")
147            self._viewer.add_labels(data=dummy_data, name="prediction")
148            if image_scale is not None:
149                self._viewer.layers["prediction"].scale = image_scale
150
151    def _create_segmentation_layer_section(self):
152        segmentation_selection = QtWidgets.QVBoxLayout()
153        segmentation_layer_widget = QtWidgets.QLabel("Segmentation:")
154        segmentation_selection.addWidget(segmentation_layer_widget)
155        self.segmentation_selection = create_widget(annotation=napari.layers.Labels)
156        state = AnnotatorState()
157        state.segmentation_selection = self.segmentation_selection
158        segmentation_selection.addWidget(self.segmentation_selection.native)
159        return segmentation_selection
160
161    def _create_label_widget(self):
162        self._label_form = QtWidgets.QFormLayout()
163        scroll_area = QtWidgets.QScrollArea()
164        inner = QtWidgets.QWidget()
165        inner.setLayout(self._label_form)
166        scroll_area.setWidget(inner)
167        scroll_area.setWidgetResizable(True)
168
169        layout = QtWidgets.QVBoxLayout()
170        layout.addWidget(QtWidgets.QLabel("Object label names:"))
171        layout.addWidget(scroll_area)
172
173        return layout
174
175    def _refresh_label_widget(self):
176        state = AnnotatorState()
177
178        # Get the current label ids.
179        ids = np.unique(self._viewer.layers["annotations"].data)[1:]
180        if state.previous_labels is not None:
181            ids = np.union1d(ids, np.unique(state.previous_labels))
182
183        # Add new rows.
184        for lbl in ids:
185            if lbl in self._label_names:
186                continue
187            line = QtWidgets.QLineEdit(self._label_names.get(lbl, ""))
188            self._label_names[lbl] = ""
189            self._label_form.addRow(f"ID {lbl}", line)
190            line.textChanged.connect(lambda txt, lbl=lbl: self._label_names.__setitem__(lbl, txt))
191
192        # Remove rows whose label vanished.
193        for row in reversed(range(self._label_form.rowCount())):
194            lbl_text = self._label_form.itemAt(row, QtWidgets.QFormLayout.LabelRole).widget().text()
195            lbl_id = int(lbl_text.split()[1])
196            if lbl_id not in ids:
197                # Remove label+field widgets completely.
198                w_label = self._label_form.itemAt(row, QtWidgets.QFormLayout.LabelRole).widget()
199                w_edit = self._label_form.itemAt(row, QtWidgets.QFormLayout.FieldRole).widget()
200                self._label_form.removeRow(row)
201                w_label.deleteLater()
202                w_edit.deleteLater()
203                self.names.pop(lbl_id, None)
204
205    def _create_widgets(self):
206        # Create the embedding widget and connect all events related to it.
207        self._embedding_widget = widgets.EmbeddingWidget()
208        # Connect events for the image selection box.
209        self._viewer.layers.events.inserted.connect(self._embedding_widget.image_selection.reset_choices)
210        self._viewer.layers.events.removed.connect(self._embedding_widget.image_selection.reset_choices)
211        # Connect the run button with the function to update the image.
212        self._embedding_widget.run_button.clicked.connect(self._update_image)
213
214        # Create the widget for training and prediction of the classifier.
215        self._train_and_predict_widget = _train_and_predict_rf_widget()
216
217        # Create the widget for segmentation selection.
218        self._seg_selection_widget = self._create_segmentation_layer_section()
219
220        # Create the widget for displaying the current label state.
221        self._label_widget = self._create_label_widget()
222
223        # Cretate the widget for exporting the RF.
224        self._export_rf_widget = _create_export_rf_widget()
225
226        self._widgets = {
227            "embeddings": self._embedding_widget,
228            "segmentation_selection": self._seg_selection_widget,
229            "train_and_predict": self._train_and_predict_widget,
230            "label_widget": self._label_widget,
231            "export_rf": self._export_rf_widget,
232        }
233
234    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
235        """Create the GUI for the object classifier.
236
237        Args:
238            viewer: The napari viewer.
239        """
240        super().__init__()
241        self._viewer = viewer
242        self._annotator_widget = QtWidgets.QWidget()
243        self._annotator_widget.setLayout(QtWidgets.QVBoxLayout())
244
245        # Add the layers for prompts and segmented obejcts.
246        # Initialize with a dummy shape, which is reset to the correct shape once an image is set.
247        self._shape = (256, 256)
248        self._require_layers()
249        self._ndim = len(self._shape)
250
251        # Create all the widgets and add them to the layout.
252        self._label_names = {}  # The names for the object labels.
253        self._create_widgets()
254
255        # We could refactor this.
256        for widget_name, widget in self._widgets.items():
257            widget_frame = QtWidgets.QGroupBox()
258            widget_layout = QtWidgets.QVBoxLayout()
259            if isinstance(widget, (Container, FunctionGui, Widget)):
260                # This is a magicgui type and we need to get the native qt widget.
261                widget_layout.addWidget(widget.native)
262            elif isinstance(widget, QtWidgets.QLayout):
263                widget_layout.addLayout(widget)
264            else:
265                # This is a qt type and we add the widget directly.
266                widget_layout.addWidget(widget)
267            widget_frame.setLayout(widget_layout)
268            self._annotator_widget.layout().addWidget(widget_frame)
269
270        # Connect the label layer and the refresh function.
271        self._refresh_label_widget()
272
273        # Set the expected annotator class to the state.
274        state = AnnotatorState()
275        state.annotator = self
276
277        # Add the widgets to the state.
278        state.widgets = self._widgets
279
280        # Add the widget to the scroll area.
281        self.setWidgetResizable(True)  # Allow widget to resize within scroll area.
282        self.setWidget(self._annotator_widget)
283
284    def _update_image(self, segmentation_result=None):
285        state = AnnotatorState()
286
287        # Whether embeddings already exist and avoid clearing objects in layers.
288        if state.skip_recomputing_embeddings:
289            return
290
291        if state.image_shape is None:
292            return
293
294        # Update the dimension and image shape if it has changed.
295        if state.image_shape != self._shape:
296            self._ndim = len(state.image_shape)
297            self._shape = state.image_shape
298
299        # Before we reset the layers, we ensure all expected layers exist.
300        self._require_layers()
301
302        # Update the image scale.
303        scale = state.image_scale
304
305        # Reset all layers.
306        self._viewer.layers["annotations"].data = np.zeros(self._shape, dtype="uint32")
307        self._viewer.layers["annotations"].scale = scale
308        self._viewer.layers["prediction"].data = np.zeros(self._shape, dtype="uint32")
309        self._viewer.layers["prediction"].scale = scale
310
311
312def object_classifier(
313    image: np.ndarray,
314    segmentation: np.ndarray,
315    embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None,
316    model_type: str = util._DEFAULT_MODEL,
317    tile_shape: Optional[Tuple[int, int]] = None,
318    halo: Optional[Tuple[int, int]] = None,
319    return_viewer: bool = False,
320    viewer: Optional["napari.viewer.Viewer"] = None,
321    checkpoint_path: Optional[str] = None,
322    device: Optional[Union[str, torch.device]] = None,
323    ndim: Optional[int] = None,
324) -> Optional["napari.viewer.Viewer"]:
325    """Start the object classifier for a given image and segmentation.
326
327    Args:
328        image: The image data.
329        segmentation: The segmentation data.
330        embedding_path: Filepath where to save the embeddings
331            or the precompted image embeddings computed by `precompute_image_embeddings`.
332        model_type: The Segment Anything model to use. For details on the available models check out
333            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
334        tile_shape: Shape of tiles for tiled embedding prediction.
335            If `None` then the whole image is passed to Segment Anything.
336        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
337        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
338            By default, does not return the napari viewer.
339        viewer: The viewer to which the Segment Anything functionality should be added.
340            This enables using a pre-initialized viewer.
341        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
342        device: The computational device to use for the SAM model.
343            By default, automatically chooses the best available device.
344        ndim: The dimensionality of the data. If not given will be derived from the data.
345
346    Returns:
347        The napari viewer, only returned if `return_viewer=True`.
348    """
349    if ndim is None:
350        ndim = image.ndim - 1 if image.shape[-1] == 3 and image.ndim in (3, 4) else image.ndim
351
352    state = AnnotatorState()
353    state.image_shape = image.shape[:ndim]
354
355    state.initialize_predictor(
356        image, model_type=model_type, save_path=embedding_path,
357        halo=halo, tile_shape=tile_shape, precompute_amg_state=False,
358        ndim=ndim, checkpoint_path=checkpoint_path, device=device,
359        skip_load=False, use_cli=True,
360    )
361
362    if viewer is None:
363        viewer = napari.Viewer()
364
365    viewer.add_image(image, name="image")
366    viewer.add_labels(segmentation, name="segmentation")
367
368    annotator = ObjectClassifier(viewer)
369
370    # Trigger layer update of the annotator so that layers have the correct shape.
371    # And initialize the 'committed_objects' with the segmentation result if it was given.
372    annotator._update_image()
373
374    # Add the annotator widget to the viewer and sync widgets.
375    viewer.window.add_dock_widget(annotator)
376    _sync_embedding_widget(
377        widget=state.widgets["embeddings"],
378        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
379        save_path=embedding_path,
380        checkpoint_path=checkpoint_path,
381        device=device,
382        tile_shape=tile_shape,
383        halo=halo,
384    )
385
386    if return_viewer:
387        return viewer
388
389    napari.run()
390
391
392def image_series_object_classifier(
393    images: List[np.ndarray],
394    segmentations: List[np.ndarray],
395    output_folder: str,
396    embedding_paths: Optional[List[Union[str, util.ImageEmbeddings]]] = None,
397    model_type: str = util._DEFAULT_MODEL,
398    tile_shape: Optional[Tuple[int, int]] = None,
399    halo: Optional[Tuple[int, int]] = None,
400    checkpoint_path: Optional[str] = None,
401    device: Optional[Union[str, torch.device]] = None,
402    ndim: Optional[int] = None,
403) -> None:
404    """Start the object classifier for a list of images and segmentations.
405
406    This function will save the all features and labels for annotated objects,
407    to enable training a random forest on multiple images.
408
409    Args:
410        images: The input images.
411        segmentations: The input segmentations.
412        output_folder: The folder where segmentation results, trained random forest
413            and the features, labels aggregated during training will be saved.
414        embedding_paths: Filepaths where to save the embeddings
415            or the precompted image embeddings computed by `precompute_image_embeddings`.
416        model_type: The Segment Anything model to use. For details on the available models check out
417            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
418        tile_shape: Shape of tiles for tiled embedding prediction.
419            If `None` then the whole image is passed to Segment Anything.
420        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
421        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
422        device: The computational device to use for the SAM model.
423            By default, automatically chooses the best available device.
424        ndim: The dimensionality of the data. If not given will be derived from the data.
425    """
426    # TODO precompute the embeddings if not computed, can re-use 'precompute' from image series annotator.
427    # TODO support file paths as inputs
428    # TODO option to skip segmented
429    if len(images) != len(segmentations):
430        raise ValueError(
431            f"Expect the same number of images and segmentations, got {len(images)}, {len(segmentations)}."
432        )
433
434    end_msg = "You have annotated the last image. Do you wish to close napari?"
435
436    # Initialize the object classifier on the fist image / segmentation.
437    viewer = object_classifier(
438        image=images[0], segmentation=segmentations[0],
439        embedding_path=None if embedding_paths is None else embedding_paths[0],
440        model_type=model_type, tile_shape=tile_shape, halo=halo,
441        return_viewer=True, checkpoint_path=checkpoint_path,
442        device=device, ndim=ndim,
443    )
444
445    os.makedirs(output_folder, exist_ok=True)
446    next_image_id = 0
447
448    def _save_prediction(image, pred, image_id):
449        fname = f"{Path(image).stem}_prediction.tif" if isinstance(image, str) else f"prediction_{image_id}.tif"
450        save_path = os.path.join(output_folder, fname)
451        imageio.imwrite(save_path, pred, compression="zlib")
452
453    # TODO handle cases where rf for the image was not trained, raise a message, enable contnuing
454    # Add functionality for going to the next image.
455    @magicgui(call_button="Next Image [N]")
456    def next_image(*args):
457        nonlocal next_image_id
458
459        # Get the state and the current segmentation (note that next image id has not yet been increased)
460        state = AnnotatorState()
461        segmentation = segmentations[next_image_id]
462
463        # Keep track of the previous features and labels.
464        labels = _accumulate_labels(segmentation, viewer.layers["annotations"].data)
465        valid = labels != 0
466        if valid.sum() > 0:
467            features, labels = state.object_features[valid], labels[valid]
468            if state.previous_features is None:
469                state.previous_features, state.previous_labels = features, labels
470            else:
471                state.previous_features = np.concatenate([state.previous_features, features], axis=0)
472                state.previous_labels = np.concatenate([state.previous_labels, labels], axis=0)
473            # Save the accumulated features and labels.
474            np.save(os.path.join(output_folder, "features.npy"), state.previous_features)
475            np.save(os.path.join(output_folder, "labels.npy"), state.previous_labels)
476
477        # Save the current prediction and RF.
478        _save_prediction(images[next_image_id], viewer.layers["prediction"].data, next_image_id)
479        dump(state.object_rf, os.path.join(output_folder, "rf.joblib"))
480
481        # Go to the next image.
482        next_image_id += 1
483
484        # Check if we are done.
485        if next_image_id == len(images):
486            # Inform the user via dialog.
487            abort = widgets._generate_message("info", end_msg)
488            if not abort:
489                viewer.close()
490            return
491
492        # Get the next image, segmentation and embedding_path.
493        image = images[next_image_id]
494        segmentation = segmentations[next_image_id]
495        embedding_path = None if embedding_paths is None else embedding_paths[next_image_id]
496
497        # Set the new image in the viewer, state and annotator.
498        viewer.layers["image"].data = image
499        viewer.layers["segmentation"].data = segmentation
500
501        state.initialize_predictor(
502            image, model_type=model_type, ndim=ndim,
503            save_path=embedding_path,
504            tile_shape=tile_shape, halo=halo,
505            predictor=state.predictor, device=device,
506        )
507        state.image_shape = image.shape if image.ndim == ndim else image.shape[:-1]
508        state.annotator._update_image()
509
510        # Clear the object features and seg-ids from the state.
511        state.object_features = None
512        state.seg_ids = None
513
514    viewer.window.add_dock_widget(next_image)
515
516    @viewer.bind_key("n", overwrite=True)
517    def _next_image(viewer):
518        next_image(viewer)
519
520    napari.run()
521
522
523# TODO: folder annotator
524# TODO: main function
class ObjectClassifier(PyQt6.QtWidgets.QScrollArea):
123class ObjectClassifier(QtWidgets.QScrollArea):
124
125    def _require_layers(self, layer_choices: Optional[List[str]] = None):
126        # Check whether the image is initialized already. And use the image shape and scale for the layers.
127        state = AnnotatorState()
128        shape = self._shape if state.image_shape is None else state.image_shape
129
130        # Add the label layers for the current object, the automatic segmentation and the committed segmentation.
131        dummy_data = np.zeros(shape, dtype="uint32")
132        image_scale = state.image_scale
133
134        # Before adding new layers, we always check whether a layer with this name already exists or not.
135        if "annotations" not in self._viewer.layers:
136            if layer_choices and "annotations" in layer_choices:
137                widgets._validation_window_for_missing_layer("annotations")
138            annotation_layer = self._viewer.add_labels(data=dummy_data, name="annotations")
139            if image_scale is not None:
140                self._viewer.layers["annotations"].scale = image_scale
141            # Reduce the brush size and set the default mode to "paint" brush mode.
142            annotation_layer.brush_size = 3
143            annotation_layer.mode = "paint"
144
145        if "prediction" not in self._viewer.layers:
146            if layer_choices and "prediction" in layer_choices:
147                widgets._validation_window_for_missing_layer("prediction")
148            self._viewer.add_labels(data=dummy_data, name="prediction")
149            if image_scale is not None:
150                self._viewer.layers["prediction"].scale = image_scale
151
152    def _create_segmentation_layer_section(self):
153        segmentation_selection = QtWidgets.QVBoxLayout()
154        segmentation_layer_widget = QtWidgets.QLabel("Segmentation:")
155        segmentation_selection.addWidget(segmentation_layer_widget)
156        self.segmentation_selection = create_widget(annotation=napari.layers.Labels)
157        state = AnnotatorState()
158        state.segmentation_selection = self.segmentation_selection
159        segmentation_selection.addWidget(self.segmentation_selection.native)
160        return segmentation_selection
161
162    def _create_label_widget(self):
163        self._label_form = QtWidgets.QFormLayout()
164        scroll_area = QtWidgets.QScrollArea()
165        inner = QtWidgets.QWidget()
166        inner.setLayout(self._label_form)
167        scroll_area.setWidget(inner)
168        scroll_area.setWidgetResizable(True)
169
170        layout = QtWidgets.QVBoxLayout()
171        layout.addWidget(QtWidgets.QLabel("Object label names:"))
172        layout.addWidget(scroll_area)
173
174        return layout
175
176    def _refresh_label_widget(self):
177        state = AnnotatorState()
178
179        # Get the current label ids.
180        ids = np.unique(self._viewer.layers["annotations"].data)[1:]
181        if state.previous_labels is not None:
182            ids = np.union1d(ids, np.unique(state.previous_labels))
183
184        # Add new rows.
185        for lbl in ids:
186            if lbl in self._label_names:
187                continue
188            line = QtWidgets.QLineEdit(self._label_names.get(lbl, ""))
189            self._label_names[lbl] = ""
190            self._label_form.addRow(f"ID {lbl}", line)
191            line.textChanged.connect(lambda txt, lbl=lbl: self._label_names.__setitem__(lbl, txt))
192
193        # Remove rows whose label vanished.
194        for row in reversed(range(self._label_form.rowCount())):
195            lbl_text = self._label_form.itemAt(row, QtWidgets.QFormLayout.LabelRole).widget().text()
196            lbl_id = int(lbl_text.split()[1])
197            if lbl_id not in ids:
198                # Remove label+field widgets completely.
199                w_label = self._label_form.itemAt(row, QtWidgets.QFormLayout.LabelRole).widget()
200                w_edit = self._label_form.itemAt(row, QtWidgets.QFormLayout.FieldRole).widget()
201                self._label_form.removeRow(row)
202                w_label.deleteLater()
203                w_edit.deleteLater()
204                self.names.pop(lbl_id, None)
205
206    def _create_widgets(self):
207        # Create the embedding widget and connect all events related to it.
208        self._embedding_widget = widgets.EmbeddingWidget()
209        # Connect events for the image selection box.
210        self._viewer.layers.events.inserted.connect(self._embedding_widget.image_selection.reset_choices)
211        self._viewer.layers.events.removed.connect(self._embedding_widget.image_selection.reset_choices)
212        # Connect the run button with the function to update the image.
213        self._embedding_widget.run_button.clicked.connect(self._update_image)
214
215        # Create the widget for training and prediction of the classifier.
216        self._train_and_predict_widget = _train_and_predict_rf_widget()
217
218        # Create the widget for segmentation selection.
219        self._seg_selection_widget = self._create_segmentation_layer_section()
220
221        # Create the widget for displaying the current label state.
222        self._label_widget = self._create_label_widget()
223
224        # Cretate the widget for exporting the RF.
225        self._export_rf_widget = _create_export_rf_widget()
226
227        self._widgets = {
228            "embeddings": self._embedding_widget,
229            "segmentation_selection": self._seg_selection_widget,
230            "train_and_predict": self._train_and_predict_widget,
231            "label_widget": self._label_widget,
232            "export_rf": self._export_rf_widget,
233        }
234
235    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
236        """Create the GUI for the object classifier.
237
238        Args:
239            viewer: The napari viewer.
240        """
241        super().__init__()
242        self._viewer = viewer
243        self._annotator_widget = QtWidgets.QWidget()
244        self._annotator_widget.setLayout(QtWidgets.QVBoxLayout())
245
246        # Add the layers for prompts and segmented obejcts.
247        # Initialize with a dummy shape, which is reset to the correct shape once an image is set.
248        self._shape = (256, 256)
249        self._require_layers()
250        self._ndim = len(self._shape)
251
252        # Create all the widgets and add them to the layout.
253        self._label_names = {}  # The names for the object labels.
254        self._create_widgets()
255
256        # We could refactor this.
257        for widget_name, widget in self._widgets.items():
258            widget_frame = QtWidgets.QGroupBox()
259            widget_layout = QtWidgets.QVBoxLayout()
260            if isinstance(widget, (Container, FunctionGui, Widget)):
261                # This is a magicgui type and we need to get the native qt widget.
262                widget_layout.addWidget(widget.native)
263            elif isinstance(widget, QtWidgets.QLayout):
264                widget_layout.addLayout(widget)
265            else:
266                # This is a qt type and we add the widget directly.
267                widget_layout.addWidget(widget)
268            widget_frame.setLayout(widget_layout)
269            self._annotator_widget.layout().addWidget(widget_frame)
270
271        # Connect the label layer and the refresh function.
272        self._refresh_label_widget()
273
274        # Set the expected annotator class to the state.
275        state = AnnotatorState()
276        state.annotator = self
277
278        # Add the widgets to the state.
279        state.widgets = self._widgets
280
281        # Add the widget to the scroll area.
282        self.setWidgetResizable(True)  # Allow widget to resize within scroll area.
283        self.setWidget(self._annotator_widget)
284
285    def _update_image(self, segmentation_result=None):
286        state = AnnotatorState()
287
288        # Whether embeddings already exist and avoid clearing objects in layers.
289        if state.skip_recomputing_embeddings:
290            return
291
292        if state.image_shape is None:
293            return
294
295        # Update the dimension and image shape if it has changed.
296        if state.image_shape != self._shape:
297            self._ndim = len(state.image_shape)
298            self._shape = state.image_shape
299
300        # Before we reset the layers, we ensure all expected layers exist.
301        self._require_layers()
302
303        # Update the image scale.
304        scale = state.image_scale
305
306        # Reset all layers.
307        self._viewer.layers["annotations"].data = np.zeros(self._shape, dtype="uint32")
308        self._viewer.layers["annotations"].scale = scale
309        self._viewer.layers["prediction"].data = np.zeros(self._shape, dtype="uint32")
310        self._viewer.layers["prediction"].scale = scale

QScrollArea(parent: Optional[QWidget] = None)

ObjectClassifier(viewer: napari.viewer.Viewer)
235    def __init__(self, viewer: "napari.viewer.Viewer") -> None:
236        """Create the GUI for the object classifier.
237
238        Args:
239            viewer: The napari viewer.
240        """
241        super().__init__()
242        self._viewer = viewer
243        self._annotator_widget = QtWidgets.QWidget()
244        self._annotator_widget.setLayout(QtWidgets.QVBoxLayout())
245
246        # Add the layers for prompts and segmented obejcts.
247        # Initialize with a dummy shape, which is reset to the correct shape once an image is set.
248        self._shape = (256, 256)
249        self._require_layers()
250        self._ndim = len(self._shape)
251
252        # Create all the widgets and add them to the layout.
253        self._label_names = {}  # The names for the object labels.
254        self._create_widgets()
255
256        # We could refactor this.
257        for widget_name, widget in self._widgets.items():
258            widget_frame = QtWidgets.QGroupBox()
259            widget_layout = QtWidgets.QVBoxLayout()
260            if isinstance(widget, (Container, FunctionGui, Widget)):
261                # This is a magicgui type and we need to get the native qt widget.
262                widget_layout.addWidget(widget.native)
263            elif isinstance(widget, QtWidgets.QLayout):
264                widget_layout.addLayout(widget)
265            else:
266                # This is a qt type and we add the widget directly.
267                widget_layout.addWidget(widget)
268            widget_frame.setLayout(widget_layout)
269            self._annotator_widget.layout().addWidget(widget_frame)
270
271        # Connect the label layer and the refresh function.
272        self._refresh_label_widget()
273
274        # Set the expected annotator class to the state.
275        state = AnnotatorState()
276        state.annotator = self
277
278        # Add the widgets to the state.
279        state.widgets = self._widgets
280
281        # Add the widget to the scroll area.
282        self.setWidgetResizable(True)  # Allow widget to resize within scroll area.
283        self.setWidget(self._annotator_widget)

Create the GUI for the object classifier.

Arguments:
  • viewer: The napari viewer.
def object_classifier( image: numpy.ndarray, segmentation: numpy.ndarray, embedding_path: Union[str, Dict[str, Any], NoneType] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional[napari.viewer.Viewer] = None, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, ndim: Optional[int] = None) -> Optional[napari.viewer.Viewer]:
313def object_classifier(
314    image: np.ndarray,
315    segmentation: np.ndarray,
316    embedding_path: Optional[Union[str, util.ImageEmbeddings]] = None,
317    model_type: str = util._DEFAULT_MODEL,
318    tile_shape: Optional[Tuple[int, int]] = None,
319    halo: Optional[Tuple[int, int]] = None,
320    return_viewer: bool = False,
321    viewer: Optional["napari.viewer.Viewer"] = None,
322    checkpoint_path: Optional[str] = None,
323    device: Optional[Union[str, torch.device]] = None,
324    ndim: Optional[int] = None,
325) -> Optional["napari.viewer.Viewer"]:
326    """Start the object classifier for a given image and segmentation.
327
328    Args:
329        image: The image data.
330        segmentation: The segmentation data.
331        embedding_path: Filepath where to save the embeddings
332            or the precompted image embeddings computed by `precompute_image_embeddings`.
333        model_type: The Segment Anything model to use. For details on the available models check out
334            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
335        tile_shape: Shape of tiles for tiled embedding prediction.
336            If `None` then the whole image is passed to Segment Anything.
337        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
338        return_viewer: Whether to return the napari viewer to further modify it before starting the tool.
339            By default, does not return the napari viewer.
340        viewer: The viewer to which the Segment Anything functionality should be added.
341            This enables using a pre-initialized viewer.
342        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
343        device: The computational device to use for the SAM model.
344            By default, automatically chooses the best available device.
345        ndim: The dimensionality of the data. If not given will be derived from the data.
346
347    Returns:
348        The napari viewer, only returned if `return_viewer=True`.
349    """
350    if ndim is None:
351        ndim = image.ndim - 1 if image.shape[-1] == 3 and image.ndim in (3, 4) else image.ndim
352
353    state = AnnotatorState()
354    state.image_shape = image.shape[:ndim]
355
356    state.initialize_predictor(
357        image, model_type=model_type, save_path=embedding_path,
358        halo=halo, tile_shape=tile_shape, precompute_amg_state=False,
359        ndim=ndim, checkpoint_path=checkpoint_path, device=device,
360        skip_load=False, use_cli=True,
361    )
362
363    if viewer is None:
364        viewer = napari.Viewer()
365
366    viewer.add_image(image, name="image")
367    viewer.add_labels(segmentation, name="segmentation")
368
369    annotator = ObjectClassifier(viewer)
370
371    # Trigger layer update of the annotator so that layers have the correct shape.
372    # And initialize the 'committed_objects' with the segmentation result if it was given.
373    annotator._update_image()
374
375    # Add the annotator widget to the viewer and sync widgets.
376    viewer.window.add_dock_widget(annotator)
377    _sync_embedding_widget(
378        widget=state.widgets["embeddings"],
379        model_type=model_type if checkpoint_path is None else state.predictor.model_type,
380        save_path=embedding_path,
381        checkpoint_path=checkpoint_path,
382        device=device,
383        tile_shape=tile_shape,
384        halo=halo,
385    )
386
387    if return_viewer:
388        return viewer
389
390    napari.run()

Start the object classifier for a given image and segmentation.

Arguments:
  • image: The image data.
  • segmentation: The segmentation data.
  • embedding_path: Filepath where to save the embeddings or the precompted image embeddings computed by precompute_image_embeddings.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
  • return_viewer: Whether to return the napari viewer to further modify it before starting the tool. By default, does not return the napari viewer.
  • viewer: The viewer to which the Segment Anything functionality should be added. This enables using a pre-initialized viewer.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
  • ndim: The dimensionality of the data. If not given will be derived from the data.
Returns:

The napari viewer, only returned if return_viewer=True.

def image_series_object_classifier( images: List[numpy.ndarray], segmentations: List[numpy.ndarray], output_folder: str, embedding_paths: Optional[List[Union[str, Dict[str, Any]]]] = None, model_type: str = 'vit_b_lm', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, checkpoint_path: Optional[str] = None, device: Union[str, torch.device, NoneType] = None, ndim: Optional[int] = None) -> None:
393def image_series_object_classifier(
394    images: List[np.ndarray],
395    segmentations: List[np.ndarray],
396    output_folder: str,
397    embedding_paths: Optional[List[Union[str, util.ImageEmbeddings]]] = None,
398    model_type: str = util._DEFAULT_MODEL,
399    tile_shape: Optional[Tuple[int, int]] = None,
400    halo: Optional[Tuple[int, int]] = None,
401    checkpoint_path: Optional[str] = None,
402    device: Optional[Union[str, torch.device]] = None,
403    ndim: Optional[int] = None,
404) -> None:
405    """Start the object classifier for a list of images and segmentations.
406
407    This function will save the all features and labels for annotated objects,
408    to enable training a random forest on multiple images.
409
410    Args:
411        images: The input images.
412        segmentations: The input segmentations.
413        output_folder: The folder where segmentation results, trained random forest
414            and the features, labels aggregated during training will be saved.
415        embedding_paths: Filepaths where to save the embeddings
416            or the precompted image embeddings computed by `precompute_image_embeddings`.
417        model_type: The Segment Anything model to use. For details on the available models check out
418            https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
419        tile_shape: Shape of tiles for tiled embedding prediction.
420            If `None` then the whole image is passed to Segment Anything.
421        halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
422        checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
423        device: The computational device to use for the SAM model.
424            By default, automatically chooses the best available device.
425        ndim: The dimensionality of the data. If not given will be derived from the data.
426    """
427    # TODO precompute the embeddings if not computed, can re-use 'precompute' from image series annotator.
428    # TODO support file paths as inputs
429    # TODO option to skip segmented
430    if len(images) != len(segmentations):
431        raise ValueError(
432            f"Expect the same number of images and segmentations, got {len(images)}, {len(segmentations)}."
433        )
434
435    end_msg = "You have annotated the last image. Do you wish to close napari?"
436
437    # Initialize the object classifier on the fist image / segmentation.
438    viewer = object_classifier(
439        image=images[0], segmentation=segmentations[0],
440        embedding_path=None if embedding_paths is None else embedding_paths[0],
441        model_type=model_type, tile_shape=tile_shape, halo=halo,
442        return_viewer=True, checkpoint_path=checkpoint_path,
443        device=device, ndim=ndim,
444    )
445
446    os.makedirs(output_folder, exist_ok=True)
447    next_image_id = 0
448
449    def _save_prediction(image, pred, image_id):
450        fname = f"{Path(image).stem}_prediction.tif" if isinstance(image, str) else f"prediction_{image_id}.tif"
451        save_path = os.path.join(output_folder, fname)
452        imageio.imwrite(save_path, pred, compression="zlib")
453
454    # TODO handle cases where rf for the image was not trained, raise a message, enable contnuing
455    # Add functionality for going to the next image.
456    @magicgui(call_button="Next Image [N]")
457    def next_image(*args):
458        nonlocal next_image_id
459
460        # Get the state and the current segmentation (note that next image id has not yet been increased)
461        state = AnnotatorState()
462        segmentation = segmentations[next_image_id]
463
464        # Keep track of the previous features and labels.
465        labels = _accumulate_labels(segmentation, viewer.layers["annotations"].data)
466        valid = labels != 0
467        if valid.sum() > 0:
468            features, labels = state.object_features[valid], labels[valid]
469            if state.previous_features is None:
470                state.previous_features, state.previous_labels = features, labels
471            else:
472                state.previous_features = np.concatenate([state.previous_features, features], axis=0)
473                state.previous_labels = np.concatenate([state.previous_labels, labels], axis=0)
474            # Save the accumulated features and labels.
475            np.save(os.path.join(output_folder, "features.npy"), state.previous_features)
476            np.save(os.path.join(output_folder, "labels.npy"), state.previous_labels)
477
478        # Save the current prediction and RF.
479        _save_prediction(images[next_image_id], viewer.layers["prediction"].data, next_image_id)
480        dump(state.object_rf, os.path.join(output_folder, "rf.joblib"))
481
482        # Go to the next image.
483        next_image_id += 1
484
485        # Check if we are done.
486        if next_image_id == len(images):
487            # Inform the user via dialog.
488            abort = widgets._generate_message("info", end_msg)
489            if not abort:
490                viewer.close()
491            return
492
493        # Get the next image, segmentation and embedding_path.
494        image = images[next_image_id]
495        segmentation = segmentations[next_image_id]
496        embedding_path = None if embedding_paths is None else embedding_paths[next_image_id]
497
498        # Set the new image in the viewer, state and annotator.
499        viewer.layers["image"].data = image
500        viewer.layers["segmentation"].data = segmentation
501
502        state.initialize_predictor(
503            image, model_type=model_type, ndim=ndim,
504            save_path=embedding_path,
505            tile_shape=tile_shape, halo=halo,
506            predictor=state.predictor, device=device,
507        )
508        state.image_shape = image.shape if image.ndim == ndim else image.shape[:-1]
509        state.annotator._update_image()
510
511        # Clear the object features and seg-ids from the state.
512        state.object_features = None
513        state.seg_ids = None
514
515    viewer.window.add_dock_widget(next_image)
516
517    @viewer.bind_key("n", overwrite=True)
518    def _next_image(viewer):
519        next_image(viewer)
520
521    napari.run()

Start the object classifier for a list of images and segmentations.

This function will save the all features and labels for annotated objects, to enable training a random forest on multiple images.

Arguments:
  • images: The input images.
  • segmentations: The input segmentations.
  • output_folder: The folder where segmentation results, trained random forest and the features, labels aggregated during training will be saved.
  • embedding_paths: Filepaths where to save the embeddings or the precompted image embeddings computed by precompute_image_embeddings.
  • model_type: The Segment Anything model to use. For details on the available models check out https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models.
  • tile_shape: Shape of tiles for tiled embedding prediction. If None then the whole image is passed to Segment Anything.
  • halo: Shape of the overlap between tiles, which is needed to segment objects on tile borders.
  • checkpoint_path: Path to a custom checkpoint from which to load the SAM model.
  • device: The computational device to use for the SAM model. By default, automatically chooses the best available device.
  • ndim: The dimensionality of the data. If not given will be derived from the data.