synapse_net.tools.segmentation_widget
1import copy 2import re 3from typing import Optional, Union 4 5import napari 6import numpy as np 7import torch 8 9from napari.utils.notifications import show_info 10from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox 11 12from .base_widget import BaseWidget 13from ..inference.inference import _get_model_registry, get_model, run_segmentation, compute_scale_from_voxel_size 14from ..inference.util import get_default_tiling, get_device 15 16 17def _load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 18 model_path = _clean_filepath(model_path) 19 if device is None: 20 device = get_device(device) 21 try: 22 model = torch.load(model_path, map_location=torch.device(device), weights_only=False) 23 except Exception as e: 24 print(e) 25 print("model path", model_path) 26 return None 27 return model 28 29 30def _available_devices(): 31 available_devices = [] 32 for i in ["cuda", "mps", "cpu"]: 33 try: 34 device = get_device(i) 35 except RuntimeError: 36 pass 37 else: 38 available_devices.append(device) 39 return available_devices 40 41 42def _get_current_tiling(tiling: dict, default_tiling: dict, model_type: str): 43 # get tiling values from qt objects 44 for k, v in tiling.items(): 45 for k2, v2 in v.items(): 46 if isinstance(v2, int): 47 continue 48 tiling[k][k2] = v2.value() 49 # check if user inputs tiling/halo or not 50 if default_tiling == tiling: 51 if "2d" in model_type: 52 # if its a 2d model expand x,y and set z to 1 53 tiling = { 54 "tile": {"x": 512, "y": 512, "z": 1}, 55 "halo": {"x": 64, "y": 64, "z": 1}, 56 } 57 elif "2d" in model_type: 58 # if its a 2d model set z to 1 59 tiling["tile"]["z"] = 1 60 tiling["halo"]["z"] = 1 61 62 return tiling 63 64 65def _clean_filepath(filepath): 66 """Cleans a given filepath by: 67 - Removing newline characters (\n) 68 - Removing escape sequences 69 - Stripping the 'file://' prefix if present 70 71 Args: 72 filepath (str): The original filepath 73 74 Returns: 75 str: The cleaned filepath 76 """ 77 # Remove 'file://' prefix if present 78 if filepath.startswith("file://"): 79 filepath = filepath[7:] 80 81 # Remove escape sequences and newlines 82 filepath = re.sub(r'\\.', '', filepath) 83 filepath = filepath.replace('\n', '').replace('\r', '') 84 85 return filepath 86 87 88class SegmentationWidget(BaseWidget): 89 def __init__(self): 90 super().__init__() 91 92 self.viewer = napari.current_viewer() 93 layout = QVBoxLayout() 94 self.tiling = {} 95 96 # Create the image selection dropdown. 97 self.image_selector_name = "Image data" 98 self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image") 99 100 # Create buttons and widgets. 101 self.predict_button = QPushButton("Run Segmentation") 102 self.predict_button.clicked.connect(self.on_predict) 103 self.model_selector_widget = self.load_model_widget() 104 self.settings = self._create_settings_widget() 105 106 # Add the widgets to the layout. 107 layout.addWidget(self.image_selector_widget) 108 layout.addWidget(self.model_selector_widget) 109 layout.addWidget(self.settings) 110 layout.addWidget(self.predict_button) 111 112 self.setLayout(layout) 113 114 def load_model_widget(self): 115 model_widget = QWidget() 116 title_label = QLabel("Select Model:") 117 118 models = ["- choose -"] + list(_get_model_registry().urls.keys()) 119 self.model_selector = QComboBox() 120 self.model_selector.addItems(models) 121 # Create a layout and add the title label and combo box 122 layout = QVBoxLayout() 123 layout.addWidget(title_label) 124 layout.addWidget(self.model_selector) 125 126 # Set layout on the model widget 127 model_widget.setLayout(layout) 128 return model_widget 129 130 def on_predict(self): 131 # Get the model and postprocessing settings. 132 model_type = self.model_selector.currentText() 133 custom_model_path = self.checkpoint_param.text() 134 if model_type == "- choose -" and custom_model_path is None: 135 show_info("INFO: Please choose a model.") 136 return 137 138 device = get_device(self.device_dropdown.currentText()) 139 140 # Load the model. Override if user chose custom model 141 if custom_model_path: 142 model = _load_custom_model(custom_model_path, device) 143 if model: 144 show_info(f"INFO: Using custom model from path: {custom_model_path}") 145 model_type = "custom" 146 else: 147 show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}") 148 return 149 else: 150 model = get_model(model_type, device) 151 152 # Get the image data. 153 image = self._get_layer_selector_data(self.image_selector_name) 154 if image is None: 155 show_info("INFO: Please choose an image.") 156 return 157 158 # Get the current tiling. 159 self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) 160 161 # Get the voxel size. 162 metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True) 163 voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False) 164 165 # Determine the scaling based on the voxel size. 166 scale = None 167 if voxel_size: 168 if model_type == "custom": 169 show_info("INFO: The image is not rescaled for a custom model.") 170 else: 171 # calculate scale so voxel_size is the same as in training 172 scale = compute_scale_from_voxel_size(voxel_size, model_type) 173 scale_info = list(map(lambda x: np.round(x, 2), scale)) 174 show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.") 175 176 # Some models require an additional segmentation for inference or postprocessing. 177 # For these models we read out the 'Extra Segmentation' widget. 178 if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. 179 extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) 180 kwargs = {"extra_segmentation": extra_seg} 181 else: 182 kwargs = {} 183 segmentation = run_segmentation( 184 image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs 185 ) 186 187 # Add the segmentation layer(s). 188 if isinstance(segmentation, dict): 189 for name, seg in segmentation.items(): 190 self.viewer.add_labels(seg, name=name, metadata=metadata) 191 else: 192 self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata) 193 show_info(f"INFO: Segmentation of {model_type} added to layers.") 194 195 def _create_settings_widget(self): 196 setting_values = QWidget() 197 # setting_values.setToolTip(get_tooltip("embedding", "settings")) 198 setting_values.setLayout(QVBoxLayout()) 199 200 # Create UI for the device. 201 device = "auto" 202 device_options = ["auto"] + _available_devices() 203 204 self.device_dropdown, layout = self._add_choice_param("device", device, device_options) 205 setting_values.layout().addLayout(layout) 206 207 # Create UI for the tile shape. 208 self.default_tiling = get_default_tiling() 209 self.tiling = copy.deepcopy(self.default_tiling) 210 self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param( 211 ("tile_x", "tile_y", "tile_z"), 212 (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]), 213 min_val=0, max_val=2048, step=16, 214 # tooltip=get_tooltip("embedding", "tiling") 215 ) 216 setting_values.layout().addLayout(layout) 217 218 # Create UI for the halo. 219 self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param( 220 ("halo_x", "halo_y", "halo_z"), 221 (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]), 222 min_val=0, max_val=512, 223 # tooltip=get_tooltip("embedding", "halo") 224 ) 225 setting_values.layout().addLayout(layout) 226 227 # Read voxel size from layer metadata. 228 self.voxel_size_param, layout = self._add_float_param( 229 "voxel_size", 0.0, min_val=0.0, max_val=100.0, 230 ) 231 setting_values.layout().addLayout(layout) 232 233 self.checkpoint_param, layout = self._add_string_param( 234 name="checkpoint", value="", title="Load Custom Model", 235 placeholder="path/to/checkpoint.pt", 236 ) 237 setting_values.layout().addLayout(layout) 238 239 # Add selection UI for additional segmentation, which some models require for inference or postproc. 240 self.extra_seg_selector_name = "Extra Segmentation" 241 self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels") 242 setting_values.layout().addWidget(self.extra_selector_widget) 243 244 settings = self._make_collapsible(widget=setting_values, title="Advanced Settings") 245 return settings
89class SegmentationWidget(BaseWidget): 90 def __init__(self): 91 super().__init__() 92 93 self.viewer = napari.current_viewer() 94 layout = QVBoxLayout() 95 self.tiling = {} 96 97 # Create the image selection dropdown. 98 self.image_selector_name = "Image data" 99 self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image") 100 101 # Create buttons and widgets. 102 self.predict_button = QPushButton("Run Segmentation") 103 self.predict_button.clicked.connect(self.on_predict) 104 self.model_selector_widget = self.load_model_widget() 105 self.settings = self._create_settings_widget() 106 107 # Add the widgets to the layout. 108 layout.addWidget(self.image_selector_widget) 109 layout.addWidget(self.model_selector_widget) 110 layout.addWidget(self.settings) 111 layout.addWidget(self.predict_button) 112 113 self.setLayout(layout) 114 115 def load_model_widget(self): 116 model_widget = QWidget() 117 title_label = QLabel("Select Model:") 118 119 models = ["- choose -"] + list(_get_model_registry().urls.keys()) 120 self.model_selector = QComboBox() 121 self.model_selector.addItems(models) 122 # Create a layout and add the title label and combo box 123 layout = QVBoxLayout() 124 layout.addWidget(title_label) 125 layout.addWidget(self.model_selector) 126 127 # Set layout on the model widget 128 model_widget.setLayout(layout) 129 return model_widget 130 131 def on_predict(self): 132 # Get the model and postprocessing settings. 133 model_type = self.model_selector.currentText() 134 custom_model_path = self.checkpoint_param.text() 135 if model_type == "- choose -" and custom_model_path is None: 136 show_info("INFO: Please choose a model.") 137 return 138 139 device = get_device(self.device_dropdown.currentText()) 140 141 # Load the model. Override if user chose custom model 142 if custom_model_path: 143 model = _load_custom_model(custom_model_path, device) 144 if model: 145 show_info(f"INFO: Using custom model from path: {custom_model_path}") 146 model_type = "custom" 147 else: 148 show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}") 149 return 150 else: 151 model = get_model(model_type, device) 152 153 # Get the image data. 154 image = self._get_layer_selector_data(self.image_selector_name) 155 if image is None: 156 show_info("INFO: Please choose an image.") 157 return 158 159 # Get the current tiling. 160 self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) 161 162 # Get the voxel size. 163 metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True) 164 voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False) 165 166 # Determine the scaling based on the voxel size. 167 scale = None 168 if voxel_size: 169 if model_type == "custom": 170 show_info("INFO: The image is not rescaled for a custom model.") 171 else: 172 # calculate scale so voxel_size is the same as in training 173 scale = compute_scale_from_voxel_size(voxel_size, model_type) 174 scale_info = list(map(lambda x: np.round(x, 2), scale)) 175 show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.") 176 177 # Some models require an additional segmentation for inference or postprocessing. 178 # For these models we read out the 'Extra Segmentation' widget. 179 if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. 180 extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) 181 kwargs = {"extra_segmentation": extra_seg} 182 else: 183 kwargs = {} 184 segmentation = run_segmentation( 185 image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs 186 ) 187 188 # Add the segmentation layer(s). 189 if isinstance(segmentation, dict): 190 for name, seg in segmentation.items(): 191 self.viewer.add_labels(seg, name=name, metadata=metadata) 192 else: 193 self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata) 194 show_info(f"INFO: Segmentation of {model_type} added to layers.") 195 196 def _create_settings_widget(self): 197 setting_values = QWidget() 198 # setting_values.setToolTip(get_tooltip("embedding", "settings")) 199 setting_values.setLayout(QVBoxLayout()) 200 201 # Create UI for the device. 202 device = "auto" 203 device_options = ["auto"] + _available_devices() 204 205 self.device_dropdown, layout = self._add_choice_param("device", device, device_options) 206 setting_values.layout().addLayout(layout) 207 208 # Create UI for the tile shape. 209 self.default_tiling = get_default_tiling() 210 self.tiling = copy.deepcopy(self.default_tiling) 211 self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param( 212 ("tile_x", "tile_y", "tile_z"), 213 (self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]), 214 min_val=0, max_val=2048, step=16, 215 # tooltip=get_tooltip("embedding", "tiling") 216 ) 217 setting_values.layout().addLayout(layout) 218 219 # Create UI for the halo. 220 self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param( 221 ("halo_x", "halo_y", "halo_z"), 222 (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]), 223 min_val=0, max_val=512, 224 # tooltip=get_tooltip("embedding", "halo") 225 ) 226 setting_values.layout().addLayout(layout) 227 228 # Read voxel size from layer metadata. 229 self.voxel_size_param, layout = self._add_float_param( 230 "voxel_size", 0.0, min_val=0.0, max_val=100.0, 231 ) 232 setting_values.layout().addLayout(layout) 233 234 self.checkpoint_param, layout = self._add_string_param( 235 name="checkpoint", value="", title="Load Custom Model", 236 placeholder="path/to/checkpoint.pt", 237 ) 238 setting_values.layout().addLayout(layout) 239 240 # Add selection UI for additional segmentation, which some models require for inference or postproc. 241 self.extra_seg_selector_name = "Extra Segmentation" 242 self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels") 243 setting_values.layout().addWidget(self.extra_selector_widget) 244 245 settings = self._make_collapsible(widget=setting_values, title="Advanced Settings") 246 return settings
QWidget(parent: typing.Optional[QWidget] = None, flags: Union[Qt.WindowFlags, Qt.WindowType] = Qt.WindowFlags())
def
load_model_widget(self):
115 def load_model_widget(self): 116 model_widget = QWidget() 117 title_label = QLabel("Select Model:") 118 119 models = ["- choose -"] + list(_get_model_registry().urls.keys()) 120 self.model_selector = QComboBox() 121 self.model_selector.addItems(models) 122 # Create a layout and add the title label and combo box 123 layout = QVBoxLayout() 124 layout.addWidget(title_label) 125 layout.addWidget(self.model_selector) 126 127 # Set layout on the model widget 128 model_widget.setLayout(layout) 129 return model_widget
def
on_predict(self):
131 def on_predict(self): 132 # Get the model and postprocessing settings. 133 model_type = self.model_selector.currentText() 134 custom_model_path = self.checkpoint_param.text() 135 if model_type == "- choose -" and custom_model_path is None: 136 show_info("INFO: Please choose a model.") 137 return 138 139 device = get_device(self.device_dropdown.currentText()) 140 141 # Load the model. Override if user chose custom model 142 if custom_model_path: 143 model = _load_custom_model(custom_model_path, device) 144 if model: 145 show_info(f"INFO: Using custom model from path: {custom_model_path}") 146 model_type = "custom" 147 else: 148 show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}") 149 return 150 else: 151 model = get_model(model_type, device) 152 153 # Get the image data. 154 image = self._get_layer_selector_data(self.image_selector_name) 155 if image is None: 156 show_info("INFO: Please choose an image.") 157 return 158 159 # Get the current tiling. 160 self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) 161 162 # Get the voxel size. 163 metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True) 164 voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False) 165 166 # Determine the scaling based on the voxel size. 167 scale = None 168 if voxel_size: 169 if model_type == "custom": 170 show_info("INFO: The image is not rescaled for a custom model.") 171 else: 172 # calculate scale so voxel_size is the same as in training 173 scale = compute_scale_from_voxel_size(voxel_size, model_type) 174 scale_info = list(map(lambda x: np.round(x, 2), scale)) 175 show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.") 176 177 # Some models require an additional segmentation for inference or postprocessing. 178 # For these models we read out the 'Extra Segmentation' widget. 179 if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. 180 extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) 181 kwargs = {"extra_segmentation": extra_seg} 182 else: 183 kwargs = {} 184 segmentation = run_segmentation( 185 image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs 186 ) 187 188 # Add the segmentation layer(s). 189 if isinstance(segmentation, dict): 190 for name, seg in segmentation.items(): 191 self.viewer.add_labels(seg, name=name, metadata=metadata) 192 else: 193 self.viewer.add_labels(segmentation, name=f"{model_type}", metadata=metadata) 194 show_info(f"INFO: Segmentation of {model_type} added to layers.")