synapse_net.tools.segmentation_widget

  1import copy
  2import re
  3from typing import Optional, Union
  4
  5import napari
  6import numpy as np
  7import torch
  8
  9from napari.utils.notifications import show_info
 10from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox
 11
 12from .base_widget import BaseWidget
 13from ..inference.inference import _get_model_registry, get_model, run_segmentation, compute_scale_from_voxel_size
 14from ..inference.util import get_default_tiling, get_device
 15
 16
 17def _load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
 18    model_path = _clean_filepath(model_path)
 19    if device is None:
 20        device = get_device(device)
 21    try:
 22        model = torch.load(model_path, map_location=torch.device(device), weights_only=False)
 23    except Exception as e:
 24        print(e)
 25        print("model path", model_path)
 26        return None
 27    return model
 28
 29
 30def _available_devices():
 31    available_devices = []
 32    for i in ["cuda", "mps", "cpu"]:
 33        try:
 34            device = get_device(i)
 35        except RuntimeError:
 36            pass
 37        else:
 38            available_devices.append(device)
 39    return available_devices
 40
 41
 42def _get_current_tiling(tiling: dict, default_tiling: dict, model_type: str):
 43    # get tiling values from qt objects
 44    for k, v in tiling.items():
 45        for k2, v2 in v.items():
 46            if isinstance(v2, int):
 47                continue
 48            elif hasattr(v2, "value"):  # If it's a QSpinBox, extract the value
 49                tiling[k][k2] = v2.value()
 50            else:
 51                raise TypeError(f"Unexpected type for tiling value: {type(v2)} at {k}/{k2}")
 52    # check if user inputs tiling/halo or not
 53    if default_tiling == tiling:
 54        if "2d" in model_type:
 55            # if its a 2d model expand x,y and set z to 1
 56            tiling = {
 57                "tile": {"x": 512, "y": 512, "z": 1},
 58                "halo": {"x": 64, "y": 64, "z": 1},
 59            }
 60    else:
 61        show_info(f"Using custom tiling: {tiling}")
 62    if "2d" in model_type:
 63        # if its a 2d model set z to 1
 64        tiling["tile"]["z"] = 1
 65        tiling["halo"]["z"] = 0
 66        show_info(f"Using tiling: {tiling}")
 67    return tiling
 68
 69
 70def _clean_filepath(filepath):
 71    """Cleans a given filepath by:
 72    - Removing newline characters (\n)
 73    - Removing escape sequences
 74    - Stripping the 'file://' prefix if present
 75
 76    Args:
 77        filepath (str): The original filepath
 78
 79    Returns:
 80        str: The cleaned filepath
 81    """
 82    # Remove 'file://' prefix if present
 83    if filepath.startswith("file://"):
 84        filepath = filepath[7:]
 85
 86    # Remove escape sequences and newlines
 87    filepath = re.sub(r'\\.', '', filepath)
 88    filepath = filepath.replace('\n', '').replace('\r', '')
 89
 90    return filepath
 91
 92
 93class SegmentationWidget(BaseWidget):
 94    def __init__(self):
 95        super().__init__()
 96
 97        self.viewer = napari.current_viewer()
 98        layout = QVBoxLayout()
 99        self.tiling = {}
100
101        # Create the image selection dropdown.
102        self.image_selector_name = "Image data"
103        self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image")
104
105        # Create buttons and widgets.
106        self.predict_button = QPushButton("Run Segmentation")
107        self.predict_button.clicked.connect(self.on_predict)
108        self.model_selector_widget = self.load_model_widget()
109        self.settings = self._create_settings_widget()
110
111        # Add the widgets to the layout.
112        layout.addWidget(self.image_selector_widget)
113        layout.addWidget(self.model_selector_widget)
114        layout.addWidget(self.settings)
115        layout.addWidget(self.predict_button)
116
117        self.setLayout(layout)
118
119    def load_model_widget(self):
120        model_widget = QWidget()
121        title_label = QLabel("Select Model:")
122
123        # Exclude the models that are only offered through the CLI and not in the plugin.
124        model_list = set(_get_model_registry().urls.keys())
125        # These are the models exlcuded due to their specificity and to keep the menu simple.
126        # TODO: we should at some point update the logic here, to make it easier to support further models
127        # without cluttering the UI.
128        excluded_models = ["vesicles_2d_maus"]
129        model_list = [name for name in model_list if name not in excluded_models]
130
131        models = ["- choose -"] + model_list
132        self.model_selector = QComboBox()
133        self.model_selector.addItems(models)
134        # Create a layout and add the title label and combo box
135        layout = QVBoxLayout()
136        layout.addWidget(title_label)
137        layout.addWidget(self.model_selector)
138
139        # Set layout on the model widget
140        model_widget.setLayout(layout)
141        return model_widget
142
143    def on_predict(self):
144        # Get the model and postprocessing settings.
145        model_type = self.model_selector.currentText()
146        custom_model_path = self.checkpoint_param.text()
147        if model_type == "- choose -":
148            show_info("INFO: Please choose a model.")
149            return
150
151        device = get_device(self.device_dropdown.currentText())
152
153        # Load the model. Override if user chose custom model.
154        rescale_input = True
155        if custom_model_path:
156            model = _load_custom_model(custom_model_path, device)
157            rescale_input = False
158            if model:
159                show_info(f"INFO: Using custom model from path: {custom_model_path}")
160            else:
161                show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
162                return
163        else:
164            model = get_model(model_type, device)
165
166        # Get the image data.
167        image = self._get_layer_selector_data(self.image_selector_name)
168        if image is None:
169            show_info("INFO: Please choose an image.")
170            return
171
172        # Get the current tiling.
173        self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
174
175        # Get the voxel size.
176        metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
177        voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)
178
179        # Determine the scaling based on the voxel size.
180        scale = None
181        if voxel_size and rescale_input:
182            # Calculate scale so voxel_size is the same as in training.
183            scale = compute_scale_from_voxel_size(voxel_size, model_type)
184            scale_info = list(map(lambda x: np.round(x, 2), scale))
185            show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")
186
187        # Some models require an additional segmentation for inference or postprocessing.
188        # For these models we read out the 'Extra Segmentation' widget.
189        if model_type == "ribbon":  # Currently only the ribbon model needs the extra seg.
190            extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
191            resolution = tuple(voxel_size[ax] for ax in "zyx")
192            kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000}
193        elif model_type.startswith("cristae"):  # Cristae model expects 2 3D volumes
194            kwargs = {
195                "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name),
196                "with_channels": True,
197                "channels_to_standardize": [0]
198            }
199        else:
200            kwargs = {}
201        segmentation = run_segmentation(
202            image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
203        )
204
205        # Add the segmentation layer(s).
206        if isinstance(segmentation, dict):
207            for name, seg in segmentation.items():
208                self.viewer.add_labels(seg, name=name, metadata=metadata)
209        else:
210            self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata)
211        show_info(f"INFO: Segmentation of {model_type} added to layers.")
212
213    def _create_settings_widget(self):
214        setting_values = QWidget()
215        # setting_values.setToolTip(get_tooltip("embedding", "settings"))
216        setting_values.setLayout(QVBoxLayout())
217
218        # Create UI for the device.
219        device = "auto"
220        device_options = ["auto"] + _available_devices()
221
222        self.device_dropdown, layout = self._add_choice_param("device", device, device_options)
223        setting_values.layout().addLayout(layout)
224
225        # Create UI for the tile shape.
226        self.default_tiling = get_default_tiling()
227        self.tiling = copy.deepcopy(self.default_tiling)
228        self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param(
229            ("tile_x", "tile_y", "tile_z"),
230            (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]),
231            min_val=0, max_val=2048, step=16,
232            # tooltip=get_tooltip("embedding", "tiling")
233        )
234        setting_values.layout().addLayout(layout)
235
236        # Create UI for the halo.
237        self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param(
238            ("halo_x", "halo_y", "halo_z"),
239            (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]),
240            min_val=0, max_val=512,
241            # tooltip=get_tooltip("embedding", "halo")
242        )
243        setting_values.layout().addLayout(layout)
244
245        # Read voxel size from layer metadata.
246        self.voxel_size_param, layout = self._add_float_param(
247            "voxel_size", 0.0, min_val=0.0, max_val=100.0,
248        )
249        setting_values.layout().addLayout(layout)
250
251        self.checkpoint_param, layout = self._add_string_param(
252            name="checkpoint", value="", title="Load Custom Model",
253            placeholder="path/to/checkpoint.pt",
254        )
255        setting_values.layout().addLayout(layout)
256
257        # Add selection UI for additional segmentation, which some models require for inference or postproc.
258        self.extra_seg_selector_name = "Extra Segmentation"
259        self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels")
260        setting_values.layout().addWidget(self.extra_selector_widget)
261
262        settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
263        return settings
class SegmentationWidget(synapse_net.tools.base_widget.BaseWidget):
 94class SegmentationWidget(BaseWidget):
 95    def __init__(self):
 96        super().__init__()
 97
 98        self.viewer = napari.current_viewer()
 99        layout = QVBoxLayout()
100        self.tiling = {}
101
102        # Create the image selection dropdown.
103        self.image_selector_name = "Image data"
104        self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image")
105
106        # Create buttons and widgets.
107        self.predict_button = QPushButton("Run Segmentation")
108        self.predict_button.clicked.connect(self.on_predict)
109        self.model_selector_widget = self.load_model_widget()
110        self.settings = self._create_settings_widget()
111
112        # Add the widgets to the layout.
113        layout.addWidget(self.image_selector_widget)
114        layout.addWidget(self.model_selector_widget)
115        layout.addWidget(self.settings)
116        layout.addWidget(self.predict_button)
117
118        self.setLayout(layout)
119
120    def load_model_widget(self):
121        model_widget = QWidget()
122        title_label = QLabel("Select Model:")
123
124        # Exclude the models that are only offered through the CLI and not in the plugin.
125        model_list = set(_get_model_registry().urls.keys())
126        # These are the models exlcuded due to their specificity and to keep the menu simple.
127        # TODO: we should at some point update the logic here, to make it easier to support further models
128        # without cluttering the UI.
129        excluded_models = ["vesicles_2d_maus"]
130        model_list = [name for name in model_list if name not in excluded_models]
131
132        models = ["- choose -"] + model_list
133        self.model_selector = QComboBox()
134        self.model_selector.addItems(models)
135        # Create a layout and add the title label and combo box
136        layout = QVBoxLayout()
137        layout.addWidget(title_label)
138        layout.addWidget(self.model_selector)
139
140        # Set layout on the model widget
141        model_widget.setLayout(layout)
142        return model_widget
143
144    def on_predict(self):
145        # Get the model and postprocessing settings.
146        model_type = self.model_selector.currentText()
147        custom_model_path = self.checkpoint_param.text()
148        if model_type == "- choose -":
149            show_info("INFO: Please choose a model.")
150            return
151
152        device = get_device(self.device_dropdown.currentText())
153
154        # Load the model. Override if user chose custom model.
155        rescale_input = True
156        if custom_model_path:
157            model = _load_custom_model(custom_model_path, device)
158            rescale_input = False
159            if model:
160                show_info(f"INFO: Using custom model from path: {custom_model_path}")
161            else:
162                show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
163                return
164        else:
165            model = get_model(model_type, device)
166
167        # Get the image data.
168        image = self._get_layer_selector_data(self.image_selector_name)
169        if image is None:
170            show_info("INFO: Please choose an image.")
171            return
172
173        # Get the current tiling.
174        self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
175
176        # Get the voxel size.
177        metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
178        voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)
179
180        # Determine the scaling based on the voxel size.
181        scale = None
182        if voxel_size and rescale_input:
183            # Calculate scale so voxel_size is the same as in training.
184            scale = compute_scale_from_voxel_size(voxel_size, model_type)
185            scale_info = list(map(lambda x: np.round(x, 2), scale))
186            show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")
187
188        # Some models require an additional segmentation for inference or postprocessing.
189        # For these models we read out the 'Extra Segmentation' widget.
190        if model_type == "ribbon":  # Currently only the ribbon model needs the extra seg.
191            extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
192            resolution = tuple(voxel_size[ax] for ax in "zyx")
193            kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000}
194        elif model_type.startswith("cristae"):  # Cristae model expects 2 3D volumes
195            kwargs = {
196                "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name),
197                "with_channels": True,
198                "channels_to_standardize": [0]
199            }
200        else:
201            kwargs = {}
202        segmentation = run_segmentation(
203            image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
204        )
205
206        # Add the segmentation layer(s).
207        if isinstance(segmentation, dict):
208            for name, seg in segmentation.items():
209                self.viewer.add_labels(seg, name=name, metadata=metadata)
210        else:
211            self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata)
212        show_info(f"INFO: Segmentation of {model_type} added to layers.")
213
214    def _create_settings_widget(self):
215        setting_values = QWidget()
216        # setting_values.setToolTip(get_tooltip("embedding", "settings"))
217        setting_values.setLayout(QVBoxLayout())
218
219        # Create UI for the device.
220        device = "auto"
221        device_options = ["auto"] + _available_devices()
222
223        self.device_dropdown, layout = self._add_choice_param("device", device, device_options)
224        setting_values.layout().addLayout(layout)
225
226        # Create UI for the tile shape.
227        self.default_tiling = get_default_tiling()
228        self.tiling = copy.deepcopy(self.default_tiling)
229        self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param(
230            ("tile_x", "tile_y", "tile_z"),
231            (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]),
232            min_val=0, max_val=2048, step=16,
233            # tooltip=get_tooltip("embedding", "tiling")
234        )
235        setting_values.layout().addLayout(layout)
236
237        # Create UI for the halo.
238        self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param(
239            ("halo_x", "halo_y", "halo_z"),
240            (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]),
241            min_val=0, max_val=512,
242            # tooltip=get_tooltip("embedding", "halo")
243        )
244        setting_values.layout().addLayout(layout)
245
246        # Read voxel size from layer metadata.
247        self.voxel_size_param, layout = self._add_float_param(
248            "voxel_size", 0.0, min_val=0.0, max_val=100.0,
249        )
250        setting_values.layout().addLayout(layout)
251
252        self.checkpoint_param, layout = self._add_string_param(
253            name="checkpoint", value="", title="Load Custom Model",
254            placeholder="path/to/checkpoint.pt",
255        )
256        setting_values.layout().addLayout(layout)
257
258        # Add selection UI for additional segmentation, which some models require for inference or postproc.
259        self.extra_seg_selector_name = "Extra Segmentation"
260        self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels")
261        setting_values.layout().addWidget(self.extra_selector_widget)
262
263        settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
264        return settings

QWidget(parent: Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())

viewer
tiling
image_selector_name
image_selector_widget
predict_button
model_selector_widget
settings
def load_model_widget(self):
120    def load_model_widget(self):
121        model_widget = QWidget()
122        title_label = QLabel("Select Model:")
123
124        # Exclude the models that are only offered through the CLI and not in the plugin.
125        model_list = set(_get_model_registry().urls.keys())
126        # These are the models exlcuded due to their specificity and to keep the menu simple.
127        # TODO: we should at some point update the logic here, to make it easier to support further models
128        # without cluttering the UI.
129        excluded_models = ["vesicles_2d_maus"]
130        model_list = [name for name in model_list if name not in excluded_models]
131
132        models = ["- choose -"] + model_list
133        self.model_selector = QComboBox()
134        self.model_selector.addItems(models)
135        # Create a layout and add the title label and combo box
136        layout = QVBoxLayout()
137        layout.addWidget(title_label)
138        layout.addWidget(self.model_selector)
139
140        # Set layout on the model widget
141        model_widget.setLayout(layout)
142        return model_widget
def on_predict(self):
144    def on_predict(self):
145        # Get the model and postprocessing settings.
146        model_type = self.model_selector.currentText()
147        custom_model_path = self.checkpoint_param.text()
148        if model_type == "- choose -":
149            show_info("INFO: Please choose a model.")
150            return
151
152        device = get_device(self.device_dropdown.currentText())
153
154        # Load the model. Override if user chose custom model.
155        rescale_input = True
156        if custom_model_path:
157            model = _load_custom_model(custom_model_path, device)
158            rescale_input = False
159            if model:
160                show_info(f"INFO: Using custom model from path: {custom_model_path}")
161            else:
162                show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
163                return
164        else:
165            model = get_model(model_type, device)
166
167        # Get the image data.
168        image = self._get_layer_selector_data(self.image_selector_name)
169        if image is None:
170            show_info("INFO: Please choose an image.")
171            return
172
173        # Get the current tiling.
174        self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
175
176        # Get the voxel size.
177        metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
178        voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)
179
180        # Determine the scaling based on the voxel size.
181        scale = None
182        if voxel_size and rescale_input:
183            # Calculate scale so voxel_size is the same as in training.
184            scale = compute_scale_from_voxel_size(voxel_size, model_type)
185            scale_info = list(map(lambda x: np.round(x, 2), scale))
186            show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")
187
188        # Some models require an additional segmentation for inference or postprocessing.
189        # For these models we read out the 'Extra Segmentation' widget.
190        if model_type == "ribbon":  # Currently only the ribbon model needs the extra seg.
191            extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
192            resolution = tuple(voxel_size[ax] for ax in "zyx")
193            kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000}
194        elif model_type.startswith("cristae"):  # Cristae model expects 2 3D volumes
195            kwargs = {
196                "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name),
197                "with_channels": True,
198                "channels_to_standardize": [0]
199            }
200        else:
201            kwargs = {}
202        segmentation = run_segmentation(
203            image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
204        )
205
206        # Add the segmentation layer(s).
207        if isinstance(segmentation, dict):
208            for name, seg in segmentation.items():
209                self.viewer.add_labels(seg, name=name, metadata=metadata)
210        else:
211            self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata)
212        show_info(f"INFO: Segmentation of {model_type} added to layers.")