synapse_net.tools.segmentation_widget

  1import copy
  2import re
  3from typing import Optional, Union
  4
  5import napari
  6import numpy as np
  7import torch
  8
  9from napari.utils.notifications import show_info
 10from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox
 11
 12from .base_widget import BaseWidget
 13from ..inference.inference import _get_model_registry, get_model, run_segmentation, compute_scale_from_voxel_size
 14from ..inference.util import get_default_tiling, get_device
 15
 16
 17def _load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
 18    model_path = _clean_filepath(model_path)
 19    if device is None:
 20        device = get_device(device)
 21    try:
 22        model = torch.load(model_path, map_location=torch.device(device), weights_only=False)
 23    except Exception as e:
 24        print(e)
 25        print("model path", model_path)
 26        return None
 27    return model
 28
 29
 30def _available_devices():
 31    available_devices = []
 32    for i in ["cuda", "mps", "cpu"]:
 33        try:
 34            device = get_device(i)
 35        except RuntimeError:
 36            pass
 37        else:
 38            available_devices.append(device)
 39    return available_devices
 40
 41
 42def _get_current_tiling(tiling: dict, default_tiling: dict, model_type: str):
 43    # get tiling values from qt objects
 44    for k, v in tiling.items():
 45        for k2, v2 in v.items():
 46            if isinstance(v2, int):
 47                continue
 48            tiling[k][k2] = v2.value()
 49    # check if user inputs tiling/halo or not
 50    if default_tiling == tiling:
 51        if "2d" in model_type:
 52            # if its a 2d model expand x,y and set z to 1
 53            tiling = {
 54                "tile": {"x": 512, "y": 512, "z": 1},
 55                "halo": {"x": 64, "y": 64, "z": 1},
 56            }
 57    elif "2d" in model_type:
 58        # if its a 2d model set z to 1
 59        tiling["tile"]["z"] = 1
 60        tiling["halo"]["z"] = 1
 61
 62    return tiling
 63
 64
 65def _clean_filepath(filepath):
 66    """Cleans a given filepath by:
 67    - Removing newline characters (\n)
 68    - Removing escape sequences
 69    - Stripping the 'file://' prefix if present
 70
 71    Args:
 72        filepath (str): The original filepath
 73
 74    Returns:
 75        str: The cleaned filepath
 76    """
 77    # Remove 'file://' prefix if present
 78    if filepath.startswith("file://"):
 79        filepath = filepath[7:]
 80
 81    # Remove escape sequences and newlines
 82    filepath = re.sub(r'\\.', '', filepath)
 83    filepath = filepath.replace('\n', '').replace('\r', '')
 84
 85    return filepath
 86
 87
 88class SegmentationWidget(BaseWidget):
 89    def __init__(self):
 90        super().__init__()
 91
 92        self.viewer = napari.current_viewer()
 93        layout = QVBoxLayout()
 94        self.tiling = {}
 95
 96        # Create the image selection dropdown.
 97        self.image_selector_name = "Image data"
 98        self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image")
 99
100        # Create buttons and widgets.
101        self.predict_button = QPushButton("Run Segmentation")
102        self.predict_button.clicked.connect(self.on_predict)
103        self.model_selector_widget = self.load_model_widget()
104        self.settings = self._create_settings_widget()
105
106        # Add the widgets to the layout.
107        layout.addWidget(self.image_selector_widget)
108        layout.addWidget(self.model_selector_widget)
109        layout.addWidget(self.settings)
110        layout.addWidget(self.predict_button)
111
112        self.setLayout(layout)
113
114    def load_model_widget(self):
115        model_widget = QWidget()
116        title_label = QLabel("Select Model:")
117
118        models = ["- choose -"] + list(_get_model_registry().urls.keys())
119        self.model_selector = QComboBox()
120        self.model_selector.addItems(models)
121        # Create a layout and add the title label and combo box
122        layout = QVBoxLayout()
123        layout.addWidget(title_label)
124        layout.addWidget(self.model_selector)
125
126        # Set layout on the model widget
127        model_widget.setLayout(layout)
128        return model_widget
129
130    def on_predict(self):
131        # Get the model and postprocessing settings.
132        model_type = self.model_selector.currentText()
133        custom_model_path = self.checkpoint_param.text()
134        if model_type == "- choose -" and custom_model_path is None:
135            show_info("INFO: Please choose a model.")
136            return
137
138        device = get_device(self.device_dropdown.currentText())
139
140        # Load the model. Override if user chose custom model
141        if custom_model_path:
142            model = _load_custom_model(custom_model_path, device)
143            if model:
144                show_info(f"INFO: Using custom model from path: {custom_model_path}")
145                model_type = "custom"
146            else:
147                show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
148                return
149        else:
150            model = get_model(model_type, device)
151
152        # Get the image data.
153        image = self._get_layer_selector_data(self.image_selector_name)
154        if image is None:
155            show_info("INFO: Please choose an image.")
156            return
157
158        # Get the current tiling.
159        self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
160
161        # Get the voxel size.
162        metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
163        voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)
164
165        # Determine the scaling based on the voxel size.
166        scale = None
167        if voxel_size:
168            if model_type == "custom":
169                show_info("INFO: The image is not rescaled for a custom model.")
170            else:
171                # calculate scale so voxel_size is the same as in training
172                scale = compute_scale_from_voxel_size(voxel_size, model_type)
173                scale_info = list(map(lambda x: np.round(x, 2), scale))
174                show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")
175
176        # Some models require an additional segmentation for inference or postprocessing.
177        # For these models we read out the 'Extra Segmentation' widget.
178        if model_type == "ribbon":  # Currently only the ribbon model needs the extra seg.
179            extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
180            kwargs = {"extra_segmentation": extra_seg}
181        else:
182            kwargs = {}
183        segmentation = run_segmentation(
184            image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
185        )
186
187        # Add the segmentation layer(s).
188        if isinstance(segmentation, dict):
189            for name, seg in segmentation.items():
190                self.viewer.add_labels(seg, name=name, metadata=metadata)
191        else:
192            self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata)
193        show_info(f"INFO: Segmentation of {model_type} added to layers.")
194
195    def _create_settings_widget(self):
196        setting_values = QWidget()
197        # setting_values.setToolTip(get_tooltip("embedding", "settings"))
198        setting_values.setLayout(QVBoxLayout())
199
200        # Create UI for the device.
201        device = "auto"
202        device_options = ["auto"] + _available_devices()
203
204        self.device_dropdown, layout = self._add_choice_param("device", device, device_options)
205        setting_values.layout().addLayout(layout)
206
207        # Create UI for the tile shape.
208        self.default_tiling = get_default_tiling()
209        self.tiling = copy.deepcopy(self.default_tiling)
210        self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param(
211            ("tile_x", "tile_y", "tile_z"),
212            (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]),
213            min_val=0, max_val=2048, step=16,
214            # tooltip=get_tooltip("embedding", "tiling")
215        )
216        setting_values.layout().addLayout(layout)
217
218        # Create UI for the halo.
219        self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param(
220            ("halo_x", "halo_y", "halo_z"),
221            (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]),
222            min_val=0, max_val=512,
223            # tooltip=get_tooltip("embedding", "halo")
224        )
225        setting_values.layout().addLayout(layout)
226
227        # Read voxel size from layer metadata.
228        self.voxel_size_param, layout = self._add_float_param(
229            "voxel_size", 0.0, min_val=0.0, max_val=100.0,
230        )
231        setting_values.layout().addLayout(layout)
232
233        self.checkpoint_param, layout = self._add_string_param(
234            name="checkpoint", value="", title="Load Custom Model",
235            placeholder="path/to/checkpoint.pt",
236        )
237        setting_values.layout().addLayout(layout)
238
239        # Add selection UI for additional segmentation, which some models require for inference or postproc.
240        self.extra_seg_selector_name = "Extra Segmentation"
241        self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels")
242        setting_values.layout().addWidget(self.extra_selector_widget)
243
244        settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
245        return settings
class SegmentationWidget(synapse_net.tools.base_widget.BaseWidget):
 89class SegmentationWidget(BaseWidget):
 90    def __init__(self):
 91        super().__init__()
 92
 93        self.viewer = napari.current_viewer()
 94        layout = QVBoxLayout()
 95        self.tiling = {}
 96
 97        # Create the image selection dropdown.
 98        self.image_selector_name = "Image data"
 99        self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image")
100
101        # Create buttons and widgets.
102        self.predict_button = QPushButton("Run Segmentation")
103        self.predict_button.clicked.connect(self.on_predict)
104        self.model_selector_widget = self.load_model_widget()
105        self.settings = self._create_settings_widget()
106
107        # Add the widgets to the layout.
108        layout.addWidget(self.image_selector_widget)
109        layout.addWidget(self.model_selector_widget)
110        layout.addWidget(self.settings)
111        layout.addWidget(self.predict_button)
112
113        self.setLayout(layout)
114
115    def load_model_widget(self):
116        model_widget = QWidget()
117        title_label = QLabel("Select Model:")
118
119        models = ["- choose -"] + list(_get_model_registry().urls.keys())
120        self.model_selector = QComboBox()
121        self.model_selector.addItems(models)
122        # Create a layout and add the title label and combo box
123        layout = QVBoxLayout()
124        layout.addWidget(title_label)
125        layout.addWidget(self.model_selector)
126
127        # Set layout on the model widget
128        model_widget.setLayout(layout)
129        return model_widget
130
131    def on_predict(self):
132        # Get the model and postprocessing settings.
133        model_type = self.model_selector.currentText()
134        custom_model_path = self.checkpoint_param.text()
135        if model_type == "- choose -" and custom_model_path is None:
136            show_info("INFO: Please choose a model.")
137            return
138
139        device = get_device(self.device_dropdown.currentText())
140
141        # Load the model. Override if user chose custom model
142        if custom_model_path:
143            model = _load_custom_model(custom_model_path, device)
144            if model:
145                show_info(f"INFO: Using custom model from path: {custom_model_path}")
146                model_type = "custom"
147            else:
148                show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
149                return
150        else:
151            model = get_model(model_type, device)
152
153        # Get the image data.
154        image = self._get_layer_selector_data(self.image_selector_name)
155        if image is None:
156            show_info("INFO: Please choose an image.")
157            return
158
159        # Get the current tiling.
160        self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
161
162        # Get the voxel size.
163        metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
164        voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)
165
166        # Determine the scaling based on the voxel size.
167        scale = None
168        if voxel_size:
169            if model_type == "custom":
170                show_info("INFO: The image is not rescaled for a custom model.")
171            else:
172                # calculate scale so voxel_size is the same as in training
173                scale = compute_scale_from_voxel_size(voxel_size, model_type)
174                scale_info = list(map(lambda x: np.round(x, 2), scale))
175                show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")
176
177        # Some models require an additional segmentation for inference or postprocessing.
178        # For these models we read out the 'Extra Segmentation' widget.
179        if model_type == "ribbon":  # Currently only the ribbon model needs the extra seg.
180            extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
181            kwargs = {"extra_segmentation": extra_seg}
182        else:
183            kwargs = {}
184        segmentation = run_segmentation(
185            image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
186        )
187
188        # Add the segmentation layer(s).
189        if isinstance(segmentation, dict):
190            for name, seg in segmentation.items():
191                self.viewer.add_labels(seg, name=name, metadata=metadata)
192        else:
193            self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata)
194        show_info(f"INFO: Segmentation of {model_type} added to layers.")
195
196    def _create_settings_widget(self):
197        setting_values = QWidget()
198        # setting_values.setToolTip(get_tooltip("embedding", "settings"))
199        setting_values.setLayout(QVBoxLayout())
200
201        # Create UI for the device.
202        device = "auto"
203        device_options = ["auto"] + _available_devices()
204
205        self.device_dropdown, layout = self._add_choice_param("device", device, device_options)
206        setting_values.layout().addLayout(layout)
207
208        # Create UI for the tile shape.
209        self.default_tiling = get_default_tiling()
210        self.tiling = copy.deepcopy(self.default_tiling)
211        self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param(
212            ("tile_x", "tile_y", "tile_z"),
213            (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]),
214            min_val=0, max_val=2048, step=16,
215            # tooltip=get_tooltip("embedding", "tiling")
216        )
217        setting_values.layout().addLayout(layout)
218
219        # Create UI for the halo.
220        self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param(
221            ("halo_x", "halo_y", "halo_z"),
222            (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]),
223            min_val=0, max_val=512,
224            # tooltip=get_tooltip("embedding", "halo")
225        )
226        setting_values.layout().addLayout(layout)
227
228        # Read voxel size from layer metadata.
229        self.voxel_size_param, layout = self._add_float_param(
230            "voxel_size", 0.0, min_val=0.0, max_val=100.0,
231        )
232        setting_values.layout().addLayout(layout)
233
234        self.checkpoint_param, layout = self._add_string_param(
235            name="checkpoint", value="", title="Load Custom Model",
236            placeholder="path/to/checkpoint.pt",
237        )
238        setting_values.layout().addLayout(layout)
239
240        # Add selection UI for additional segmentation, which some models require for inference or postproc.
241        self.extra_seg_selector_name = "Extra Segmentation"
242        self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels")
243        setting_values.layout().addWidget(self.extra_selector_widget)
244
245        settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
246        return settings

QWidget(parent: typing.Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())

viewer
tiling
image_selector_name
image_selector_widget
predict_button
model_selector_widget
settings
def load_model_widget(self):
115    def load_model_widget(self):
116        model_widget = QWidget()
117        title_label = QLabel("Select Model:")
118
119        models = ["- choose -"] + list(_get_model_registry().urls.keys())
120        self.model_selector = QComboBox()
121        self.model_selector.addItems(models)
122        # Create a layout and add the title label and combo box
123        layout = QVBoxLayout()
124        layout.addWidget(title_label)
125        layout.addWidget(self.model_selector)
126
127        # Set layout on the model widget
128        model_widget.setLayout(layout)
129        return model_widget
def on_predict(self):
131    def on_predict(self):
132        # Get the model and postprocessing settings.
133        model_type = self.model_selector.currentText()
134        custom_model_path = self.checkpoint_param.text()
135        if model_type == "- choose -" and custom_model_path is None:
136            show_info("INFO: Please choose a model.")
137            return
138
139        device = get_device(self.device_dropdown.currentText())
140
141        # Load the model. Override if user chose custom model
142        if custom_model_path:
143            model = _load_custom_model(custom_model_path, device)
144            if model:
145                show_info(f"INFO: Using custom model from path: {custom_model_path}")
146                model_type = "custom"
147            else:
148                show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
149                return
150        else:
151            model = get_model(model_type, device)
152
153        # Get the image data.
154        image = self._get_layer_selector_data(self.image_selector_name)
155        if image is None:
156            show_info("INFO: Please choose an image.")
157            return
158
159        # Get the current tiling.
160        self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
161
162        # Get the voxel size.
163        metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
164        voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)
165
166        # Determine the scaling based on the voxel size.
167        scale = None
168        if voxel_size:
169            if model_type == "custom":
170                show_info("INFO: The image is not rescaled for a custom model.")
171            else:
172                # calculate scale so voxel_size is the same as in training
173                scale = compute_scale_from_voxel_size(voxel_size, model_type)
174                scale_info = list(map(lambda x: np.round(x, 2), scale))
175                show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")
176
177        # Some models require an additional segmentation for inference or postprocessing.
178        # For these models we read out the 'Extra Segmentation' widget.
179        if model_type == "ribbon":  # Currently only the ribbon model needs the extra seg.
180            extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
181            kwargs = {"extra_segmentation": extra_seg}
182        else:
183            kwargs = {}
184        segmentation = run_segmentation(
185            image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
186        )
187
188        # Add the segmentation layer(s).
189        if isinstance(segmentation, dict):
190            for name, seg in segmentation.items():
191                self.viewer.add_labels(seg, name=name, metadata=metadata)
192        else:
193            self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata)
194        show_info(f"INFO: Segmentation of {model_type} added to layers.")