synapse_net.tools.segmentation_widget

  1import copy
  2import re
  3from typing import Optional, Union
  4
  5import napari
  6import numpy as np
  7import torch
  8
  9from napari.utils.notifications import show_info
 10from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox
 11
 12from .base_widget import BaseWidget
 13from ..inference.inference import _get_model_registry, get_model, run_segmentation, compute_scale_from_voxel_size
 14from ..inference.util import get_default_tiling, get_device
 15
 16
 17def _load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
 18    model_path = _clean_filepath(model_path)
 19    if device is None:
 20        device = get_device(device)
 21    try:
 22        model = torch.load(model_path, map_location=torch.device(device), weights_only=False)
 23    except Exception as e:
 24        print(e)
 25        print("model path", model_path)
 26        return None
 27    return model
 28
 29
 30def _available_devices():
 31    available_devices = []
 32    for i in ["cuda", "mps", "cpu"]:
 33        try:
 34            device = get_device(i)
 35        except RuntimeError:
 36            pass
 37        else:
 38            available_devices.append(device)
 39    return available_devices
 40
 41
 42def _get_current_tiling(tiling: dict, default_tiling: dict, model_type: str):
 43    # get tiling values from qt objects
 44    for k, v in tiling.items():
 45        for k2, v2 in v.items():
 46            if isinstance(v2, int):
 47                continue
 48            elif hasattr(v2, "value"):  # If it's a QSpinBox, extract the value
 49                tiling[k][k2] = v2.value()
 50            else:
 51                raise TypeError(f"Unexpected type for tiling value: {type(v2)} at {k}/{k2}")
 52    # check if user inputs tiling/halo or not
 53    if default_tiling == tiling:
 54        if "2d" in model_type:
 55            # if its a 2d model expand x,y and set z to 1
 56            tiling = {
 57                "tile": {"x": 512, "y": 512, "z": 1},
 58                "halo": {"x": 64, "y": 64, "z": 1},
 59            }
 60    else:
 61        show_info(f"Using custom tiling: {tiling}")
 62    if "2d" in model_type:
 63        # if its a 2d model set z to 1
 64        tiling["tile"]["z"] = 1
 65        tiling["halo"]["z"] = 0
 66        show_info(f"Using tiling: {tiling}")
 67    return tiling
 68
 69
 70def _clean_filepath(filepath):
 71    """Cleans a given filepath by:
 72    - Removing newline characters (\n)
 73    - Removing escape sequences
 74    - Stripping the 'file://' prefix if present
 75
 76    Args:
 77        filepath (str): The original filepath
 78
 79    Returns:
 80        str: The cleaned filepath
 81    """
 82    # Remove 'file://' prefix if present
 83    if filepath.startswith("file://"):
 84        filepath = filepath[7:]
 85
 86    # Remove escape sequences and newlines
 87    filepath = re.sub(r'\\.', '', filepath)
 88    filepath = filepath.replace('\n', '').replace('\r', '')
 89
 90    return filepath
 91
 92
 93class SegmentationWidget(BaseWidget):
 94    def __init__(self):
 95        super().__init__()
 96
 97        self.viewer = napari.current_viewer()
 98        layout = QVBoxLayout()
 99        self.tiling = {}
100
101        # Create the image selection dropdown.
102        self.image_selector_name = "Image data"
103        self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image")
104
105        # Create buttons and widgets.
106        self.predict_button = QPushButton("Run Segmentation")
107        self.predict_button.clicked.connect(self.on_predict)
108        self.model_selector_widget = self.load_model_widget()
109        self.settings = self._create_settings_widget()
110
111        # Add the widgets to the layout.
112        layout.addWidget(self.image_selector_widget)
113        layout.addWidget(self.model_selector_widget)
114        layout.addWidget(self.settings)
115        layout.addWidget(self.predict_button)
116
117        self.setLayout(layout)
118
119    def load_model_widget(self):
120        model_widget = QWidget()
121        title_label = QLabel("Select Model:")
122
123        # Exclude the models that are only offered through the CLI and not in the plugin.
124        model_list = set(_get_model_registry().urls.keys())
125        # These are the models exlcuded due to their specificity and to keep the menu simple.
126        # TODO: we should at some point update the logic here, to make it easier to support further models
127        # without cluttering the UI.
128        excluded_models = ["vesicles_2d_maus"]
129        model_list = [name for name in model_list if name not in excluded_models]
130
131        models = ["- choose -"] + model_list
132        self.model_selector = QComboBox()
133        self.model_selector.addItems(models)
134        # Create a layout and add the title label and combo box
135        layout = QVBoxLayout()
136        layout.addWidget(title_label)
137        layout.addWidget(self.model_selector)
138
139        # Set layout on the model widget
140        model_widget.setLayout(layout)
141        return model_widget
142
143    def on_predict(self):
144        # Get the model and postprocessing settings.
145        model_type = self.model_selector.currentText()
146        custom_model_path = self.checkpoint_param.text()
147        if model_type == "- choose -" and custom_model_path is None:
148            show_info("INFO: Please choose a model.")
149            return
150
151        device = get_device(self.device_dropdown.currentText())
152
153        # Load the model. Override if user chose custom model
154        if custom_model_path:
155            model = _load_custom_model(custom_model_path, device)
156            if model:
157                show_info(f"INFO: Using custom model from path: {custom_model_path}")
158                model_type = "custom"
159            else:
160                show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
161                return
162        else:
163            model = get_model(model_type, device)
164
165        # Get the image data.
166        image = self._get_layer_selector_data(self.image_selector_name)
167        if image is None:
168            show_info("INFO: Please choose an image.")
169            return
170
171        # Get the current tiling.
172        self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
173
174        # Get the voxel size.
175        metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
176        voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)
177
178        # Determine the scaling based on the voxel size.
179        scale = None
180        if voxel_size:
181            if model_type == "custom":
182                show_info("INFO: The image is not rescaled for a custom model.")
183            else:
184                # calculate scale so voxel_size is the same as in training
185                scale = compute_scale_from_voxel_size(voxel_size, model_type)
186                scale_info = list(map(lambda x: np.round(x, 2), scale))
187                show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")
188
189        # Some models require an additional segmentation for inference or postprocessing.
190        # For these models we read out the 'Extra Segmentation' widget.
191        if model_type == "ribbon":  # Currently only the ribbon model needs the extra seg.
192            extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
193            resolution = tuple(voxel_size[ax] for ax in "zyx")
194            kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000}
195        elif model_type == "cristae":  # Cristae model expects 2 3D volumes
196            kwargs = {
197                "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name),
198                "with_channels": True,
199                "channels_to_standardize": [0]
200            }
201        else:
202            kwargs = {}
203        segmentation = run_segmentation(
204            image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
205        )
206
207        # Add the segmentation layer(s).
208        if isinstance(segmentation, dict):
209            for name, seg in segmentation.items():
210                self.viewer.add_labels(seg, name=name, metadata=metadata)
211        else:
212            self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata)
213        show_info(f"INFO: Segmentation of {model_type} added to layers.")
214
215    def _create_settings_widget(self):
216        setting_values = QWidget()
217        # setting_values.setToolTip(get_tooltip("embedding", "settings"))
218        setting_values.setLayout(QVBoxLayout())
219
220        # Create UI for the device.
221        device = "auto"
222        device_options = ["auto"] + _available_devices()
223
224        self.device_dropdown, layout = self._add_choice_param("device", device, device_options)
225        setting_values.layout().addLayout(layout)
226
227        # Create UI for the tile shape.
228        self.default_tiling = get_default_tiling()
229        self.tiling = copy.deepcopy(self.default_tiling)
230        self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param(
231            ("tile_x", "tile_y", "tile_z"),
232            (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]),
233            min_val=0, max_val=2048, step=16,
234            # tooltip=get_tooltip("embedding", "tiling")
235        )
236        setting_values.layout().addLayout(layout)
237
238        # Create UI for the halo.
239        self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param(
240            ("halo_x", "halo_y", "halo_z"),
241            (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]),
242            min_val=0, max_val=512,
243            # tooltip=get_tooltip("embedding", "halo")
244        )
245        setting_values.layout().addLayout(layout)
246
247        # Read voxel size from layer metadata.
248        self.voxel_size_param, layout = self._add_float_param(
249            "voxel_size", 0.0, min_val=0.0, max_val=100.0,
250        )
251        setting_values.layout().addLayout(layout)
252
253        self.checkpoint_param, layout = self._add_string_param(
254            name="checkpoint", value="", title="Load Custom Model",
255            placeholder="path/to/checkpoint.pt",
256        )
257        setting_values.layout().addLayout(layout)
258
259        # Add selection UI for additional segmentation, which some models require for inference or postproc.
260        self.extra_seg_selector_name = "Extra Segmentation"
261        self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels")
262        setting_values.layout().addWidget(self.extra_selector_widget)
263
264        settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
265        return settings
class SegmentationWidget(synapse_net.tools.base_widget.BaseWidget):
 94class SegmentationWidget(BaseWidget):
 95    def __init__(self):
 96        super().__init__()
 97
 98        self.viewer = napari.current_viewer()
 99        layout = QVBoxLayout()
100        self.tiling = {}
101
102        # Create the image selection dropdown.
103        self.image_selector_name = "Image data"
104        self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image")
105
106        # Create buttons and widgets.
107        self.predict_button = QPushButton("Run Segmentation")
108        self.predict_button.clicked.connect(self.on_predict)
109        self.model_selector_widget = self.load_model_widget()
110        self.settings = self._create_settings_widget()
111
112        # Add the widgets to the layout.
113        layout.addWidget(self.image_selector_widget)
114        layout.addWidget(self.model_selector_widget)
115        layout.addWidget(self.settings)
116        layout.addWidget(self.predict_button)
117
118        self.setLayout(layout)
119
120    def load_model_widget(self):
121        model_widget = QWidget()
122        title_label = QLabel("Select Model:")
123
124        # Exclude the models that are only offered through the CLI and not in the plugin.
125        model_list = set(_get_model_registry().urls.keys())
126        # These are the models exlcuded due to their specificity and to keep the menu simple.
127        # TODO: we should at some point update the logic here, to make it easier to support further models
128        # without cluttering the UI.
129        excluded_models = ["vesicles_2d_maus"]
130        model_list = [name for name in model_list if name not in excluded_models]
131
132        models = ["- choose -"] + model_list
133        self.model_selector = QComboBox()
134        self.model_selector.addItems(models)
135        # Create a layout and add the title label and combo box
136        layout = QVBoxLayout()
137        layout.addWidget(title_label)
138        layout.addWidget(self.model_selector)
139
140        # Set layout on the model widget
141        model_widget.setLayout(layout)
142        return model_widget
143
144    def on_predict(self):
145        # Get the model and postprocessing settings.
146        model_type = self.model_selector.currentText()
147        custom_model_path = self.checkpoint_param.text()
148        if model_type == "- choose -" and custom_model_path is None:
149            show_info("INFO: Please choose a model.")
150            return
151
152        device = get_device(self.device_dropdown.currentText())
153
154        # Load the model. Override if user chose custom model
155        if custom_model_path:
156            model = _load_custom_model(custom_model_path, device)
157            if model:
158                show_info(f"INFO: Using custom model from path: {custom_model_path}")
159                model_type = "custom"
160            else:
161                show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
162                return
163        else:
164            model = get_model(model_type, device)
165
166        # Get the image data.
167        image = self._get_layer_selector_data(self.image_selector_name)
168        if image is None:
169            show_info("INFO: Please choose an image.")
170            return
171
172        # Get the current tiling.
173        self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
174
175        # Get the voxel size.
176        metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
177        voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)
178
179        # Determine the scaling based on the voxel size.
180        scale = None
181        if voxel_size:
182            if model_type == "custom":
183                show_info("INFO: The image is not rescaled for a custom model.")
184            else:
185                # calculate scale so voxel_size is the same as in training
186                scale = compute_scale_from_voxel_size(voxel_size, model_type)
187                scale_info = list(map(lambda x: np.round(x, 2), scale))
188                show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")
189
190        # Some models require an additional segmentation for inference or postprocessing.
191        # For these models we read out the 'Extra Segmentation' widget.
192        if model_type == "ribbon":  # Currently only the ribbon model needs the extra seg.
193            extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
194            resolution = tuple(voxel_size[ax] for ax in "zyx")
195            kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000}
196        elif model_type == "cristae":  # Cristae model expects 2 3D volumes
197            kwargs = {
198                "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name),
199                "with_channels": True,
200                "channels_to_standardize": [0]
201            }
202        else:
203            kwargs = {}
204        segmentation = run_segmentation(
205            image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
206        )
207
208        # Add the segmentation layer(s).
209        if isinstance(segmentation, dict):
210            for name, seg in segmentation.items():
211                self.viewer.add_labels(seg, name=name, metadata=metadata)
212        else:
213            self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata)
214        show_info(f"INFO: Segmentation of {model_type} added to layers.")
215
216    def _create_settings_widget(self):
217        setting_values = QWidget()
218        # setting_values.setToolTip(get_tooltip("embedding", "settings"))
219        setting_values.setLayout(QVBoxLayout())
220
221        # Create UI for the device.
222        device = "auto"
223        device_options = ["auto"] + _available_devices()
224
225        self.device_dropdown, layout = self._add_choice_param("device", device, device_options)
226        setting_values.layout().addLayout(layout)
227
228        # Create UI for the tile shape.
229        self.default_tiling = get_default_tiling()
230        self.tiling = copy.deepcopy(self.default_tiling)
231        self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param(
232            ("tile_x", "tile_y", "tile_z"),
233            (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]),
234            min_val=0, max_val=2048, step=16,
235            # tooltip=get_tooltip("embedding", "tiling")
236        )
237        setting_values.layout().addLayout(layout)
238
239        # Create UI for the halo.
240        self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param(
241            ("halo_x", "halo_y", "halo_z"),
242            (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]),
243            min_val=0, max_val=512,
244            # tooltip=get_tooltip("embedding", "halo")
245        )
246        setting_values.layout().addLayout(layout)
247
248        # Read voxel size from layer metadata.
249        self.voxel_size_param, layout = self._add_float_param(
250            "voxel_size", 0.0, min_val=0.0, max_val=100.0,
251        )
252        setting_values.layout().addLayout(layout)
253
254        self.checkpoint_param, layout = self._add_string_param(
255            name="checkpoint", value="", title="Load Custom Model",
256            placeholder="path/to/checkpoint.pt",
257        )
258        setting_values.layout().addLayout(layout)
259
260        # Add selection UI for additional segmentation, which some models require for inference or postproc.
261        self.extra_seg_selector_name = "Extra Segmentation"
262        self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels")
263        setting_values.layout().addWidget(self.extra_selector_widget)
264
265        settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
266        return settings

QWidget(parent: Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())

viewer
tiling
image_selector_name
image_selector_widget
predict_button
model_selector_widget
settings
def load_model_widget(self):
120    def load_model_widget(self):
121        model_widget = QWidget()
122        title_label = QLabel("Select Model:")
123
124        # Exclude the models that are only offered through the CLI and not in the plugin.
125        model_list = set(_get_model_registry().urls.keys())
126        # These are the models exlcuded due to their specificity and to keep the menu simple.
127        # TODO: we should at some point update the logic here, to make it easier to support further models
128        # without cluttering the UI.
129        excluded_models = ["vesicles_2d_maus"]
130        model_list = [name for name in model_list if name not in excluded_models]
131
132        models = ["- choose -"] + model_list
133        self.model_selector = QComboBox()
134        self.model_selector.addItems(models)
135        # Create a layout and add the title label and combo box
136        layout = QVBoxLayout()
137        layout.addWidget(title_label)
138        layout.addWidget(self.model_selector)
139
140        # Set layout on the model widget
141        model_widget.setLayout(layout)
142        return model_widget
def on_predict(self):
144    def on_predict(self):
145        # Get the model and postprocessing settings.
146        model_type = self.model_selector.currentText()
147        custom_model_path = self.checkpoint_param.text()
148        if model_type == "- choose -" and custom_model_path is None:
149            show_info("INFO: Please choose a model.")
150            return
151
152        device = get_device(self.device_dropdown.currentText())
153
154        # Load the model. Override if user chose custom model
155        if custom_model_path:
156            model = _load_custom_model(custom_model_path, device)
157            if model:
158                show_info(f"INFO: Using custom model from path: {custom_model_path}")
159                model_type = "custom"
160            else:
161                show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
162                return
163        else:
164            model = get_model(model_type, device)
165
166        # Get the image data.
167        image = self._get_layer_selector_data(self.image_selector_name)
168        if image is None:
169            show_info("INFO: Please choose an image.")
170            return
171
172        # Get the current tiling.
173        self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
174
175        # Get the voxel size.
176        metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
177        voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)
178
179        # Determine the scaling based on the voxel size.
180        scale = None
181        if voxel_size:
182            if model_type == "custom":
183                show_info("INFO: The image is not rescaled for a custom model.")
184            else:
185                # calculate scale so voxel_size is the same as in training
186                scale = compute_scale_from_voxel_size(voxel_size, model_type)
187                scale_info = list(map(lambda x: np.round(x, 2), scale))
188                show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")
189
190        # Some models require an additional segmentation for inference or postprocessing.
191        # For these models we read out the 'Extra Segmentation' widget.
192        if model_type == "ribbon":  # Currently only the ribbon model needs the extra seg.
193            extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
194            resolution = tuple(voxel_size[ax] for ax in "zyx")
195            kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000}
196        elif model_type == "cristae":  # Cristae model expects 2 3D volumes
197            kwargs = {
198                "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name),
199                "with_channels": True,
200                "channels_to_standardize": [0]
201            }
202        else:
203            kwargs = {}
204        segmentation = run_segmentation(
205            image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
206        )
207
208        # Add the segmentation layer(s).
209        if isinstance(segmentation, dict):
210            for name, seg in segmentation.items():
211                self.viewer.add_labels(seg, name=name, metadata=metadata)
212        else:
213            self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata)
214        show_info(f"INFO: Segmentation of {model_type} added to layers.")