1from typing import Optional, List
2
3import numpy as np
4
5import napari
6from qtpy import QtWidgets
7from magicgui.widgets import Widget, Container, FunctionGui
8
9from . import util as vutil
10from . import _widgets as widgets
11from ._state import AnnotatorState
12
13
14class _AnnotatorBase(QtWidgets.QScrollArea):
15 """Base class for micro_sam annotation plugins.
16
17 Implements the logic for the 2d, 3d and tracking annotator.
18 The annotators differ in their data dimensionality and the widgets.
19 """
20
21 def _require_layers(self, layer_choices: Optional[List[str]] = None):
22
23 # Check whether the image is initialized already. And use the image shape and scale for the layers.
24 state = AnnotatorState()
25 shape = self._shape if state.image_shape is None else state.image_shape
26
27 # Add the label layers for the current object, the automatic segmentation and the committed segmentation.
28 dummy_data = np.zeros(shape, dtype="uint32")
29 image_scale = state.image_scale
30
31 # Before adding new layers, we always check whether a layer with this name already exists or not.
32 if "current_object" not in self._viewer.layers:
33 if layer_choices and "current_object" in layer_choices: # Check at 'commit' call button.
34 widgets._validation_window_for_missing_layer("current_object")
35 self._viewer.add_labels(data=dummy_data, name="current_object")
36 if image_scale is not None:
37 self._viewer.layers["current_object"].scale = image_scale
38
39 if "auto_segmentation" not in self._viewer.layers:
40 if layer_choices and "auto_segmentation" in layer_choices: # Check at 'commit' call button.
41 widgets._validation_window_for_missing_layer("auto_segmentation")
42 self._viewer.add_labels(data=dummy_data, name="auto_segmentation")
43 if image_scale is not None:
44 self._viewer.layers["auto_segmentation"].scale = image_scale
45
46 if "committed_objects" not in self._viewer.layers:
47 if layer_choices and "committed_objects" in layer_choices: # Check at 'commit' call button.
48 widgets._validation_window_for_missing_layer("committed_objects")
49 self._viewer.add_labels(data=dummy_data, name="committed_objects")
50 # Randomize colors so it is easy to see when object committed.
51 self._viewer.layers["committed_objects"].new_colormap()
52 if image_scale is not None:
53 self._viewer.layers["committed_objects"].scale = image_scale
54
55 # Add the point layer for point prompts.
56 self._point_labels = ["positive", "negative"]
57 if "point_prompts" in self._viewer.layers:
58 self._point_prompt_layer = self._viewer.layers["point_prompts"]
59 else:
60 self._point_prompt_layer = self._viewer.add_points(
61 name="point_prompts",
62 property_choices={"label": self._point_labels},
63 border_color="label",
64 border_color_cycle=vutil.LABEL_COLOR_CYCLE,
65 symbol="o",
66 face_color="transparent",
67 border_width=0.5,
68 size=12,
69 ndim=self._ndim,
70 )
71 self._point_prompt_layer.border_color_mode = "cycle"
72
73 if "prompts" not in self._viewer.layers:
74 # Add the shape layer for box and other shape prompts.
75 self._viewer.add_shapes(
76 face_color="transparent", edge_color="green", edge_width=4, name="prompts", ndim=self._ndim,
77 )
78
79 # Child classes have to implement this function and create a dictionary with the widgets.
80 def _get_widgets(self):
81 raise NotImplementedError("The child classes of _AnnotatorBase have to implement _get_widgets.")
82
83 def _create_widgets(self):
84 # Create the embedding widget and connect all events related to it.
85 self._embedding_widget = widgets.EmbeddingWidget()
86 # Connect events for the image selection box.
87 self._viewer.layers.events.inserted.connect(self._embedding_widget.image_selection.reset_choices)
88 self._viewer.layers.events.removed.connect(self._embedding_widget.image_selection.reset_choices)
89 # Connect the run button with the function to update the image.
90 self._embedding_widget.run_button.clicked.connect(self._update_image)
91
92 # Create the prompt widget. (The same for all plugins.)
93 self._prompt_widget = widgets.create_prompt_menu(self._point_prompt_layer, self._point_labels)
94
95 # Create the dictionary for the widgets and get the widgets of the child plugin.
96 self._widgets = {"embeddings": self._embedding_widget, "prompts": self._prompt_widget}
97 self._widgets.update(self._get_widgets())
98
99 def _create_keybindings(self):
100 @self._viewer.bind_key("s", overwrite=True)
101 def _segment(viewer):
102 self._widgets["segment"](viewer)
103
104 # Note: we also need to over-write the keybindings for specific layers.
105 # See https://github.com/napari/napari/issues/7302 for details.
106 # Here, we need to over-write the 's' keybinding for both of the prompt layers.
107 prompt_layer = self._viewer.layers["prompts"]
108 point_prompt_layer = self._viewer.layers["point_prompts"]
109
110 @prompt_layer.bind_key("s", overwrite=True)
111 def _segment_prompts(event):
112 self._widgets["segment"](self._viewer)
113
114 @point_prompt_layer.bind_key("s", overwrite=True)
115 def _segment_point_prompts(event):
116 self._widgets["segment"](self._viewer)
117
118 @self._viewer.bind_key("c", overwrite=True)
119 def _commit(viewer):
120 self._widgets["commit"](viewer)
121
122 @self._viewer.bind_key("t", overwrite=True)
123 def _toggle_label(event=None):
124 vutil.toggle_label(self._point_prompt_layer)
125
126 @self._viewer.bind_key("Shift-C", overwrite=True)
127 def _clear_annotations(viewer):
128 self._widgets["clear"](viewer)
129
130 if "segment_nd" in self._widgets:
131 @self._viewer.bind_key("Shift-S", overwrite=True)
132 def _seg_nd(viewer):
133 self._widgets["segment_nd"]()
134
135 # We could implement a better way of initializing the segmentation result,
136 # so that instead of just passing a numpy array an existing layer from the napari
137 # viewer can be chosen.
138 # See https://github.com/computational-cell-analytics/micro-sam/issues/335
139 def __init__(self, viewer: "napari.viewer.Viewer", ndim: int) -> None:
140 """Create the annotator GUI.
141
142 Args:
143 viewer: The napari viewer.
144 ndim: The number of spatial dimension of the image data (2 or 3).
145 """
146 super().__init__()
147 self._viewer = viewer
148 self._annotator_widget = QtWidgets.QWidget()
149 self._annotator_widget.setLayout(QtWidgets.QVBoxLayout())
150
151 # Add the layers for prompts and segmented obejcts.
152 # Initialize with a dummy shape, which is reset to the correct shape once an image is set.
153 self._ndim = ndim
154 self._shape = (256, 256) if ndim == 2 else (16, 256, 256)
155 self._require_layers()
156
157 # Create all the widgets and add them to the layout.
158 self._create_widgets()
159 for widget in self._widgets.values():
160 widget_frame = QtWidgets.QGroupBox()
161 widget_layout = QtWidgets.QVBoxLayout()
162 if isinstance(widget, (Container, FunctionGui, Widget)):
163 # This is a magicgui type and we need to get the native qt widget.
164 widget_layout.addWidget(widget.native)
165 else:
166 # This is a qt type and we add the widget directly.
167 widget_layout.addWidget(widget)
168 widget_frame.setLayout(widget_layout)
169 self._annotator_widget.layout().addWidget(widget_frame)
170
171 # Add the widgets to the state.
172 AnnotatorState().widgets = self._widgets
173
174 # Add the key bindings in common between all annotators.
175 self._create_keybindings()
176
177 # Add the widget to the scroll area.
178 self.setWidgetResizable(True) # Allow widget to resize within scroll area.
179 self.setWidget(self._annotator_widget)
180
181 def _update_image(self, segmentation_result=None):
182 state = AnnotatorState()
183
184 # Whether embeddings already exist and avoid clearing objects in layers.
185 if state.skip_recomputing_embeddings:
186 return
187
188 # This is encountered when there is no image layer available / selected.
189 # In this case, we need not update the image shape or check for changes.
190 # NOTE: On code-level, this happens when '__init__' method is called by '_AnnotatorBase',
191 # where one of the first steps is to '_create_widgets', which reaches here.
192 if state.image_shape is None:
193 return
194
195 # Update the image shape if it has changed.
196 if state.image_shape != self._shape:
197 if len(state.image_shape) != self._ndim:
198 raise RuntimeError(
199 f"The dim of the annotator {self._ndim} does not match the image data of shape {state.image_shape}."
200 )
201 self._shape = state.image_shape
202
203 # Before we reset the layers, we ensure all expected layers exist.
204 self._require_layers()
205
206 # Update the image scale.
207 scale = state.image_scale
208
209 # Reset all layers.
210 self._viewer.layers["current_object"].data = np.zeros(self._shape, dtype="uint32")
211 self._viewer.layers["current_object"].scale = scale
212 self._viewer.layers["auto_segmentation"].data = np.zeros(self._shape, dtype="uint32")
213 self._viewer.layers["auto_segmentation"].scale = scale
214
215 if segmentation_result is None or segmentation_result is False:
216 self._viewer.layers["committed_objects"].data = np.zeros(self._shape, dtype="uint32")
217 else:
218 assert segmentation_result.shape == self._shape
219 self._viewer.layers["committed_objects"].data = segmentation_result
220 self._viewer.layers["committed_objects"].scale = scale
221
222 self._viewer.layers["point_prompts"].scale = scale
223 self._viewer.layers["prompts"].scale = scale
224
225 vutil.clear_annotations(self._viewer, clear_segmentations=False)