synapse_net.tools.segmentation_widget

  1import copy
  2import re
  3from typing import Optional, Union
  4
  5import napari
  6import numpy as np
  7import torch
  8
  9from napari.utils.notifications import show_info
 10from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox
 11
 12from .base_widget import BaseWidget
 13from ..inference.inference import _get_model_registry, get_model, run_segmentation, compute_scale_from_voxel_size
 14from ..inference.util import get_default_tiling, get_device
 15
 16
 17def _load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
 18    model_path = _clean_filepath(model_path)
 19    if device is None:
 20        device = get_device(device)
 21    try:
 22        model = torch.load(model_path, map_location=torch.device(device), weights_only=False)
 23    except Exception as e:
 24        print(e)
 25        print("model path", model_path)
 26        return None
 27    return model
 28
 29
 30def _available_devices():
 31    available_devices = []
 32    for i in ["cuda", "mps", "cpu"]:
 33        try:
 34            device = get_device(i)
 35        except RuntimeError:
 36            pass
 37        else:
 38            available_devices.append(device)
 39    return available_devices
 40
 41
 42def _get_current_tiling(tiling: dict, default_tiling: dict, model_type: str):
 43    # get tiling values from qt objects
 44    for k, v in tiling.items():
 45        for k2, v2 in v.items():
 46            if isinstance(v2, int):
 47                continue
 48            elif hasattr(v2, "value"):  # If it's a QSpinBox, extract the value
 49                tiling[k][k2] = v2.value()
 50            else:
 51                raise TypeError(f"Unexpected type for tiling value: {type(v2)} at {k}/{k2}")
 52    # check if user inputs tiling/halo or not
 53    if default_tiling == tiling:
 54        if "2d" in model_type:
 55            # if its a 2d model expand x,y and set z to 1
 56            tiling = {
 57                "tile": {"x": 512, "y": 512, "z": 1},
 58                "halo": {"x": 64, "y": 64, "z": 1},
 59            }
 60    else:
 61        show_info(f"Using custom tiling: {tiling}")
 62    if "2d" in model_type:
 63        # if its a 2d model set z to 1
 64        tiling["tile"]["z"] = 1
 65        tiling["halo"]["z"] = 0
 66        show_info(f"Using tiling: {tiling}")
 67    return tiling
 68
 69
 70def _clean_filepath(filepath):
 71    """Cleans a given filepath by:
 72    - Removing newline characters (\n)
 73    - Removing escape sequences
 74    - Stripping the 'file://' prefix if present
 75
 76    Args:
 77        filepath (str): The original filepath
 78
 79    Returns:
 80        str: The cleaned filepath
 81    """
 82    # Remove 'file://' prefix if present
 83    if filepath.startswith("file://"):
 84        filepath = filepath[7:]
 85
 86    # Remove escape sequences and newlines
 87    filepath = re.sub(r'\\.', '', filepath)
 88    filepath = filepath.replace('\n', '').replace('\r', '')
 89
 90    return filepath
 91
 92
 93class SegmentationWidget(BaseWidget):
 94    def __init__(self):
 95        super().__init__()
 96
 97        self.viewer = napari.current_viewer()
 98        layout = QVBoxLayout()
 99        self.tiling = {}
100
101        # Create the image selection dropdown.
102        self.image_selector_name = "Image data"
103        self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image")
104
105        # Create buttons and widgets.
106        self.predict_button = QPushButton("Run Segmentation")
107        self.predict_button.clicked.connect(self.on_predict)
108        self.model_selector_widget = self.load_model_widget()
109        self.settings = self._create_settings_widget()
110
111        # Add the widgets to the layout.
112        layout.addWidget(self.image_selector_widget)
113        layout.addWidget(self.model_selector_widget)
114        layout.addWidget(self.settings)
115        layout.addWidget(self.predict_button)
116
117        self.setLayout(layout)
118
119    def load_model_widget(self):
120        model_widget = QWidget()
121        title_label = QLabel("Select Model:")
122
123        # Exclude the models that are only offered through the CLI and not in the plugin.
124        model_list = set(_get_model_registry().urls.keys())
125        excluded_models = ["vesicles_2d_maus", "vesicles_3d_endbulb", "vesicles_3d_innerear"]
126        model_list = [name for name in model_list if name not in excluded_models]
127
128        models = ["- choose -"] + model_list
129        self.model_selector = QComboBox()
130        self.model_selector.addItems(models)
131        # Create a layout and add the title label and combo box
132        layout = QVBoxLayout()
133        layout.addWidget(title_label)
134        layout.addWidget(self.model_selector)
135
136        # Set layout on the model widget
137        model_widget.setLayout(layout)
138        return model_widget
139
140    def on_predict(self):
141        # Get the model and postprocessing settings.
142        model_type = self.model_selector.currentText()
143        custom_model_path = self.checkpoint_param.text()
144        if model_type == "- choose -" and custom_model_path is None:
145            show_info("INFO: Please choose a model.")
146            return
147
148        device = get_device(self.device_dropdown.currentText())
149
150        # Load the model. Override if user chose custom model
151        if custom_model_path:
152            model = _load_custom_model(custom_model_path, device)
153            if model:
154                show_info(f"INFO: Using custom model from path: {custom_model_path}")
155                model_type = "custom"
156            else:
157                show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
158                return
159        else:
160            model = get_model(model_type, device)
161
162        # Get the image data.
163        image = self._get_layer_selector_data(self.image_selector_name)
164        if image is None:
165            show_info("INFO: Please choose an image.")
166            return
167
168        # Get the current tiling.
169        self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
170
171        # Get the voxel size.
172        metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
173        voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)
174
175        # Determine the scaling based on the voxel size.
176        scale = None
177        if voxel_size:
178            if model_type == "custom":
179                show_info("INFO: The image is not rescaled for a custom model.")
180            else:
181                # calculate scale so voxel_size is the same as in training
182                scale = compute_scale_from_voxel_size(voxel_size, model_type)
183                scale_info = list(map(lambda x: np.round(x, 2), scale))
184                show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")
185
186        # Some models require an additional segmentation for inference or postprocessing.
187        # For these models we read out the 'Extra Segmentation' widget.
188        if model_type == "ribbon":  # Currently only the ribbon model needs the extra seg.
189            extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
190            kwargs = {"extra_segmentation": extra_seg}
191        elif model_type == "cristae":  # Cristae model expects 2 3D volumes
192            kwargs = {
193                "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name),
194                "with_channels": True,
195                "channels_to_standardize": [0]
196            }
197        else:
198            kwargs = {}
199        segmentation = run_segmentation(
200            image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
201        )
202
203        # Add the segmentation layer(s).
204        if isinstance(segmentation, dict):
205            for name, seg in segmentation.items():
206                self.viewer.add_labels(seg, name=name, metadata=metadata)
207        else:
208            self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata)
209        show_info(f"INFO: Segmentation of {model_type} added to layers.")
210
211    def _create_settings_widget(self):
212        setting_values = QWidget()
213        # setting_values.setToolTip(get_tooltip("embedding", "settings"))
214        setting_values.setLayout(QVBoxLayout())
215
216        # Create UI for the device.
217        device = "auto"
218        device_options = ["auto"] + _available_devices()
219
220        self.device_dropdown, layout = self._add_choice_param("device", device, device_options)
221        setting_values.layout().addLayout(layout)
222
223        # Create UI for the tile shape.
224        self.default_tiling = get_default_tiling()
225        self.tiling = copy.deepcopy(self.default_tiling)
226        self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param(
227            ("tile_x", "tile_y", "tile_z"),
228            (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]),
229            min_val=0, max_val=2048, step=16,
230            # tooltip=get_tooltip("embedding", "tiling")
231        )
232        setting_values.layout().addLayout(layout)
233
234        # Create UI for the halo.
235        self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param(
236            ("halo_x", "halo_y", "halo_z"),
237            (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]),
238            min_val=0, max_val=512,
239            # tooltip=get_tooltip("embedding", "halo")
240        )
241        setting_values.layout().addLayout(layout)
242
243        # Read voxel size from layer metadata.
244        self.voxel_size_param, layout = self._add_float_param(
245            "voxel_size", 0.0, min_val=0.0, max_val=100.0,
246        )
247        setting_values.layout().addLayout(layout)
248
249        self.checkpoint_param, layout = self._add_string_param(
250            name="checkpoint", value="", title="Load Custom Model",
251            placeholder="path/to/checkpoint.pt",
252        )
253        setting_values.layout().addLayout(layout)
254
255        # Add selection UI for additional segmentation, which some models require for inference or postproc.
256        self.extra_seg_selector_name = "Extra Segmentation"
257        self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels")
258        setting_values.layout().addWidget(self.extra_selector_widget)
259
260        settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
261        return settings
class SegmentationWidget(synapse_net.tools.base_widget.BaseWidget):
 94class SegmentationWidget(BaseWidget):
 95    def __init__(self):
 96        super().__init__()
 97
 98        self.viewer = napari.current_viewer()
 99        layout = QVBoxLayout()
100        self.tiling = {}
101
102        # Create the image selection dropdown.
103        self.image_selector_name = "Image data"
104        self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image")
105
106        # Create buttons and widgets.
107        self.predict_button = QPushButton("Run Segmentation")
108        self.predict_button.clicked.connect(self.on_predict)
109        self.model_selector_widget = self.load_model_widget()
110        self.settings = self._create_settings_widget()
111
112        # Add the widgets to the layout.
113        layout.addWidget(self.image_selector_widget)
114        layout.addWidget(self.model_selector_widget)
115        layout.addWidget(self.settings)
116        layout.addWidget(self.predict_button)
117
118        self.setLayout(layout)
119
120    def load_model_widget(self):
121        model_widget = QWidget()
122        title_label = QLabel("Select Model:")
123
124        # Exclude the models that are only offered through the CLI and not in the plugin.
125        model_list = set(_get_model_registry().urls.keys())
126        excluded_models = ["vesicles_2d_maus", "vesicles_3d_endbulb", "vesicles_3d_innerear"]
127        model_list = [name for name in model_list if name not in excluded_models]
128
129        models = ["- choose -"] + model_list
130        self.model_selector = QComboBox()
131        self.model_selector.addItems(models)
132        # Create a layout and add the title label and combo box
133        layout = QVBoxLayout()
134        layout.addWidget(title_label)
135        layout.addWidget(self.model_selector)
136
137        # Set layout on the model widget
138        model_widget.setLayout(layout)
139        return model_widget
140
141    def on_predict(self):
142        # Get the model and postprocessing settings.
143        model_type = self.model_selector.currentText()
144        custom_model_path = self.checkpoint_param.text()
145        if model_type == "- choose -" and custom_model_path is None:
146            show_info("INFO: Please choose a model.")
147            return
148
149        device = get_device(self.device_dropdown.currentText())
150
151        # Load the model. Override if user chose custom model
152        if custom_model_path:
153            model = _load_custom_model(custom_model_path, device)
154            if model:
155                show_info(f"INFO: Using custom model from path: {custom_model_path}")
156                model_type = "custom"
157            else:
158                show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
159                return
160        else:
161            model = get_model(model_type, device)
162
163        # Get the image data.
164        image = self._get_layer_selector_data(self.image_selector_name)
165        if image is None:
166            show_info("INFO: Please choose an image.")
167            return
168
169        # Get the current tiling.
170        self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
171
172        # Get the voxel size.
173        metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
174        voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)
175
176        # Determine the scaling based on the voxel size.
177        scale = None
178        if voxel_size:
179            if model_type == "custom":
180                show_info("INFO: The image is not rescaled for a custom model.")
181            else:
182                # calculate scale so voxel_size is the same as in training
183                scale = compute_scale_from_voxel_size(voxel_size, model_type)
184                scale_info = list(map(lambda x: np.round(x, 2), scale))
185                show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")
186
187        # Some models require an additional segmentation for inference or postprocessing.
188        # For these models we read out the 'Extra Segmentation' widget.
189        if model_type == "ribbon":  # Currently only the ribbon model needs the extra seg.
190            extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
191            kwargs = {"extra_segmentation": extra_seg}
192        elif model_type == "cristae":  # Cristae model expects 2 3D volumes
193            kwargs = {
194                "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name),
195                "with_channels": True,
196                "channels_to_standardize": [0]
197            }
198        else:
199            kwargs = {}
200        segmentation = run_segmentation(
201            image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
202        )
203
204        # Add the segmentation layer(s).
205        if isinstance(segmentation, dict):
206            for name, seg in segmentation.items():
207                self.viewer.add_labels(seg, name=name, metadata=metadata)
208        else:
209            self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata)
210        show_info(f"INFO: Segmentation of {model_type} added to layers.")
211
212    def _create_settings_widget(self):
213        setting_values = QWidget()
214        # setting_values.setToolTip(get_tooltip("embedding", "settings"))
215        setting_values.setLayout(QVBoxLayout())
216
217        # Create UI for the device.
218        device = "auto"
219        device_options = ["auto"] + _available_devices()
220
221        self.device_dropdown, layout = self._add_choice_param("device", device, device_options)
222        setting_values.layout().addLayout(layout)
223
224        # Create UI for the tile shape.
225        self.default_tiling = get_default_tiling()
226        self.tiling = copy.deepcopy(self.default_tiling)
227        self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param(
228            ("tile_x", "tile_y", "tile_z"),
229            (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]),
230            min_val=0, max_val=2048, step=16,
231            # tooltip=get_tooltip("embedding", "tiling")
232        )
233        setting_values.layout().addLayout(layout)
234
235        # Create UI for the halo.
236        self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param(
237            ("halo_x", "halo_y", "halo_z"),
238            (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]),
239            min_val=0, max_val=512,
240            # tooltip=get_tooltip("embedding", "halo")
241        )
242        setting_values.layout().addLayout(layout)
243
244        # Read voxel size from layer metadata.
245        self.voxel_size_param, layout = self._add_float_param(
246            "voxel_size", 0.0, min_val=0.0, max_val=100.0,
247        )
248        setting_values.layout().addLayout(layout)
249
250        self.checkpoint_param, layout = self._add_string_param(
251            name="checkpoint", value="", title="Load Custom Model",
252            placeholder="path/to/checkpoint.pt",
253        )
254        setting_values.layout().addLayout(layout)
255
256        # Add selection UI for additional segmentation, which some models require for inference or postproc.
257        self.extra_seg_selector_name = "Extra Segmentation"
258        self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels")
259        setting_values.layout().addWidget(self.extra_selector_widget)
260
261        settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
262        return settings

QWidget(parent: Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())

viewer
tiling
image_selector_name
image_selector_widget
predict_button
model_selector_widget
settings
def load_model_widget(self):
120    def load_model_widget(self):
121        model_widget = QWidget()
122        title_label = QLabel("Select Model:")
123
124        # Exclude the models that are only offered through the CLI and not in the plugin.
125        model_list = set(_get_model_registry().urls.keys())
126        excluded_models = ["vesicles_2d_maus", "vesicles_3d_endbulb", "vesicles_3d_innerear"]
127        model_list = [name for name in model_list if name not in excluded_models]
128
129        models = ["- choose -"] + model_list
130        self.model_selector = QComboBox()
131        self.model_selector.addItems(models)
132        # Create a layout and add the title label and combo box
133        layout = QVBoxLayout()
134        layout.addWidget(title_label)
135        layout.addWidget(self.model_selector)
136
137        # Set layout on the model widget
138        model_widget.setLayout(layout)
139        return model_widget
def on_predict(self):
141    def on_predict(self):
142        # Get the model and postprocessing settings.
143        model_type = self.model_selector.currentText()
144        custom_model_path = self.checkpoint_param.text()
145        if model_type == "- choose -" and custom_model_path is None:
146            show_info("INFO: Please choose a model.")
147            return
148
149        device = get_device(self.device_dropdown.currentText())
150
151        # Load the model. Override if user chose custom model
152        if custom_model_path:
153            model = _load_custom_model(custom_model_path, device)
154            if model:
155                show_info(f"INFO: Using custom model from path: {custom_model_path}")
156                model_type = "custom"
157            else:
158                show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
159                return
160        else:
161            model = get_model(model_type, device)
162
163        # Get the image data.
164        image = self._get_layer_selector_data(self.image_selector_name)
165        if image is None:
166            show_info("INFO: Please choose an image.")
167            return
168
169        # Get the current tiling.
170        self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
171
172        # Get the voxel size.
173        metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
174        voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)
175
176        # Determine the scaling based on the voxel size.
177        scale = None
178        if voxel_size:
179            if model_type == "custom":
180                show_info("INFO: The image is not rescaled for a custom model.")
181            else:
182                # calculate scale so voxel_size is the same as in training
183                scale = compute_scale_from_voxel_size(voxel_size, model_type)
184                scale_info = list(map(lambda x: np.round(x, 2), scale))
185                show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")
186
187        # Some models require an additional segmentation for inference or postprocessing.
188        # For these models we read out the 'Extra Segmentation' widget.
189        if model_type == "ribbon":  # Currently only the ribbon model needs the extra seg.
190            extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
191            kwargs = {"extra_segmentation": extra_seg}
192        elif model_type == "cristae":  # Cristae model expects 2 3D volumes
193            kwargs = {
194                "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name),
195                "with_channels": True,
196                "channels_to_standardize": [0]
197            }
198        else:
199            kwargs = {}
200        segmentation = run_segmentation(
201            image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
202        )
203
204        # Add the segmentation layer(s).
205        if isinstance(segmentation, dict):
206            for name, seg in segmentation.items():
207                self.viewer.add_labels(seg, name=name, metadata=metadata)
208        else:
209            self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata)
210        show_info(f"INFO: Segmentation of {model_type} added to layers.")