micro_sam.sam_annotator._annotator

  1from typing import Optional, List
  2
  3import numpy as np
  4
  5import napari
  6from qtpy import QtWidgets
  7from magicgui.widgets import Widget, Container, FunctionGui
  8
  9from . import util as vutil
 10from . import _widgets as widgets
 11from ._state import AnnotatorState
 12
 13
 14class _AnnotatorBase(QtWidgets.QScrollArea):
 15    """Base class for micro_sam annotation plugins.
 16
 17    Implements the logic for the 2d, 3d and tracking annotator.
 18    The annotators differ in their data dimensionality and the widgets.
 19    """
 20
 21    def _require_layers(self, layer_choices: Optional[List[str]] = None):
 22
 23        # Check whether the image is initialized already. And use the image shape and scale for the layers.
 24        state = AnnotatorState()
 25        shape = self._shape if state.image_shape is None else state.image_shape
 26
 27        # Add the label layers for the current object, the automatic segmentation and the committed segmentation.
 28        dummy_data = np.zeros(shape, dtype="uint32")
 29        image_scale = state.image_scale
 30
 31        # Before adding new layers, we always check whether a layer with this name already exists or not.
 32        if "current_object" not in self._viewer.layers:
 33            if layer_choices and "current_object" in layer_choices:  # Check at 'commit' call button.
 34                widgets._validation_window_for_missing_layer("current_object")
 35            self._viewer.add_labels(data=dummy_data, name="current_object")
 36            if image_scale is not None:
 37                self._viewer.layers["current_object"].scale = image_scale
 38
 39        if "auto_segmentation" not in self._viewer.layers:
 40            if layer_choices and "auto_segmentation" in layer_choices:  # Check at 'commit' call button.
 41                widgets._validation_window_for_missing_layer("auto_segmentation")
 42            self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
 43            if image_scale is not None:
 44                self._viewer.layers["auto_segmentation"].scale = image_scale
 45
 46        if "committed_objects" not in self._viewer.layers:
 47            if layer_choices and "committed_objects" in layer_choices:  # Check at 'commit' call button.
 48                widgets._validation_window_for_missing_layer("committed_objects")
 49            self._viewer.add_labels(data=dummy_data, name="committed_objects")
 50            # Randomize colors so it is easy to see when object committed.
 51            self._viewer.layers["committed_objects"].new_colormap()
 52            if image_scale is not None:
 53                self._viewer.layers["committed_objects"].scale = image_scale
 54
 55        # Add the point layer for point prompts.
 56        self._point_labels = ["positive", "negative"]
 57        if "point_prompts" in self._viewer.layers:
 58            self._point_prompt_layer = self._viewer.layers["point_prompts"]
 59        else:
 60            self._point_prompt_layer = self._viewer.add_points(
 61                name="point_prompts",
 62                property_choices={"label": self._point_labels},
 63                border_color="label",
 64                border_color_cycle=vutil.LABEL_COLOR_CYCLE,
 65                symbol="o",
 66                face_color="transparent",
 67                border_width=0.5,
 68                size=12,
 69                ndim=self._ndim,
 70            )
 71            self._point_prompt_layer.border_color_mode = "cycle"
 72
 73        if "prompts" not in self._viewer.layers:
 74            # Add the shape layer for box and other shape prompts.
 75            self._viewer.add_shapes(
 76                face_color="transparent", edge_color="green", edge_width=4, name="prompts", ndim=self._ndim,
 77            )
 78
 79    # Child classes have to implement this function and create a dictionary with the widgets.
 80    def _get_widgets(self):
 81        raise NotImplementedError("The child classes of _AnnotatorBase have to implement _get_widgets.")
 82
 83    def _create_widgets(self):
 84        # Create the embedding widget and connect all events related to it.
 85        self._embedding_widget = widgets.EmbeddingWidget()
 86        # Connect events for the image selection box.
 87        self._viewer.layers.events.inserted.connect(self._embedding_widget.image_selection.reset_choices)
 88        self._viewer.layers.events.removed.connect(self._embedding_widget.image_selection.reset_choices)
 89        # Connect the run button with the function to update the image.
 90        self._embedding_widget.run_button.clicked.connect(self._update_image)
 91
 92        # Create the prompt widget. (The same for all plugins.)
 93        self._prompt_widget = widgets.create_prompt_menu(self._point_prompt_layer, self._point_labels)
 94
 95        # Create the dictionary for the widgets and get the widgets of the child plugin.
 96        self._widgets = {"embeddings": self._embedding_widget, "prompts": self._prompt_widget}
 97        self._widgets.update(self._get_widgets())
 98
 99    def _create_keybindings(self):
100        @self._viewer.bind_key("s", overwrite=True)
101        def _segment(viewer):
102            self._widgets["segment"](viewer)
103
104        # Note: we also need to over-write the keybindings for specific layers.
105        # See https://github.com/napari/napari/issues/7302 for details.
106        # Here, we need to over-write the 's' keybinding for both of the prompt layers.
107        prompt_layer = self._viewer.layers["prompts"]
108        point_prompt_layer = self._viewer.layers["point_prompts"]
109
110        @prompt_layer.bind_key("s", overwrite=True)
111        def _segment_prompts(event):
112            self._widgets["segment"](self._viewer)
113
114        @point_prompt_layer.bind_key("s", overwrite=True)
115        def _segment_point_prompts(event):
116            self._widgets["segment"](self._viewer)
117
118        @self._viewer.bind_key("c", overwrite=True)
119        def _commit(viewer):
120            self._widgets["commit"](viewer)
121
122        @self._viewer.bind_key("t", overwrite=True)
123        def _toggle_label(event=None):
124            vutil.toggle_label(self._point_prompt_layer)
125
126        @self._viewer.bind_key("Shift-C", overwrite=True)
127        def _clear_annotations(viewer):
128            self._widgets["clear"](viewer)
129
130        if "segment_nd" in self._widgets:
131            @self._viewer.bind_key("Shift-S", overwrite=True)
132            def _seg_nd(viewer):
133                self._widgets["segment_nd"]()
134
135    # We could implement a better way of initializing the segmentation result,
136    # so that instead of just passing a numpy array an existing layer from the napari
137    # viewer can be chosen.
138    # See https://github.com/computational-cell-analytics/micro-sam/issues/335
139    def __init__(self, viewer: "napari.viewer.Viewer", ndim: int) -> None:
140        """Create the annotator GUI.
141
142        Args:
143            viewer: The napari viewer.
144            ndim: The number of spatial dimension of the image data (2 or 3).
145        """
146        super().__init__()
147        self._viewer = viewer
148        self._annotator_widget = QtWidgets.QWidget()
149        self._annotator_widget.setLayout(QtWidgets.QVBoxLayout())
150
151        # Add the layers for prompts and segmented obejcts.
152        # Initialize with a dummy shape, which is reset to the correct shape once an image is set.
153        self._ndim = ndim
154        self._shape = (256, 256) if ndim == 2 else (16, 256, 256)
155        self._require_layers()
156
157        # Create all the widgets and add them to the layout.
158        self._create_widgets()
159        for widget in self._widgets.values():
160            widget_frame = QtWidgets.QGroupBox()
161            widget_layout = QtWidgets.QVBoxLayout()
162            if isinstance(widget, (Container, FunctionGui, Widget)):
163                # This is a magicgui type and we need to get the native qt widget.
164                widget_layout.addWidget(widget.native)
165            else:
166                # This is a qt type and we add the widget directly.
167                widget_layout.addWidget(widget)
168            widget_frame.setLayout(widget_layout)
169            self._annotator_widget.layout().addWidget(widget_frame)
170
171        # Add the widgets to the state.
172        AnnotatorState().widgets = self._widgets
173
174        # Add the key bindings in common between all annotators.
175        self._create_keybindings()
176
177        # Add the widget to the scroll area.
178        self.setWidgetResizable(True)  # Allow widget to resize within scroll area.
179        self.setWidget(self._annotator_widget)
180
181    def _update_image(self, segmentation_result=None):
182        state = AnnotatorState()
183
184        # Whether embeddings already exist and avoid clearing objects in layers.
185        if state.skip_recomputing_embeddings:
186            return
187
188        # This is encountered when there is no image layer available / selected.
189        # In this case, we need not update the image shape or check for changes.
190        # NOTE: On code-level, this happens when '__init__' method is called by '_AnnotatorBase',
191        #       where one of the first steps is to '_create_widgets', which reaches here.
192        if state.image_shape is None:
193            return
194
195        # Update the image shape if it has changed.
196        if state.image_shape != self._shape:
197            if len(state.image_shape) != self._ndim:
198                raise RuntimeError(
199                    f"The dim of the annotator {self._ndim} does not match the image data of shape {state.image_shape}."
200                )
201            self._shape = state.image_shape
202
203        # Before we reset the layers, we ensure all expected layers exist.
204        self._require_layers()
205
206        # Update the image scale.
207        scale = state.image_scale
208
209        # Reset all layers.
210        self._viewer.layers["current_object"].data = np.zeros(self._shape, dtype="uint32")
211        self._viewer.layers["current_object"].scale = scale
212        self._viewer.layers["auto_segmentation"].data = np.zeros(self._shape, dtype="uint32")
213        self._viewer.layers["auto_segmentation"].scale = scale
214
215        if segmentation_result is None or segmentation_result is False:
216            self._viewer.layers["committed_objects"].data = np.zeros(self._shape, dtype="uint32")
217        else:
218            assert segmentation_result.shape == self._shape
219            self._viewer.layers["committed_objects"].data = segmentation_result
220        self._viewer.layers["committed_objects"].scale = scale
221
222        self._viewer.layers["point_prompts"].scale = scale
223        self._viewer.layers["prompts"].scale = scale
224
225        vutil.clear_annotations(self._viewer, clear_segmentations=False)