synapse_net.tools.segmentation_widget
1import copy 2import re 3from typing import Optional, Union 4 5import napari 6import numpy as np 7import torch 8 9from napari.utils.notifications import show_info 10from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox 11 12from .base_widget import BaseWidget 13from ..inference.inference import _get_model_registry, get_model, run_segmentation, compute_scale_from_voxel_size 14from ..inference.util import get_default_tiling, get_device 15 16 17def _load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 18 model_path = _clean_filepath(model_path) 19 if device is None: 20 device = get_device(device) 21 try: 22 model = torch.load(model_path, map_location=torch.device(device), weights_only=False) 23 except Exception as e: 24 print(e) 25 print("model path", model_path) 26 return None 27 return model 28 29 30def _available_devices(): 31 available_devices = [] 32 for i in ["cuda", "mps", "cpu"]: 33 try: 34 device = get_device(i) 35 except RuntimeError: 36 pass 37 else: 38 available_devices.append(device) 39 return available_devices 40 41 42def _get_current_tiling(tiling: dict, default_tiling: dict, model_type: str): 43 # get tiling values from qt objects 44 for k, v in tiling.items(): 45 for k2, v2 in v.items(): 46 if isinstance(v2, int): 47 continue 48 elif hasattr(v2, "value"): # If it's a QSpinBox, extract the value 49 tiling[k][k2] = v2.value() 50 else: 51 raise TypeError(f"Unexpected type for tiling value: {type(v2)} at {k}/{k2}") 52 # check if user inputs tiling/halo or not 53 if default_tiling == tiling: 54 if "2d" in model_type: 55 # if its a 2d model expand x,y and set z to 1 56 tiling = { 57 "tile": {"x": 512, "y": 512, "z": 1}, 58 "halo": {"x": 64, "y": 64, "z": 1}, 59 } 60 else: 61 show_info(f"Using custom tiling: {tiling}") 62 if "2d" in model_type: 63 # if its a 2d model set z to 1 64 tiling["tile"]["z"] = 1 65 tiling["halo"]["z"] = 0 66 show_info(f"Using tiling: {tiling}") 67 return tiling 68 69 70def _clean_filepath(filepath): 71 """Cleans a given filepath by: 72 - Removing newline characters (\n) 73 - Removing escape sequences 74 - Stripping the 'file://' prefix if present 75 76 Args: 77 filepath (str): The original filepath 78 79 Returns: 80 str: The cleaned filepath 81 """ 82 # Remove 'file://' prefix if present 83 if filepath.startswith("file://"): 84 filepath = filepath[7:] 85 86 # Remove escape sequences and newlines 87 filepath = re.sub(r'\\.', '', filepath) 88 filepath = filepath.replace('\n', '').replace('\r', '') 89 90 return filepath 91 92 93class SegmentationWidget(BaseWidget): 94 def __init__(self): 95 super().__init__() 96 97 self.viewer = napari.current_viewer() 98 layout = QVBoxLayout() 99 self.tiling = {} 100 101 # Create the image selection dropdown. 102 self.image_selector_name = "Image data" 103 self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image") 104 105 # Create buttons and widgets. 106 self.predict_button = QPushButton("Run Segmentation") 107 self.predict_button.clicked.connect(self.on_predict) 108 self.model_selector_widget = self.load_model_widget() 109 self.settings = self._create_settings_widget() 110 111 # Add the widgets to the layout. 112 layout.addWidget(self.image_selector_widget) 113 layout.addWidget(self.model_selector_widget) 114 layout.addWidget(self.settings) 115 layout.addWidget(self.predict_button) 116 117 self.setLayout(layout) 118 119 def load_model_widget(self): 120 model_widget = QWidget() 121 title_label = QLabel("Select Model:") 122 123 # Exclude the models that are only offered through the CLI and not in the plugin. 124 model_list = set(_get_model_registry().urls.keys()) 125 # These are the models exlcuded due to their specificity and to keep the menu simple. 126 # TODO: we should at some point update the logic here, to make it easier to support further models 127 # without cluttering the UI. 128 excluded_models = ["vesicles_2d_maus"] 129 model_list = [name for name in model_list if name not in excluded_models] 130 131 models = ["- choose -"] + model_list 132 self.model_selector = QComboBox() 133 self.model_selector.addItems(models) 134 # Create a layout and add the title label and combo box 135 layout = QVBoxLayout() 136 layout.addWidget(title_label) 137 layout.addWidget(self.model_selector) 138 139 # Set layout on the model widget 140 model_widget.setLayout(layout) 141 return model_widget 142 143 def on_predict(self): 144 # Get the model and postprocessing settings. 145 model_type = self.model_selector.currentText() 146 custom_model_path = self.checkpoint_param.text() 147 if model_type == "- choose -" and custom_model_path is None: 148 show_info("INFO: Please choose a model.") 149 return 150 151 device = get_device(self.device_dropdown.currentText()) 152 153 # Load the model. Override if user chose custom model 154 if custom_model_path: 155 model = _load_custom_model(custom_model_path, device) 156 if model: 157 show_info(f"INFO: Using custom model from path: {custom_model_path}") 158 model_type = "custom" 159 else: 160 show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}") 161 return 162 else: 163 model = get_model(model_type, device) 164 165 # Get the image data. 166 image = self._get_layer_selector_data(self.image_selector_name) 167 if image is None: 168 show_info("INFO: Please choose an image.") 169 return 170 171 # Get the current tiling. 172 self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) 173 174 # Get the voxel size. 175 metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True) 176 voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False) 177 178 # Determine the scaling based on the voxel size. 179 scale = None 180 if voxel_size: 181 if model_type == "custom": 182 show_info("INFO: The image is not rescaled for a custom model.") 183 else: 184 # calculate scale so voxel_size is the same as in training 185 scale = compute_scale_from_voxel_size(voxel_size, model_type) 186 scale_info = list(map(lambda x: np.round(x, 2), scale)) 187 show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.") 188 189 # Some models require an additional segmentation for inference or postprocessing. 190 # For these models we read out the 'Extra Segmentation' widget. 191 if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. 192 extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) 193 resolution = tuple(voxel_size[ax] for ax in "zyx") 194 kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000} 195 elif model_type == "cristae": # Cristae model expects 2 3D volumes 196 kwargs = { 197 "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name), 198 "with_channels": True, 199 "channels_to_standardize": [0] 200 } 201 else: 202 kwargs = {} 203 segmentation = run_segmentation( 204 image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs 205 ) 206 207 # Add the segmentation layer(s). 208 if isinstance(segmentation, dict): 209 for name, seg in segmentation.items(): 210 self.viewer.add_labels(seg, name=name, metadata=metadata) 211 else: 212 self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata) 213 show_info(f"INFO: Segmentation of {model_type} added to layers.") 214 215 def _create_settings_widget(self): 216 setting_values = QWidget() 217 # setting_values.setToolTip(get_tooltip("embedding", "settings")) 218 setting_values.setLayout(QVBoxLayout()) 219 220 # Create UI for the device. 221 device = "auto" 222 device_options = ["auto"] + _available_devices() 223 224 self.device_dropdown, layout = self._add_choice_param("device", device, device_options) 225 setting_values.layout().addLayout(layout) 226 227 # Create UI for the tile shape. 228 self.default_tiling = get_default_tiling() 229 self.tiling = copy.deepcopy(self.default_tiling) 230 self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param( 231 ("tile_x", "tile_y", "tile_z"), 232 (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]), 233 min_val=0, max_val=2048, step=16, 234 # tooltip=get_tooltip("embedding", "tiling") 235 ) 236 setting_values.layout().addLayout(layout) 237 238 # Create UI for the halo. 239 self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param( 240 ("halo_x", "halo_y", "halo_z"), 241 (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]), 242 min_val=0, max_val=512, 243 # tooltip=get_tooltip("embedding", "halo") 244 ) 245 setting_values.layout().addLayout(layout) 246 247 # Read voxel size from layer metadata. 248 self.voxel_size_param, layout = self._add_float_param( 249 "voxel_size", 0.0, min_val=0.0, max_val=100.0, 250 ) 251 setting_values.layout().addLayout(layout) 252 253 self.checkpoint_param, layout = self._add_string_param( 254 name="checkpoint", value="", title="Load Custom Model", 255 placeholder="path/to/checkpoint.pt", 256 ) 257 setting_values.layout().addLayout(layout) 258 259 # Add selection UI for additional segmentation, which some models require for inference or postproc. 260 self.extra_seg_selector_name = "Extra Segmentation" 261 self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels") 262 setting_values.layout().addWidget(self.extra_selector_widget) 263 264 settings = self._make_collapsible(widget=setting_values, title="Advanced Settings") 265 return settings
94class SegmentationWidget(BaseWidget): 95 def __init__(self): 96 super().__init__() 97 98 self.viewer = napari.current_viewer() 99 layout = QVBoxLayout() 100 self.tiling = {} 101 102 # Create the image selection dropdown. 103 self.image_selector_name = "Image data" 104 self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image") 105 106 # Create buttons and widgets. 107 self.predict_button = QPushButton("Run Segmentation") 108 self.predict_button.clicked.connect(self.on_predict) 109 self.model_selector_widget = self.load_model_widget() 110 self.settings = self._create_settings_widget() 111 112 # Add the widgets to the layout. 113 layout.addWidget(self.image_selector_widget) 114 layout.addWidget(self.model_selector_widget) 115 layout.addWidget(self.settings) 116 layout.addWidget(self.predict_button) 117 118 self.setLayout(layout) 119 120 def load_model_widget(self): 121 model_widget = QWidget() 122 title_label = QLabel("Select Model:") 123 124 # Exclude the models that are only offered through the CLI and not in the plugin. 125 model_list = set(_get_model_registry().urls.keys()) 126 # These are the models exlcuded due to their specificity and to keep the menu simple. 127 # TODO: we should at some point update the logic here, to make it easier to support further models 128 # without cluttering the UI. 129 excluded_models = ["vesicles_2d_maus"] 130 model_list = [name for name in model_list if name not in excluded_models] 131 132 models = ["- choose -"] + model_list 133 self.model_selector = QComboBox() 134 self.model_selector.addItems(models) 135 # Create a layout and add the title label and combo box 136 layout = QVBoxLayout() 137 layout.addWidget(title_label) 138 layout.addWidget(self.model_selector) 139 140 # Set layout on the model widget 141 model_widget.setLayout(layout) 142 return model_widget 143 144 def on_predict(self): 145 # Get the model and postprocessing settings. 146 model_type = self.model_selector.currentText() 147 custom_model_path = self.checkpoint_param.text() 148 if model_type == "- choose -" and custom_model_path is None: 149 show_info("INFO: Please choose a model.") 150 return 151 152 device = get_device(self.device_dropdown.currentText()) 153 154 # Load the model. Override if user chose custom model 155 if custom_model_path: 156 model = _load_custom_model(custom_model_path, device) 157 if model: 158 show_info(f"INFO: Using custom model from path: {custom_model_path}") 159 model_type = "custom" 160 else: 161 show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}") 162 return 163 else: 164 model = get_model(model_type, device) 165 166 # Get the image data. 167 image = self._get_layer_selector_data(self.image_selector_name) 168 if image is None: 169 show_info("INFO: Please choose an image.") 170 return 171 172 # Get the current tiling. 173 self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) 174 175 # Get the voxel size. 176 metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True) 177 voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False) 178 179 # Determine the scaling based on the voxel size. 180 scale = None 181 if voxel_size: 182 if model_type == "custom": 183 show_info("INFO: The image is not rescaled for a custom model.") 184 else: 185 # calculate scale so voxel_size is the same as in training 186 scale = compute_scale_from_voxel_size(voxel_size, model_type) 187 scale_info = list(map(lambda x: np.round(x, 2), scale)) 188 show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.") 189 190 # Some models require an additional segmentation for inference or postprocessing. 191 # For these models we read out the 'Extra Segmentation' widget. 192 if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. 193 extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) 194 resolution = tuple(voxel_size[ax] for ax in "zyx") 195 kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000} 196 elif model_type == "cristae": # Cristae model expects 2 3D volumes 197 kwargs = { 198 "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name), 199 "with_channels": True, 200 "channels_to_standardize": [0] 201 } 202 else: 203 kwargs = {} 204 segmentation = run_segmentation( 205 image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs 206 ) 207 208 # Add the segmentation layer(s). 209 if isinstance(segmentation, dict): 210 for name, seg in segmentation.items(): 211 self.viewer.add_labels(seg, name=name, metadata=metadata) 212 else: 213 self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata) 214 show_info(f"INFO: Segmentation of {model_type} added to layers.") 215 216 def _create_settings_widget(self): 217 setting_values = QWidget() 218 # setting_values.setToolTip(get_tooltip("embedding", "settings")) 219 setting_values.setLayout(QVBoxLayout()) 220 221 # Create UI for the device. 222 device = "auto" 223 device_options = ["auto"] + _available_devices() 224 225 self.device_dropdown, layout = self._add_choice_param("device", device, device_options) 226 setting_values.layout().addLayout(layout) 227 228 # Create UI for the tile shape. 229 self.default_tiling = get_default_tiling() 230 self.tiling = copy.deepcopy(self.default_tiling) 231 self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param( 232 ("tile_x", "tile_y", "tile_z"), 233 (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]), 234 min_val=0, max_val=2048, step=16, 235 # tooltip=get_tooltip("embedding", "tiling") 236 ) 237 setting_values.layout().addLayout(layout) 238 239 # Create UI for the halo. 240 self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param( 241 ("halo_x", "halo_y", "halo_z"), 242 (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]), 243 min_val=0, max_val=512, 244 # tooltip=get_tooltip("embedding", "halo") 245 ) 246 setting_values.layout().addLayout(layout) 247 248 # Read voxel size from layer metadata. 249 self.voxel_size_param, layout = self._add_float_param( 250 "voxel_size", 0.0, min_val=0.0, max_val=100.0, 251 ) 252 setting_values.layout().addLayout(layout) 253 254 self.checkpoint_param, layout = self._add_string_param( 255 name="checkpoint", value="", title="Load Custom Model", 256 placeholder="path/to/checkpoint.pt", 257 ) 258 setting_values.layout().addLayout(layout) 259 260 # Add selection UI for additional segmentation, which some models require for inference or postproc. 261 self.extra_seg_selector_name = "Extra Segmentation" 262 self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels") 263 setting_values.layout().addWidget(self.extra_selector_widget) 264 265 settings = self._make_collapsible(widget=setting_values, title="Advanced Settings") 266 return settings
QWidget(parent: Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())
def
load_model_widget(self):
120 def load_model_widget(self): 121 model_widget = QWidget() 122 title_label = QLabel("Select Model:") 123 124 # Exclude the models that are only offered through the CLI and not in the plugin. 125 model_list = set(_get_model_registry().urls.keys()) 126 # These are the models exlcuded due to their specificity and to keep the menu simple. 127 # TODO: we should at some point update the logic here, to make it easier to support further models 128 # without cluttering the UI. 129 excluded_models = ["vesicles_2d_maus"] 130 model_list = [name for name in model_list if name not in excluded_models] 131 132 models = ["- choose -"] + model_list 133 self.model_selector = QComboBox() 134 self.model_selector.addItems(models) 135 # Create a layout and add the title label and combo box 136 layout = QVBoxLayout() 137 layout.addWidget(title_label) 138 layout.addWidget(self.model_selector) 139 140 # Set layout on the model widget 141 model_widget.setLayout(layout) 142 return model_widget
def
on_predict(self):
144 def on_predict(self): 145 # Get the model and postprocessing settings. 146 model_type = self.model_selector.currentText() 147 custom_model_path = self.checkpoint_param.text() 148 if model_type == "- choose -" and custom_model_path is None: 149 show_info("INFO: Please choose a model.") 150 return 151 152 device = get_device(self.device_dropdown.currentText()) 153 154 # Load the model. Override if user chose custom model 155 if custom_model_path: 156 model = _load_custom_model(custom_model_path, device) 157 if model: 158 show_info(f"INFO: Using custom model from path: {custom_model_path}") 159 model_type = "custom" 160 else: 161 show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}") 162 return 163 else: 164 model = get_model(model_type, device) 165 166 # Get the image data. 167 image = self._get_layer_selector_data(self.image_selector_name) 168 if image is None: 169 show_info("INFO: Please choose an image.") 170 return 171 172 # Get the current tiling. 173 self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) 174 175 # Get the voxel size. 176 metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True) 177 voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False) 178 179 # Determine the scaling based on the voxel size. 180 scale = None 181 if voxel_size: 182 if model_type == "custom": 183 show_info("INFO: The image is not rescaled for a custom model.") 184 else: 185 # calculate scale so voxel_size is the same as in training 186 scale = compute_scale_from_voxel_size(voxel_size, model_type) 187 scale_info = list(map(lambda x: np.round(x, 2), scale)) 188 show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.") 189 190 # Some models require an additional segmentation for inference or postprocessing. 191 # For these models we read out the 'Extra Segmentation' widget. 192 if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. 193 extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) 194 resolution = tuple(voxel_size[ax] for ax in "zyx") 195 kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000} 196 elif model_type == "cristae": # Cristae model expects 2 3D volumes 197 kwargs = { 198 "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name), 199 "with_channels": True, 200 "channels_to_standardize": [0] 201 } 202 else: 203 kwargs = {} 204 segmentation = run_segmentation( 205 image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs 206 ) 207 208 # Add the segmentation layer(s). 209 if isinstance(segmentation, dict): 210 for name, seg in segmentation.items(): 211 self.viewer.add_labels(seg, name=name, metadata=metadata) 212 else: 213 self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata) 214 show_info(f"INFO: Segmentation of {model_type} added to layers.")