synapse_net.inference.inference
1import os 2from typing import Dict, List, Optional, Union 3 4import torch 5import numpy as np 6import pooch 7 8from .active_zone import segment_active_zone 9from .compartments import segment_compartments 10from .mitochondria import segment_mitochondria 11from .ribbon_synapse import segment_ribbon_synapse_structures 12from .vesicles import segment_vesicles 13from .cristae import segment_cristae 14from .util import get_device 15from ..file_utils import get_cache_dir 16 17 18# 19# Functions to access SynapseNet's pretrained models. 20# 21 22 23def _get_model_registry(): 24 registry = { 25 "active_zone": "c23652a8fe06daa113546af6d3200c4c1dcc79917056c6ed7357b8c93548372a", 26 "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1", 27 "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186", 28 "mitochondria2": "0ec4c48fb67ebcdf1c2a86710e1d5e40519758b867e49a6999d155e8eb15d459", 29 "cristae": "f96c90484f4ea92ac0515a06e389cc117580f02c2aacdc44b5828820cf38c3c3", 30 "cristae2": "0864945698862df043adc51c0034289a579b0622a61164e5ebd00a24ee25d075", 31 "cristae3": "5cb8699487bd21204071cbfb784a7f5c2bb6ab5f347a9d02a913aa27ae70eca4", 32 "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9", 33 "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1", 34 "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29", 35 "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b", 36 # Additional models that are only available in the CLI, not in the plugin model selection. 37 "vesicles_2d_maus": "01506895df6343fc33ffc9c9eb3f975bf42eb4eaaaf4848bac83b57f1b46e460", 38 "vesicles_3d_endbulb": "8582c7e3e5f16ef2bf34d6f9e34644862ca3c76835c9e7d44475c9dd7891d228", 39 "vesicles_3d_innerear": "924f0f7cfb648a3a6931c1d48d8b1fdc6c0c0d2cb3330fe2cae49d13e7c3b69d", 40 } 41 urls = { 42 "active_zone": "https://owncloud.gwdg.de/index.php/s/wpea9FH9waG4zJd/download", 43 "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download", 44 "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download", 45 "mitochondria2": "https://owncloud.gwdg.de/index.php/s/jivHzhpsqXN3PoH/download", 46 "cristae": "https://owncloud.gwdg.de/index.php/s/Df7OUOyQ1Kc2eEO/download", 47 "cristae2": "https://owncloud.gwdg.de/index.php/s/qe0R5pRgH2m0pQ5/download", 48 "cristae3": "https://owncloud.gwdg.de/index.php/s/51X1et8d7pOfdhU/download", 49 "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download", 50 "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download", 51 "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download", 52 "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download", 53 # Additional models that are only available in the CLI, not in the plugin model selection. 54 "vesicles_2d_maus": "https://owncloud.gwdg.de/index.php/s/sZ8woLr0zs5zOpv/download", 55 "vesicles_3d_endbulb": "https://owncloud.gwdg.de/index.php/s/16tmnWrEDpYIMzU/download", 56 "vesicles_3d_innerear": "https://owncloud.gwdg.de/index.php/s/UFUCYivsCxrqISX/download", 57 } 58 cache_dir = get_cache_dir() 59 models = pooch.create( 60 path=os.path.join(cache_dir, "models"), 61 base_url="", 62 registry=registry, 63 urls=urls, 64 ) 65 return models 66 67 68def get_available_models() -> List[str]: 69 """Get the names of all available pretrained models. 70 71 Returns: 72 The list of available model names. 73 """ 74 model_registry = _get_model_registry() 75 return list(model_registry.urls.keys()) 76 77 78def get_model_path(model_type: str) -> str: 79 """Get the local path to a pretrained model. 80 81 Args: 82 The model type. 83 84 Returns: 85 The local path to the model. 86 """ 87 model_registry = _get_model_registry() 88 model_path = model_registry.fetch(model_type) 89 return model_path 90 91 92def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 93 """Get the model for a specific segmentation type. 94 95 Args: 96 model_type: The model for one of the following segmentation tasks: 97 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. 98 device: The device to use. 99 100 Returns: 101 The model. 102 """ 103 if device is None: 104 device = get_device(device) 105 model_path = get_model_path(model_type) 106 model = torch.load(model_path, weights_only=False) 107 model.to(device) 108 return model 109 110 111# 112# Functions for training resolution / voxel size. 113# 114 115 116def get_model_training_resolution(model_type: str) -> Dict[str, float]: 117 """Get the average resolution / voxel size of the training data for a given pretrained model. 118 119 Args: 120 model_type: The name of the pretrained model. 121 122 Returns: 123 Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. 124 """ 125 resolutions = { 126 "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38}, 127 "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, 128 "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, 129 "mitochondria2": {"x": 2.87, "y": 2.87, "z": 2.87}, 130 "cristae": {"x": 1.44, "y": 1.44, "z": 1.44}, 131 "cristae2": {"x": 1.44, "y": 1.44, "z": 1.44}, 132 "cristae3": {"x": 1.44, "y": 1.44, "z": 1.44}, 133 "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, 134 "vesicles_2d": {"x": 1.35, "y": 1.35}, 135 "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, 136 "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, 137 # TODO add the correct resolutions, these are the resolutions of the source models. 138 "vesicles_2d_maus": {"x": 1.35, "y": 1.35}, 139 "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35}, 140 "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35}, 141 } 142 return resolutions[model_type] 143 144 145def compute_scale_from_voxel_size( 146 voxel_size: Dict[str, float], 147 model_type: str 148) -> List[float]: 149 """Compute the appropriate scale factor for inference with a given pretrained model. 150 151 Args: 152 voxel_size: The voxel size of the data for inference. 153 model_type: The name of the pretrained model. 154 155 Returns: 156 The scale factor, as a list in zyx order. 157 """ 158 training_voxel_size = get_model_training_resolution(model_type) 159 scale = [ 160 voxel_size["x"] / training_voxel_size["x"], 161 voxel_size["y"] / training_voxel_size["y"], 162 ] 163 if len(voxel_size) == 3 and len(training_voxel_size) == 3: 164 scale.append( 165 voxel_size["z"] / training_voxel_size["z"] 166 ) 167 return scale 168 169 170# 171# Convenience functions for segmentation. 172# 173 174 175def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size): 176 from synapse_net.inference.postprocessing import ( 177 segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, 178 ) 179 180 ribbon = segment_ribbon( 181 predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons, 182 max_vesicle_distance=40, 183 ) 184 PD = segment_presynaptic_density( 185 predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40, 186 ) 187 ref_segmentation = PD if PD.sum() > 0 else ribbon 188 membrane = segment_membrane_distance_based( 189 predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude, 190 resolution=resolution, min_size=min_membrane_size, 191 ) 192 193 segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane} 194 return segmentations 195 196 197def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs): 198 # Parse additional keyword arguments from the kwargs. 199 vesicles = kwargs.pop("extra_segmentation") 200 threshold = kwargs.pop("threshold", 0.5) 201 n_slices_exclude = kwargs.pop("n_slices_exclude", 20) 202 n_ribbons = kwargs.pop("n_slices_exclude", 1) 203 resolution = kwargs.pop("resolution", None) 204 min_membrane_size = kwargs.pop("min_membrane_size", 0) 205 206 predictions = segment_ribbon_synapse_structures( 207 image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs 208 ) 209 210 # Otherwise, just return the predictions. 211 if vesicles is None: 212 if verbose: 213 print("Vesicle segmentation was not passed, WILL NOT run post-processing.") 214 segmentations = predictions 215 216 # If the vesicles were passed then run additional post-processing. 217 else: 218 if verbose: 219 print("Vesicle segmentation was passed, WILL run post-processing.") 220 segmentations = _ribbon_AZ_postprocessing( 221 predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size 222 ) 223 224 if return_predictions: 225 return segmentations, predictions 226 return segmentations 227 228 229def run_segmentation( 230 image: np.ndarray, 231 model: torch.nn.Module, 232 model_type: str, 233 tiling: Optional[Dict[str, Dict[str, int]]] = None, 234 scale: Optional[List[float]] = None, 235 verbose: bool = False, 236 **kwargs, 237) -> np.ndarray | Dict[str, np.ndarray]: 238 """Run synaptic structure segmentation. 239 240 Args: 241 image: The input image or image volume. 242 model: The segmentation model. 243 model_type: The model type. This will determine which segmentation post-processing is used. 244 tiling: The tiling settings for inference. 245 scale: A scale factor for resizing the input before applying the model. 246 The output will be scaled back to the initial size. 247 verbose: Whether to print detailed information about the prediction and segmentation. 248 kwargs: Optional parameters for the segmentation function. 249 250 Returns: 251 The segmentation. For models that return multiple segmentations, this function returns a dictionary. 252 """ 253 if model_type.startswith("vesicles"): 254 segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 255 elif model_type == "mitochondria" or model_type == "mitochondria2": 256 segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 257 elif model_type == "active_zone": 258 segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 259 elif model_type == "compartments": 260 segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 261 elif model_type == "ribbon": 262 segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 263 elif model_type == "cristae" or model_type == "cristae2" or model_type == "cristae3": 264 segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 265 else: 266 raise ValueError(f"Unknown model type: {model_type}") 267 return segmentation
def
get_available_models() -> List[str]:
69def get_available_models() -> List[str]: 70 """Get the names of all available pretrained models. 71 72 Returns: 73 The list of available model names. 74 """ 75 model_registry = _get_model_registry() 76 return list(model_registry.urls.keys())
Get the names of all available pretrained models.
Returns:
The list of available model names.
def
get_model_path(model_type: str) -> str:
79def get_model_path(model_type: str) -> str: 80 """Get the local path to a pretrained model. 81 82 Args: 83 The model type. 84 85 Returns: 86 The local path to the model. 87 """ 88 model_registry = _get_model_registry() 89 model_path = model_registry.fetch(model_type) 90 return model_path
Get the local path to a pretrained model.
Arguments:
- The model type.
Returns:
The local path to the model.
def
get_model( model_type: str, device: Union[torch.device, str, NoneType] = None) -> torch.nn.modules.module.Module:
93def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 94 """Get the model for a specific segmentation type. 95 96 Args: 97 model_type: The model for one of the following segmentation tasks: 98 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. 99 device: The device to use. 100 101 Returns: 102 The model. 103 """ 104 if device is None: 105 device = get_device(device) 106 model_path = get_model_path(model_type) 107 model = torch.load(model_path, weights_only=False) 108 model.to(device) 109 return model
Get the model for a specific segmentation type.
Arguments:
- model_type: The model for one of the following segmentation tasks: 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
- device: The device to use.
Returns:
The model.
def
get_model_training_resolution(model_type: str) -> Dict[str, float]:
117def get_model_training_resolution(model_type: str) -> Dict[str, float]: 118 """Get the average resolution / voxel size of the training data for a given pretrained model. 119 120 Args: 121 model_type: The name of the pretrained model. 122 123 Returns: 124 Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. 125 """ 126 resolutions = { 127 "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38}, 128 "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, 129 "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, 130 "mitochondria2": {"x": 2.87, "y": 2.87, "z": 2.87}, 131 "cristae": {"x": 1.44, "y": 1.44, "z": 1.44}, 132 "cristae2": {"x": 1.44, "y": 1.44, "z": 1.44}, 133 "cristae3": {"x": 1.44, "y": 1.44, "z": 1.44}, 134 "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, 135 "vesicles_2d": {"x": 1.35, "y": 1.35}, 136 "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, 137 "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, 138 # TODO add the correct resolutions, these are the resolutions of the source models. 139 "vesicles_2d_maus": {"x": 1.35, "y": 1.35}, 140 "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35}, 141 "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35}, 142 } 143 return resolutions[model_type]
Get the average resolution / voxel size of the training data for a given pretrained model.
Arguments:
- model_type: The name of the pretrained model.
Returns:
Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
def
compute_scale_from_voxel_size(voxel_size: Dict[str, float], model_type: str) -> List[float]:
146def compute_scale_from_voxel_size( 147 voxel_size: Dict[str, float], 148 model_type: str 149) -> List[float]: 150 """Compute the appropriate scale factor for inference with a given pretrained model. 151 152 Args: 153 voxel_size: The voxel size of the data for inference. 154 model_type: The name of the pretrained model. 155 156 Returns: 157 The scale factor, as a list in zyx order. 158 """ 159 training_voxel_size = get_model_training_resolution(model_type) 160 scale = [ 161 voxel_size["x"] / training_voxel_size["x"], 162 voxel_size["y"] / training_voxel_size["y"], 163 ] 164 if len(voxel_size) == 3 and len(training_voxel_size) == 3: 165 scale.append( 166 voxel_size["z"] / training_voxel_size["z"] 167 ) 168 return scale
Compute the appropriate scale factor for inference with a given pretrained model.
Arguments:
- voxel_size: The voxel size of the data for inference.
- model_type: The name of the pretrained model.
Returns:
The scale factor, as a list in zyx order.
def
run_segmentation( image: numpy.ndarray, model: torch.nn.modules.module.Module, model_type: str, tiling: Optional[Dict[str, Dict[str, int]]] = None, scale: Optional[List[float]] = None, verbose: bool = False, **kwargs) -> Union[numpy.ndarray, Dict[str, numpy.ndarray]]:
230def run_segmentation( 231 image: np.ndarray, 232 model: torch.nn.Module, 233 model_type: str, 234 tiling: Optional[Dict[str, Dict[str, int]]] = None, 235 scale: Optional[List[float]] = None, 236 verbose: bool = False, 237 **kwargs, 238) -> np.ndarray | Dict[str, np.ndarray]: 239 """Run synaptic structure segmentation. 240 241 Args: 242 image: The input image or image volume. 243 model: The segmentation model. 244 model_type: The model type. This will determine which segmentation post-processing is used. 245 tiling: The tiling settings for inference. 246 scale: A scale factor for resizing the input before applying the model. 247 The output will be scaled back to the initial size. 248 verbose: Whether to print detailed information about the prediction and segmentation. 249 kwargs: Optional parameters for the segmentation function. 250 251 Returns: 252 The segmentation. For models that return multiple segmentations, this function returns a dictionary. 253 """ 254 if model_type.startswith("vesicles"): 255 segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 256 elif model_type == "mitochondria" or model_type == "mitochondria2": 257 segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 258 elif model_type == "active_zone": 259 segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 260 elif model_type == "compartments": 261 segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 262 elif model_type == "ribbon": 263 segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 264 elif model_type == "cristae" or model_type == "cristae2" or model_type == "cristae3": 265 segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 266 else: 267 raise ValueError(f"Unknown model type: {model_type}") 268 return segmentation
Run synaptic structure segmentation.
Arguments:
- image: The input image or image volume.
- model: The segmentation model.
- model_type: The model type. This will determine which segmentation post-processing is used.
- tiling: The tiling settings for inference.
- scale: A scale factor for resizing the input before applying the model. The output will be scaled back to the initial size.
- verbose: Whether to print detailed information about the prediction and segmentation.
- kwargs: Optional parameters for the segmentation function.
Returns:
The segmentation. For models that return multiple segmentations, this function returns a dictionary.