synapse_net.inference.inference
1import os 2from typing import Dict, List, Optional, Union 3 4import torch 5import numpy as np 6import pooch 7 8from .active_zone import segment_active_zone 9from .compartments import segment_compartments 10from .mitochondria import segment_mitochondria 11from .ribbon_synapse import segment_ribbon_synapse_structures 12from .vesicles import segment_vesicles 13from .cristae import segment_cristae 14from .util import get_device 15from ..file_utils import get_cache_dir 16 17 18# 19# Functions to access SynapseNet's pretrained models. 20# 21 22 23def _get_model_registry(): 24 registry = { 25 "active_zone": "c23652a8fe06daa113546af6d3200c4c1dcc79917056c6ed7357b8c93548372a", 26 "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1", 27 "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186", 28 "mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673", 29 "cristae": "f96c90484f4ea92ac0515a06e389cc117580f02c2aacdc44b5828820cf38c3c3", 30 "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9", 31 "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1", 32 "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29", 33 "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b", 34 # Additional models that are only available in the CLI, not in the plugin model selection. 35 "vesicles_2d_maus": "01506895df6343fc33ffc9c9eb3f975bf42eb4eaaaf4848bac83b57f1b46e460", 36 "vesicles_3d_endbulb": "8582c7e3e5f16ef2bf34d6f9e34644862ca3c76835c9e7d44475c9dd7891d228", 37 "vesicles_3d_innerear": "924f0f7cfb648a3a6931c1d48d8b1fdc6c0c0d2cb3330fe2cae49d13e7c3b69d", 38 } 39 urls = { 40 "active_zone": "https://owncloud.gwdg.de/index.php/s/wpea9FH9waG4zJd/download", 41 "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download", 42 "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download", 43 "mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download", 44 "cristae": "https://owncloud.gwdg.de/index.php/s/Df7OUOyQ1Kc2eEO/download", 45 "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download", 46 "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download", 47 "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download", 48 "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download", 49 # Additional models that are only available in the CLI, not in the plugin model selection. 50 "vesicles_2d_maus": "https://owncloud.gwdg.de/index.php/s/sZ8woLr0zs5zOpv/download", 51 "vesicles_3d_endbulb": "https://owncloud.gwdg.de/index.php/s/16tmnWrEDpYIMzU/download", 52 "vesicles_3d_innerear": "https://owncloud.gwdg.de/index.php/s/UFUCYivsCxrqISX/download", 53 } 54 cache_dir = get_cache_dir() 55 models = pooch.create( 56 path=os.path.join(cache_dir, "models"), 57 base_url="", 58 registry=registry, 59 urls=urls, 60 ) 61 return models 62 63 64def get_model_path(model_type: str) -> str: 65 """Get the local path to a pretrained model. 66 67 Args: 68 The model type. 69 70 Returns: 71 The local path to the model. 72 """ 73 model_registry = _get_model_registry() 74 model_path = model_registry.fetch(model_type) 75 return model_path 76 77 78def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 79 """Get the model for a specific segmentation type. 80 81 Args: 82 model_type: The model for one of the following segmentation tasks: 83 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. 84 device: The device to use. 85 86 Returns: 87 The model. 88 """ 89 if device is None: 90 device = get_device(device) 91 model_path = get_model_path(model_type) 92 model = torch.load(model_path, weights_only=False) 93 model.to(device) 94 return model 95 96 97# 98# Functions for training resolution / voxel size. 99# 100 101 102def get_model_training_resolution(model_type: str) -> Dict[str, float]: 103 """Get the average resolution / voxel size of the training data for a given pretrained model. 104 105 Args: 106 model_type: The name of the pretrained model. 107 108 Returns: 109 Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. 110 """ 111 resolutions = { 112 "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38}, 113 "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, 114 "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, 115 "cristae": {"x": 1.44, "y": 1.44, "z": 1.44}, 116 "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, 117 "vesicles_2d": {"x": 1.35, "y": 1.35}, 118 "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, 119 "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, 120 # TODO add the correct resolutions, these are the resolutions of the source models. 121 "vesicles_2d_maus": {"x": 1.35, "y": 1.35}, 122 "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35}, 123 "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35}, 124 } 125 return resolutions[model_type] 126 127 128def compute_scale_from_voxel_size( 129 voxel_size: Dict[str, float], 130 model_type: str 131) -> List[float]: 132 """Compute the appropriate scale factor for inference with a given pretrained model. 133 134 Args: 135 voxel_size: The voxel size of the data for inference. 136 model_type: The name of the pretrained model. 137 138 Returns: 139 The scale factor, as a list in zyx order. 140 """ 141 training_voxel_size = get_model_training_resolution(model_type) 142 scale = [ 143 voxel_size["x"] / training_voxel_size["x"], 144 voxel_size["y"] / training_voxel_size["y"], 145 ] 146 if len(voxel_size) == 3 and len(training_voxel_size) == 3: 147 scale.append( 148 voxel_size["z"] / training_voxel_size["z"] 149 ) 150 return scale 151 152 153# 154# Convenience functions for segmentation. 155# 156 157 158def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size): 159 from synapse_net.inference.postprocessing import ( 160 segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, 161 ) 162 163 ribbon = segment_ribbon( 164 predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons, 165 max_vesicle_distance=40, 166 ) 167 PD = segment_presynaptic_density( 168 predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40, 169 ) 170 ref_segmentation = PD if PD.sum() > 0 else ribbon 171 membrane = segment_membrane_distance_based( 172 predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude, 173 resolution=resolution, min_size=min_membrane_size, 174 ) 175 176 segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane} 177 return segmentations 178 179 180def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs): 181 # Parse additional keyword arguments from the kwargs. 182 vesicles = kwargs.pop("extra_segmentation") 183 threshold = kwargs.pop("threshold", 0.5) 184 n_slices_exclude = kwargs.pop("n_slices_exclude", 20) 185 n_ribbons = kwargs.pop("n_slices_exclude", 1) 186 resolution = kwargs.pop("resolution", None) 187 min_membrane_size = kwargs.pop("min_membrane_size", 0) 188 189 predictions = segment_ribbon_synapse_structures( 190 image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs 191 ) 192 193 # Otherwise, just return the predictions. 194 if vesicles is None: 195 if verbose: 196 print("Vesicle segmentation was not passed, WILL NOT run post-processing.") 197 segmentations = predictions 198 199 # If the vesicles were passed then run additional post-processing. 200 else: 201 if verbose: 202 print("Vesicle segmentation was passed, WILL run post-processing.") 203 segmentations = _ribbon_AZ_postprocessing( 204 predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size 205 ) 206 207 if return_predictions: 208 return segmentations, predictions 209 return segmentations 210 211 212def run_segmentation( 213 image: np.ndarray, 214 model: torch.nn.Module, 215 model_type: str, 216 tiling: Optional[Dict[str, Dict[str, int]]] = None, 217 scale: Optional[List[float]] = None, 218 verbose: bool = False, 219 **kwargs, 220) -> np.ndarray | Dict[str, np.ndarray]: 221 """Run synaptic structure segmentation. 222 223 Args: 224 image: The input image or image volume. 225 model: The segmentation model. 226 model_type: The model type. This will determine which segmentation post-processing is used. 227 tiling: The tiling settings for inference. 228 scale: A scale factor for resizing the input before applying the model. 229 The output will be scaled back to the initial size. 230 verbose: Whether to print detailed information about the prediction and segmentation. 231 kwargs: Optional parameters for the segmentation function. 232 233 Returns: 234 The segmentation. For models that return multiple segmentations, this function returns a dictionary. 235 """ 236 if model_type.startswith("vesicles"): 237 segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 238 elif model_type == "mitochondria" or model_type == "mitochondria2": 239 segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 240 elif model_type == "active_zone": 241 segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 242 elif model_type == "compartments": 243 segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 244 elif model_type == "ribbon": 245 segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 246 elif model_type == "cristae": 247 segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 248 else: 249 raise ValueError(f"Unknown model type: {model_type}") 250 return segmentation
def
get_model_path(model_type: str) -> str:
65def get_model_path(model_type: str) -> str: 66 """Get the local path to a pretrained model. 67 68 Args: 69 The model type. 70 71 Returns: 72 The local path to the model. 73 """ 74 model_registry = _get_model_registry() 75 model_path = model_registry.fetch(model_type) 76 return model_path
Get the local path to a pretrained model.
Arguments:
- The model type.
Returns:
The local path to the model.
def
get_model( model_type: str, device: Union[torch.device, str, NoneType] = None) -> torch.nn.modules.module.Module:
79def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 80 """Get the model for a specific segmentation type. 81 82 Args: 83 model_type: The model for one of the following segmentation tasks: 84 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. 85 device: The device to use. 86 87 Returns: 88 The model. 89 """ 90 if device is None: 91 device = get_device(device) 92 model_path = get_model_path(model_type) 93 model = torch.load(model_path, weights_only=False) 94 model.to(device) 95 return model
Get the model for a specific segmentation type.
Arguments:
- model_type: The model for one of the following segmentation tasks: 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
- device: The device to use.
Returns:
The model.
def
get_model_training_resolution(model_type: str) -> Dict[str, float]:
103def get_model_training_resolution(model_type: str) -> Dict[str, float]: 104 """Get the average resolution / voxel size of the training data for a given pretrained model. 105 106 Args: 107 model_type: The name of the pretrained model. 108 109 Returns: 110 Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. 111 """ 112 resolutions = { 113 "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38}, 114 "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, 115 "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, 116 "cristae": {"x": 1.44, "y": 1.44, "z": 1.44}, 117 "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, 118 "vesicles_2d": {"x": 1.35, "y": 1.35}, 119 "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, 120 "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, 121 # TODO add the correct resolutions, these are the resolutions of the source models. 122 "vesicles_2d_maus": {"x": 1.35, "y": 1.35}, 123 "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35}, 124 "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35}, 125 } 126 return resolutions[model_type]
Get the average resolution / voxel size of the training data for a given pretrained model.
Arguments:
- model_type: The name of the pretrained model.
Returns:
Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
def
compute_scale_from_voxel_size(voxel_size: Dict[str, float], model_type: str) -> List[float]:
129def compute_scale_from_voxel_size( 130 voxel_size: Dict[str, float], 131 model_type: str 132) -> List[float]: 133 """Compute the appropriate scale factor for inference with a given pretrained model. 134 135 Args: 136 voxel_size: The voxel size of the data for inference. 137 model_type: The name of the pretrained model. 138 139 Returns: 140 The scale factor, as a list in zyx order. 141 """ 142 training_voxel_size = get_model_training_resolution(model_type) 143 scale = [ 144 voxel_size["x"] / training_voxel_size["x"], 145 voxel_size["y"] / training_voxel_size["y"], 146 ] 147 if len(voxel_size) == 3 and len(training_voxel_size) == 3: 148 scale.append( 149 voxel_size["z"] / training_voxel_size["z"] 150 ) 151 return scale
Compute the appropriate scale factor for inference with a given pretrained model.
Arguments:
- voxel_size: The voxel size of the data for inference.
- model_type: The name of the pretrained model.
Returns:
The scale factor, as a list in zyx order.
def
run_segmentation( image: numpy.ndarray, model: torch.nn.modules.module.Module, model_type: str, tiling: Optional[Dict[str, Dict[str, int]]] = None, scale: Optional[List[float]] = None, verbose: bool = False, **kwargs) -> Union[numpy.ndarray, Dict[str, numpy.ndarray]]:
213def run_segmentation( 214 image: np.ndarray, 215 model: torch.nn.Module, 216 model_type: str, 217 tiling: Optional[Dict[str, Dict[str, int]]] = None, 218 scale: Optional[List[float]] = None, 219 verbose: bool = False, 220 **kwargs, 221) -> np.ndarray | Dict[str, np.ndarray]: 222 """Run synaptic structure segmentation. 223 224 Args: 225 image: The input image or image volume. 226 model: The segmentation model. 227 model_type: The model type. This will determine which segmentation post-processing is used. 228 tiling: The tiling settings for inference. 229 scale: A scale factor for resizing the input before applying the model. 230 The output will be scaled back to the initial size. 231 verbose: Whether to print detailed information about the prediction and segmentation. 232 kwargs: Optional parameters for the segmentation function. 233 234 Returns: 235 The segmentation. For models that return multiple segmentations, this function returns a dictionary. 236 """ 237 if model_type.startswith("vesicles"): 238 segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 239 elif model_type == "mitochondria" or model_type == "mitochondria2": 240 segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 241 elif model_type == "active_zone": 242 segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 243 elif model_type == "compartments": 244 segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 245 elif model_type == "ribbon": 246 segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 247 elif model_type == "cristae": 248 segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 249 else: 250 raise ValueError(f"Unknown model type: {model_type}") 251 return segmentation
Run synaptic structure segmentation.
Arguments:
- image: The input image or image volume.
- model: The segmentation model.
- model_type: The model type. This will determine which segmentation post-processing is used.
- tiling: The tiling settings for inference.
- scale: A scale factor for resizing the input before applying the model. The output will be scaled back to the initial size.
- verbose: Whether to print detailed information about the prediction and segmentation.
- kwargs: Optional parameters for the segmentation function.
Returns:
The segmentation. For models that return multiple segmentations, this function returns a dictionary.