synapse_net.tools.segmentation_widget
1import copy 2import re 3from typing import Optional, Union 4 5import napari 6import numpy as np 7import torch 8 9from napari.utils.notifications import show_info 10from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox 11 12from .base_widget import BaseWidget 13from ..inference.inference import _get_model_registry, get_model, run_segmentation, compute_scale_from_voxel_size 14from ..inference.util import get_default_tiling, get_device 15 16 17def _load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 18 model_path = _clean_filepath(model_path) 19 if device is None: 20 device = get_device(device) 21 try: 22 model = torch.load(model_path, map_location=torch.device(device), weights_only=False) 23 except Exception as e: 24 print(e) 25 print("model path", model_path) 26 return None 27 return model 28 29 30def _available_devices(): 31 available_devices = [] 32 for i in ["cuda", "mps", "cpu"]: 33 try: 34 device = get_device(i) 35 except RuntimeError: 36 pass 37 else: 38 available_devices.append(device) 39 return available_devices 40 41 42def _get_current_tiling(tiling: dict, default_tiling: dict, model_type: str): 43 # get tiling values from qt objects 44 for k, v in tiling.items(): 45 for k2, v2 in v.items(): 46 if isinstance(v2, int): 47 continue 48 elif hasattr(v2, "value"): # If it's a QSpinBox, extract the value 49 tiling[k][k2] = v2.value() 50 else: 51 raise TypeError(f"Unexpected type for tiling value: {type(v2)} at {k}/{k2}") 52 # check if user inputs tiling/halo or not 53 if default_tiling == tiling: 54 if "2d" in model_type: 55 # if its a 2d model expand x,y and set z to 1 56 tiling = { 57 "tile": {"x": 512, "y": 512, "z": 1}, 58 "halo": {"x": 64, "y": 64, "z": 1}, 59 } 60 else: 61 show_info(f"Using custom tiling: {tiling}") 62 if "2d" in model_type: 63 # if its a 2d model set z to 1 64 tiling["tile"]["z"] = 1 65 tiling["halo"]["z"] = 0 66 show_info(f"Using tiling: {tiling}") 67 return tiling 68 69 70def _clean_filepath(filepath): 71 """Cleans a given filepath by: 72 - Removing newline characters (\n) 73 - Removing escape sequences 74 - Stripping the 'file://' prefix if present 75 76 Args: 77 filepath (str): The original filepath 78 79 Returns: 80 str: The cleaned filepath 81 """ 82 # Remove 'file://' prefix if present 83 if filepath.startswith("file://"): 84 filepath = filepath[7:] 85 86 # Remove escape sequences and newlines 87 filepath = re.sub(r'\\.', '', filepath) 88 filepath = filepath.replace('\n', '').replace('\r', '') 89 90 return filepath 91 92 93class SegmentationWidget(BaseWidget): 94 def __init__(self): 95 super().__init__() 96 97 self.viewer = napari.current_viewer() 98 layout = QVBoxLayout() 99 self.tiling = {} 100 101 # Create the image selection dropdown. 102 self.image_selector_name = "Image data" 103 self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image") 104 105 # Create buttons and widgets. 106 self.predict_button = QPushButton("Run Segmentation") 107 self.predict_button.clicked.connect(self.on_predict) 108 self.model_selector_widget = self.load_model_widget() 109 self.settings = self._create_settings_widget() 110 111 # Add the widgets to the layout. 112 layout.addWidget(self.image_selector_widget) 113 layout.addWidget(self.model_selector_widget) 114 layout.addWidget(self.settings) 115 layout.addWidget(self.predict_button) 116 117 self.setLayout(layout) 118 119 def load_model_widget(self): 120 model_widget = QWidget() 121 title_label = QLabel("Select Model:") 122 123 # Exclude the models that are only offered through the CLI and not in the plugin. 124 model_list = set(_get_model_registry().urls.keys()) 125 # These are the models exlcuded due to their specificity and to keep the menu simple. 126 # TODO: we should at some point update the logic here, to make it easier to support further models 127 # without cluttering the UI. 128 excluded_models = ["vesicles_2d_maus"] 129 model_list = [name for name in model_list if name not in excluded_models] 130 131 models = ["- choose -"] + model_list 132 self.model_selector = QComboBox() 133 self.model_selector.addItems(models) 134 # Create a layout and add the title label and combo box 135 layout = QVBoxLayout() 136 layout.addWidget(title_label) 137 layout.addWidget(self.model_selector) 138 139 # Set layout on the model widget 140 model_widget.setLayout(layout) 141 return model_widget 142 143 def on_predict(self): 144 # Get the model and postprocessing settings. 145 model_type = self.model_selector.currentText() 146 custom_model_path = self.checkpoint_param.text() 147 if model_type == "- choose -": 148 show_info("INFO: Please choose a model.") 149 return 150 151 device = get_device(self.device_dropdown.currentText()) 152 153 # Load the model. Override if user chose custom model. 154 rescale_input = True 155 if custom_model_path: 156 model = _load_custom_model(custom_model_path, device) 157 rescale_input = False 158 if model: 159 show_info(f"INFO: Using custom model from path: {custom_model_path}") 160 else: 161 show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}") 162 return 163 else: 164 model = get_model(model_type, device) 165 166 # Get the image data. 167 image = self._get_layer_selector_data(self.image_selector_name) 168 if image is None: 169 show_info("INFO: Please choose an image.") 170 return 171 172 # Get the current tiling. 173 self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) 174 175 # Get the voxel size. 176 metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True) 177 voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False) 178 179 # Determine the scaling based on the voxel size. 180 scale = None 181 if voxel_size and rescale_input: 182 # Calculate scale so voxel_size is the same as in training. 183 scale = compute_scale_from_voxel_size(voxel_size, model_type) 184 scale_info = list(map(lambda x: np.round(x, 2), scale)) 185 show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.") 186 187 # Some models require an additional segmentation for inference or postprocessing. 188 # For these models we read out the 'Extra Segmentation' widget. 189 if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. 190 extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) 191 resolution = tuple(voxel_size[ax] for ax in "zyx") 192 kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000} 193 elif model_type.startswith("cristae"): # Cristae model expects 2 3D volumes 194 kwargs = { 195 "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name), 196 "with_channels": True, 197 "channels_to_standardize": [0] 198 } 199 else: 200 kwargs = {} 201 segmentation = run_segmentation( 202 image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs 203 ) 204 205 # Add the segmentation layer(s). 206 if isinstance(segmentation, dict): 207 for name, seg in segmentation.items(): 208 self.viewer.add_labels(seg, name=name, metadata=metadata) 209 else: 210 self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata) 211 show_info(f"INFO: Segmentation of {model_type} added to layers.") 212 213 def _create_settings_widget(self): 214 setting_values = QWidget() 215 # setting_values.setToolTip(get_tooltip("embedding", "settings")) 216 setting_values.setLayout(QVBoxLayout()) 217 218 # Create UI for the device. 219 device = "auto" 220 device_options = ["auto"] + _available_devices() 221 222 self.device_dropdown, layout = self._add_choice_param("device", device, device_options) 223 setting_values.layout().addLayout(layout) 224 225 # Create UI for the tile shape. 226 self.default_tiling = get_default_tiling() 227 self.tiling = copy.deepcopy(self.default_tiling) 228 self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param( 229 ("tile_x", "tile_y", "tile_z"), 230 (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]), 231 min_val=0, max_val=2048, step=16, 232 # tooltip=get_tooltip("embedding", "tiling") 233 ) 234 setting_values.layout().addLayout(layout) 235 236 # Create UI for the halo. 237 self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param( 238 ("halo_x", "halo_y", "halo_z"), 239 (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]), 240 min_val=0, max_val=512, 241 # tooltip=get_tooltip("embedding", "halo") 242 ) 243 setting_values.layout().addLayout(layout) 244 245 # Read voxel size from layer metadata. 246 self.voxel_size_param, layout = self._add_float_param( 247 "voxel_size", 0.0, min_val=0.0, max_val=100.0, 248 ) 249 setting_values.layout().addLayout(layout) 250 251 self.checkpoint_param, layout = self._add_string_param( 252 name="checkpoint", value="", title="Load Custom Model", 253 placeholder="path/to/checkpoint.pt", 254 ) 255 setting_values.layout().addLayout(layout) 256 257 # Add selection UI for additional segmentation, which some models require for inference or postproc. 258 self.extra_seg_selector_name = "Extra Segmentation" 259 self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels") 260 setting_values.layout().addWidget(self.extra_selector_widget) 261 262 settings = self._make_collapsible(widget=setting_values, title="Advanced Settings") 263 return settings
94class SegmentationWidget(BaseWidget): 95 def __init__(self): 96 super().__init__() 97 98 self.viewer = napari.current_viewer() 99 layout = QVBoxLayout() 100 self.tiling = {} 101 102 # Create the image selection dropdown. 103 self.image_selector_name = "Image data" 104 self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image") 105 106 # Create buttons and widgets. 107 self.predict_button = QPushButton("Run Segmentation") 108 self.predict_button.clicked.connect(self.on_predict) 109 self.model_selector_widget = self.load_model_widget() 110 self.settings = self._create_settings_widget() 111 112 # Add the widgets to the layout. 113 layout.addWidget(self.image_selector_widget) 114 layout.addWidget(self.model_selector_widget) 115 layout.addWidget(self.settings) 116 layout.addWidget(self.predict_button) 117 118 self.setLayout(layout) 119 120 def load_model_widget(self): 121 model_widget = QWidget() 122 title_label = QLabel("Select Model:") 123 124 # Exclude the models that are only offered through the CLI and not in the plugin. 125 model_list = set(_get_model_registry().urls.keys()) 126 # These are the models exlcuded due to their specificity and to keep the menu simple. 127 # TODO: we should at some point update the logic here, to make it easier to support further models 128 # without cluttering the UI. 129 excluded_models = ["vesicles_2d_maus"] 130 model_list = [name for name in model_list if name not in excluded_models] 131 132 models = ["- choose -"] + model_list 133 self.model_selector = QComboBox() 134 self.model_selector.addItems(models) 135 # Create a layout and add the title label and combo box 136 layout = QVBoxLayout() 137 layout.addWidget(title_label) 138 layout.addWidget(self.model_selector) 139 140 # Set layout on the model widget 141 model_widget.setLayout(layout) 142 return model_widget 143 144 def on_predict(self): 145 # Get the model and postprocessing settings. 146 model_type = self.model_selector.currentText() 147 custom_model_path = self.checkpoint_param.text() 148 if model_type == "- choose -": 149 show_info("INFO: Please choose a model.") 150 return 151 152 device = get_device(self.device_dropdown.currentText()) 153 154 # Load the model. Override if user chose custom model. 155 rescale_input = True 156 if custom_model_path: 157 model = _load_custom_model(custom_model_path, device) 158 rescale_input = False 159 if model: 160 show_info(f"INFO: Using custom model from path: {custom_model_path}") 161 else: 162 show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}") 163 return 164 else: 165 model = get_model(model_type, device) 166 167 # Get the image data. 168 image = self._get_layer_selector_data(self.image_selector_name) 169 if image is None: 170 show_info("INFO: Please choose an image.") 171 return 172 173 # Get the current tiling. 174 self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) 175 176 # Get the voxel size. 177 metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True) 178 voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False) 179 180 # Determine the scaling based on the voxel size. 181 scale = None 182 if voxel_size and rescale_input: 183 # Calculate scale so voxel_size is the same as in training. 184 scale = compute_scale_from_voxel_size(voxel_size, model_type) 185 scale_info = list(map(lambda x: np.round(x, 2), scale)) 186 show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.") 187 188 # Some models require an additional segmentation for inference or postprocessing. 189 # For these models we read out the 'Extra Segmentation' widget. 190 if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. 191 extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) 192 resolution = tuple(voxel_size[ax] for ax in "zyx") 193 kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000} 194 elif model_type.startswith("cristae"): # Cristae model expects 2 3D volumes 195 kwargs = { 196 "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name), 197 "with_channels": True, 198 "channels_to_standardize": [0] 199 } 200 else: 201 kwargs = {} 202 segmentation = run_segmentation( 203 image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs 204 ) 205 206 # Add the segmentation layer(s). 207 if isinstance(segmentation, dict): 208 for name, seg in segmentation.items(): 209 self.viewer.add_labels(seg, name=name, metadata=metadata) 210 else: 211 self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata) 212 show_info(f"INFO: Segmentation of {model_type} added to layers.") 213 214 def _create_settings_widget(self): 215 setting_values = QWidget() 216 # setting_values.setToolTip(get_tooltip("embedding", "settings")) 217 setting_values.setLayout(QVBoxLayout()) 218 219 # Create UI for the device. 220 device = "auto" 221 device_options = ["auto"] + _available_devices() 222 223 self.device_dropdown, layout = self._add_choice_param("device", device, device_options) 224 setting_values.layout().addLayout(layout) 225 226 # Create UI for the tile shape. 227 self.default_tiling = get_default_tiling() 228 self.tiling = copy.deepcopy(self.default_tiling) 229 self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param( 230 ("tile_x", "tile_y", "tile_z"), 231 (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]), 232 min_val=0, max_val=2048, step=16, 233 # tooltip=get_tooltip("embedding", "tiling") 234 ) 235 setting_values.layout().addLayout(layout) 236 237 # Create UI for the halo. 238 self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param( 239 ("halo_x", "halo_y", "halo_z"), 240 (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]), 241 min_val=0, max_val=512, 242 # tooltip=get_tooltip("embedding", "halo") 243 ) 244 setting_values.layout().addLayout(layout) 245 246 # Read voxel size from layer metadata. 247 self.voxel_size_param, layout = self._add_float_param( 248 "voxel_size", 0.0, min_val=0.0, max_val=100.0, 249 ) 250 setting_values.layout().addLayout(layout) 251 252 self.checkpoint_param, layout = self._add_string_param( 253 name="checkpoint", value="", title="Load Custom Model", 254 placeholder="path/to/checkpoint.pt", 255 ) 256 setting_values.layout().addLayout(layout) 257 258 # Add selection UI for additional segmentation, which some models require for inference or postproc. 259 self.extra_seg_selector_name = "Extra Segmentation" 260 self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels") 261 setting_values.layout().addWidget(self.extra_selector_widget) 262 263 settings = self._make_collapsible(widget=setting_values, title="Advanced Settings") 264 return settings
QWidget(parent: Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())
def
load_model_widget(self):
120 def load_model_widget(self): 121 model_widget = QWidget() 122 title_label = QLabel("Select Model:") 123 124 # Exclude the models that are only offered through the CLI and not in the plugin. 125 model_list = set(_get_model_registry().urls.keys()) 126 # These are the models exlcuded due to their specificity and to keep the menu simple. 127 # TODO: we should at some point update the logic here, to make it easier to support further models 128 # without cluttering the UI. 129 excluded_models = ["vesicles_2d_maus"] 130 model_list = [name for name in model_list if name not in excluded_models] 131 132 models = ["- choose -"] + model_list 133 self.model_selector = QComboBox() 134 self.model_selector.addItems(models) 135 # Create a layout and add the title label and combo box 136 layout = QVBoxLayout() 137 layout.addWidget(title_label) 138 layout.addWidget(self.model_selector) 139 140 # Set layout on the model widget 141 model_widget.setLayout(layout) 142 return model_widget
def
on_predict(self):
144 def on_predict(self): 145 # Get the model and postprocessing settings. 146 model_type = self.model_selector.currentText() 147 custom_model_path = self.checkpoint_param.text() 148 if model_type == "- choose -": 149 show_info("INFO: Please choose a model.") 150 return 151 152 device = get_device(self.device_dropdown.currentText()) 153 154 # Load the model. Override if user chose custom model. 155 rescale_input = True 156 if custom_model_path: 157 model = _load_custom_model(custom_model_path, device) 158 rescale_input = False 159 if model: 160 show_info(f"INFO: Using custom model from path: {custom_model_path}") 161 else: 162 show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}") 163 return 164 else: 165 model = get_model(model_type, device) 166 167 # Get the image data. 168 image = self._get_layer_selector_data(self.image_selector_name) 169 if image is None: 170 show_info("INFO: Please choose an image.") 171 return 172 173 # Get the current tiling. 174 self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) 175 176 # Get the voxel size. 177 metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True) 178 voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False) 179 180 # Determine the scaling based on the voxel size. 181 scale = None 182 if voxel_size and rescale_input: 183 # Calculate scale so voxel_size is the same as in training. 184 scale = compute_scale_from_voxel_size(voxel_size, model_type) 185 scale_info = list(map(lambda x: np.round(x, 2), scale)) 186 show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.") 187 188 # Some models require an additional segmentation for inference or postprocessing. 189 # For these models we read out the 'Extra Segmentation' widget. 190 if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. 191 extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) 192 resolution = tuple(voxel_size[ax] for ax in "zyx") 193 kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000} 194 elif model_type.startswith("cristae"): # Cristae model expects 2 3D volumes 195 kwargs = { 196 "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name), 197 "with_channels": True, 198 "channels_to_standardize": [0] 199 } 200 else: 201 kwargs = {} 202 segmentation = run_segmentation( 203 image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs 204 ) 205 206 # Add the segmentation layer(s). 207 if isinstance(segmentation, dict): 208 for name, seg in segmentation.items(): 209 self.viewer.add_labels(seg, name=name, metadata=metadata) 210 else: 211 self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata) 212 show_info(f"INFO: Segmentation of {model_type} added to layers.")