synapse_net.tools.segmentation_widget
1import copy 2import re 3from typing import Optional, Union 4 5import napari 6import numpy as np 7import torch 8 9from napari.utils.notifications import show_info 10from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox 11 12from .base_widget import BaseWidget 13from ..inference.inference import _get_model_registry, get_model, run_segmentation, compute_scale_from_voxel_size 14from ..inference.util import get_default_tiling, get_device 15 16 17def _load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 18 model_path = _clean_filepath(model_path) 19 if device is None: 20 device = get_device(device) 21 try: 22 model = torch.load(model_path, map_location=torch.device(device), weights_only=False) 23 except Exception as e: 24 print(e) 25 print("model path", model_path) 26 return None 27 return model 28 29 30def _available_devices(): 31 available_devices = [] 32 for i in ["cuda", "mps", "cpu"]: 33 try: 34 device = get_device(i) 35 except RuntimeError: 36 pass 37 else: 38 available_devices.append(device) 39 return available_devices 40 41 42def _get_current_tiling(tiling: dict, default_tiling: dict, model_type: str): 43 # get tiling values from qt objects 44 for k, v in tiling.items(): 45 for k2, v2 in v.items(): 46 if isinstance(v2, int): 47 continue 48 elif hasattr(v2, "value"): # If it's a QSpinBox, extract the value 49 tiling[k][k2] = v2.value() 50 else: 51 raise TypeError(f"Unexpected type for tiling value: {type(v2)} at {k}/{k2}") 52 # check if user inputs tiling/halo or not 53 if default_tiling == tiling: 54 if "2d" in model_type: 55 # if its a 2d model expand x,y and set z to 1 56 tiling = { 57 "tile": {"x": 512, "y": 512, "z": 1}, 58 "halo": {"x": 64, "y": 64, "z": 1}, 59 } 60 else: 61 show_info(f"Using custom tiling: {tiling}") 62 if "2d" in model_type: 63 # if its a 2d model set z to 1 64 tiling["tile"]["z"] = 1 65 tiling["halo"]["z"] = 0 66 show_info(f"Using tiling: {tiling}") 67 return tiling 68 69 70def _clean_filepath(filepath): 71 """Cleans a given filepath by: 72 - Removing newline characters (\n) 73 - Removing escape sequences 74 - Stripping the 'file://' prefix if present 75 76 Args: 77 filepath (str): The original filepath 78 79 Returns: 80 str: The cleaned filepath 81 """ 82 # Remove 'file://' prefix if present 83 if filepath.startswith("file://"): 84 filepath = filepath[7:] 85 86 # Remove escape sequences and newlines 87 filepath = re.sub(r'\\.', '', filepath) 88 filepath = filepath.replace('\n', '').replace('\r', '') 89 90 return filepath 91 92 93class SegmentationWidget(BaseWidget): 94 def __init__(self): 95 super().__init__() 96 97 self.viewer = napari.current_viewer() 98 layout = QVBoxLayout() 99 self.tiling = {} 100 101 # Create the image selection dropdown. 102 self.image_selector_name = "Image data" 103 self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image") 104 105 # Create buttons and widgets. 106 self.predict_button = QPushButton("Run Segmentation") 107 self.predict_button.clicked.connect(self.on_predict) 108 self.model_selector_widget = self.load_model_widget() 109 self.settings = self._create_settings_widget() 110 111 # Add the widgets to the layout. 112 layout.addWidget(self.image_selector_widget) 113 layout.addWidget(self.model_selector_widget) 114 layout.addWidget(self.settings) 115 layout.addWidget(self.predict_button) 116 117 self.setLayout(layout) 118 119 def load_model_widget(self): 120 model_widget = QWidget() 121 title_label = QLabel("Select Model:") 122 123 # Exclude the models that are only offered through the CLI and not in the plugin. 124 model_list = set(_get_model_registry().urls.keys()) 125 excluded_models = ["vesicles_2d_maus", "vesicles_3d_endbulb", "vesicles_3d_innerear"] 126 model_list = [name for name in model_list if name not in excluded_models] 127 128 models = ["- choose -"] + model_list 129 self.model_selector = QComboBox() 130 self.model_selector.addItems(models) 131 # Create a layout and add the title label and combo box 132 layout = QVBoxLayout() 133 layout.addWidget(title_label) 134 layout.addWidget(self.model_selector) 135 136 # Set layout on the model widget 137 model_widget.setLayout(layout) 138 return model_widget 139 140 def on_predict(self): 141 # Get the model and postprocessing settings. 142 model_type = self.model_selector.currentText() 143 custom_model_path = self.checkpoint_param.text() 144 if model_type == "- choose -" and custom_model_path is None: 145 show_info("INFO: Please choose a model.") 146 return 147 148 device = get_device(self.device_dropdown.currentText()) 149 150 # Load the model. Override if user chose custom model 151 if custom_model_path: 152 model = _load_custom_model(custom_model_path, device) 153 if model: 154 show_info(f"INFO: Using custom model from path: {custom_model_path}") 155 model_type = "custom" 156 else: 157 show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}") 158 return 159 else: 160 model = get_model(model_type, device) 161 162 # Get the image data. 163 image = self._get_layer_selector_data(self.image_selector_name) 164 if image is None: 165 show_info("INFO: Please choose an image.") 166 return 167 168 # Get the current tiling. 169 self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) 170 171 # Get the voxel size. 172 metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True) 173 voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False) 174 175 # Determine the scaling based on the voxel size. 176 scale = None 177 if voxel_size: 178 if model_type == "custom": 179 show_info("INFO: The image is not rescaled for a custom model.") 180 else: 181 # calculate scale so voxel_size is the same as in training 182 scale = compute_scale_from_voxel_size(voxel_size, model_type) 183 scale_info = list(map(lambda x: np.round(x, 2), scale)) 184 show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.") 185 186 # Some models require an additional segmentation for inference or postprocessing. 187 # For these models we read out the 'Extra Segmentation' widget. 188 if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. 189 extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) 190 kwargs = {"extra_segmentation": extra_seg} 191 elif model_type == "cristae": # Cristae model expects 2 3D volumes 192 kwargs = { 193 "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name), 194 "with_channels": True, 195 "channels_to_standardize": [0] 196 } 197 else: 198 kwargs = {} 199 segmentation = run_segmentation( 200 image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs 201 ) 202 203 # Add the segmentation layer(s). 204 if isinstance(segmentation, dict): 205 for name, seg in segmentation.items(): 206 self.viewer.add_labels(seg, name=name, metadata=metadata) 207 else: 208 self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata) 209 show_info(f"INFO: Segmentation of {model_type} added to layers.") 210 211 def _create_settings_widget(self): 212 setting_values = QWidget() 213 # setting_values.setToolTip(get_tooltip("embedding", "settings")) 214 setting_values.setLayout(QVBoxLayout()) 215 216 # Create UI for the device. 217 device = "auto" 218 device_options = ["auto"] + _available_devices() 219 220 self.device_dropdown, layout = self._add_choice_param("device", device, device_options) 221 setting_values.layout().addLayout(layout) 222 223 # Create UI for the tile shape. 224 self.default_tiling = get_default_tiling() 225 self.tiling = copy.deepcopy(self.default_tiling) 226 self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param( 227 ("tile_x", "tile_y", "tile_z"), 228 (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]), 229 min_val=0, max_val=2048, step=16, 230 # tooltip=get_tooltip("embedding", "tiling") 231 ) 232 setting_values.layout().addLayout(layout) 233 234 # Create UI for the halo. 235 self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param( 236 ("halo_x", "halo_y", "halo_z"), 237 (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]), 238 min_val=0, max_val=512, 239 # tooltip=get_tooltip("embedding", "halo") 240 ) 241 setting_values.layout().addLayout(layout) 242 243 # Read voxel size from layer metadata. 244 self.voxel_size_param, layout = self._add_float_param( 245 "voxel_size", 0.0, min_val=0.0, max_val=100.0, 246 ) 247 setting_values.layout().addLayout(layout) 248 249 self.checkpoint_param, layout = self._add_string_param( 250 name="checkpoint", value="", title="Load Custom Model", 251 placeholder="path/to/checkpoint.pt", 252 ) 253 setting_values.layout().addLayout(layout) 254 255 # Add selection UI for additional segmentation, which some models require for inference or postproc. 256 self.extra_seg_selector_name = "Extra Segmentation" 257 self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels") 258 setting_values.layout().addWidget(self.extra_selector_widget) 259 260 settings = self._make_collapsible(widget=setting_values, title="Advanced Settings") 261 return settings
94class SegmentationWidget(BaseWidget): 95 def __init__(self): 96 super().__init__() 97 98 self.viewer = napari.current_viewer() 99 layout = QVBoxLayout() 100 self.tiling = {} 101 102 # Create the image selection dropdown. 103 self.image_selector_name = "Image data" 104 self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image") 105 106 # Create buttons and widgets. 107 self.predict_button = QPushButton("Run Segmentation") 108 self.predict_button.clicked.connect(self.on_predict) 109 self.model_selector_widget = self.load_model_widget() 110 self.settings = self._create_settings_widget() 111 112 # Add the widgets to the layout. 113 layout.addWidget(self.image_selector_widget) 114 layout.addWidget(self.model_selector_widget) 115 layout.addWidget(self.settings) 116 layout.addWidget(self.predict_button) 117 118 self.setLayout(layout) 119 120 def load_model_widget(self): 121 model_widget = QWidget() 122 title_label = QLabel("Select Model:") 123 124 # Exclude the models that are only offered through the CLI and not in the plugin. 125 model_list = set(_get_model_registry().urls.keys()) 126 excluded_models = ["vesicles_2d_maus", "vesicles_3d_endbulb", "vesicles_3d_innerear"] 127 model_list = [name for name in model_list if name not in excluded_models] 128 129 models = ["- choose -"] + model_list 130 self.model_selector = QComboBox() 131 self.model_selector.addItems(models) 132 # Create a layout and add the title label and combo box 133 layout = QVBoxLayout() 134 layout.addWidget(title_label) 135 layout.addWidget(self.model_selector) 136 137 # Set layout on the model widget 138 model_widget.setLayout(layout) 139 return model_widget 140 141 def on_predict(self): 142 # Get the model and postprocessing settings. 143 model_type = self.model_selector.currentText() 144 custom_model_path = self.checkpoint_param.text() 145 if model_type == "- choose -" and custom_model_path is None: 146 show_info("INFO: Please choose a model.") 147 return 148 149 device = get_device(self.device_dropdown.currentText()) 150 151 # Load the model. Override if user chose custom model 152 if custom_model_path: 153 model = _load_custom_model(custom_model_path, device) 154 if model: 155 show_info(f"INFO: Using custom model from path: {custom_model_path}") 156 model_type = "custom" 157 else: 158 show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}") 159 return 160 else: 161 model = get_model(model_type, device) 162 163 # Get the image data. 164 image = self._get_layer_selector_data(self.image_selector_name) 165 if image is None: 166 show_info("INFO: Please choose an image.") 167 return 168 169 # Get the current tiling. 170 self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) 171 172 # Get the voxel size. 173 metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True) 174 voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False) 175 176 # Determine the scaling based on the voxel size. 177 scale = None 178 if voxel_size: 179 if model_type == "custom": 180 show_info("INFO: The image is not rescaled for a custom model.") 181 else: 182 # calculate scale so voxel_size is the same as in training 183 scale = compute_scale_from_voxel_size(voxel_size, model_type) 184 scale_info = list(map(lambda x: np.round(x, 2), scale)) 185 show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.") 186 187 # Some models require an additional segmentation for inference or postprocessing. 188 # For these models we read out the 'Extra Segmentation' widget. 189 if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. 190 extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) 191 kwargs = {"extra_segmentation": extra_seg} 192 elif model_type == "cristae": # Cristae model expects 2 3D volumes 193 kwargs = { 194 "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name), 195 "with_channels": True, 196 "channels_to_standardize": [0] 197 } 198 else: 199 kwargs = {} 200 segmentation = run_segmentation( 201 image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs 202 ) 203 204 # Add the segmentation layer(s). 205 if isinstance(segmentation, dict): 206 for name, seg in segmentation.items(): 207 self.viewer.add_labels(seg, name=name, metadata=metadata) 208 else: 209 self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata) 210 show_info(f"INFO: Segmentation of {model_type} added to layers.") 211 212 def _create_settings_widget(self): 213 setting_values = QWidget() 214 # setting_values.setToolTip(get_tooltip("embedding", "settings")) 215 setting_values.setLayout(QVBoxLayout()) 216 217 # Create UI for the device. 218 device = "auto" 219 device_options = ["auto"] + _available_devices() 220 221 self.device_dropdown, layout = self._add_choice_param("device", device, device_options) 222 setting_values.layout().addLayout(layout) 223 224 # Create UI for the tile shape. 225 self.default_tiling = get_default_tiling() 226 self.tiling = copy.deepcopy(self.default_tiling) 227 self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param( 228 ("tile_x", "tile_y", "tile_z"), 229 (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]), 230 min_val=0, max_val=2048, step=16, 231 # tooltip=get_tooltip("embedding", "tiling") 232 ) 233 setting_values.layout().addLayout(layout) 234 235 # Create UI for the halo. 236 self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param( 237 ("halo_x", "halo_y", "halo_z"), 238 (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]), 239 min_val=0, max_val=512, 240 # tooltip=get_tooltip("embedding", "halo") 241 ) 242 setting_values.layout().addLayout(layout) 243 244 # Read voxel size from layer metadata. 245 self.voxel_size_param, layout = self._add_float_param( 246 "voxel_size", 0.0, min_val=0.0, max_val=100.0, 247 ) 248 setting_values.layout().addLayout(layout) 249 250 self.checkpoint_param, layout = self._add_string_param( 251 name="checkpoint", value="", title="Load Custom Model", 252 placeholder="path/to/checkpoint.pt", 253 ) 254 setting_values.layout().addLayout(layout) 255 256 # Add selection UI for additional segmentation, which some models require for inference or postproc. 257 self.extra_seg_selector_name = "Extra Segmentation" 258 self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels") 259 setting_values.layout().addWidget(self.extra_selector_widget) 260 261 settings = self._make_collapsible(widget=setting_values, title="Advanced Settings") 262 return settings
QWidget(parent: Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())
def
load_model_widget(self):
120 def load_model_widget(self): 121 model_widget = QWidget() 122 title_label = QLabel("Select Model:") 123 124 # Exclude the models that are only offered through the CLI and not in the plugin. 125 model_list = set(_get_model_registry().urls.keys()) 126 excluded_models = ["vesicles_2d_maus", "vesicles_3d_endbulb", "vesicles_3d_innerear"] 127 model_list = [name for name in model_list if name not in excluded_models] 128 129 models = ["- choose -"] + model_list 130 self.model_selector = QComboBox() 131 self.model_selector.addItems(models) 132 # Create a layout and add the title label and combo box 133 layout = QVBoxLayout() 134 layout.addWidget(title_label) 135 layout.addWidget(self.model_selector) 136 137 # Set layout on the model widget 138 model_widget.setLayout(layout) 139 return model_widget
def
on_predict(self):
141 def on_predict(self): 142 # Get the model and postprocessing settings. 143 model_type = self.model_selector.currentText() 144 custom_model_path = self.checkpoint_param.text() 145 if model_type == "- choose -" and custom_model_path is None: 146 show_info("INFO: Please choose a model.") 147 return 148 149 device = get_device(self.device_dropdown.currentText()) 150 151 # Load the model. Override if user chose custom model 152 if custom_model_path: 153 model = _load_custom_model(custom_model_path, device) 154 if model: 155 show_info(f"INFO: Using custom model from path: {custom_model_path}") 156 model_type = "custom" 157 else: 158 show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}") 159 return 160 else: 161 model = get_model(model_type, device) 162 163 # Get the image data. 164 image = self._get_layer_selector_data(self.image_selector_name) 165 if image is None: 166 show_info("INFO: Please choose an image.") 167 return 168 169 # Get the current tiling. 170 self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) 171 172 # Get the voxel size. 173 metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True) 174 voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False) 175 176 # Determine the scaling based on the voxel size. 177 scale = None 178 if voxel_size: 179 if model_type == "custom": 180 show_info("INFO: The image is not rescaled for a custom model.") 181 else: 182 # calculate scale so voxel_size is the same as in training 183 scale = compute_scale_from_voxel_size(voxel_size, model_type) 184 scale_info = list(map(lambda x: np.round(x, 2), scale)) 185 show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.") 186 187 # Some models require an additional segmentation for inference or postprocessing. 188 # For these models we read out the 'Extra Segmentation' widget. 189 if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. 190 extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) 191 kwargs = {"extra_segmentation": extra_seg} 192 elif model_type == "cristae": # Cristae model expects 2 3D volumes 193 kwargs = { 194 "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name), 195 "with_channels": True, 196 "channels_to_standardize": [0] 197 } 198 else: 199 kwargs = {} 200 segmentation = run_segmentation( 201 image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs 202 ) 203 204 # Add the segmentation layer(s). 205 if isinstance(segmentation, dict): 206 for name, seg in segmentation.items(): 207 self.viewer.add_labels(seg, name=name, metadata=metadata) 208 else: 209 self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata) 210 show_info(f"INFO: Segmentation of {model_type} added to layers.")