synapse_net.inference.inference
1import os 2from typing import Dict, List, Optional, Union 3 4import torch 5import numpy as np 6import pooch 7 8from .active_zone import segment_active_zone 9from .compartments import segment_compartments 10from .mitochondria import segment_mitochondria 11from .ribbon_synapse import segment_ribbon_synapse_structures 12from .vesicles import segment_vesicles 13from .cristae import segment_cristae 14from .util import get_device 15from ..file_utils import get_cache_dir 16 17 18# 19# Functions to access SynapseNet's pretrained models. 20# 21 22 23def _get_model_registry(): 24 registry = { 25 "active_zone": "c23652a8fe06daa113546af6d3200c4c1dcc79917056c6ed7357b8c93548372a", 26 "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1", 27 "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186", 28 "mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673", 29 "cristae": "f96c90484f4ea92ac0515a06e389cc117580f02c2aacdc44b5828820cf38c3c3", 30 "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9", 31 "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1", 32 "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29", 33 "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b", 34 # Additional models that are only available in the CLI, not in the plugin model selection. 35 "vesicles_2d_maus": "01506895df6343fc33ffc9c9eb3f975bf42eb4eaaaf4848bac83b57f1b46e460", 36 "vesicles_3d_endbulb": "8582c7e3e5f16ef2bf34d6f9e34644862ca3c76835c9e7d44475c9dd7891d228", 37 "vesicles_3d_innerear": "924f0f7cfb648a3a6931c1d48d8b1fdc6c0c0d2cb3330fe2cae49d13e7c3b69d", 38 } 39 urls = { 40 "active_zone": "https://owncloud.gwdg.de/index.php/s/wpea9FH9waG4zJd/download", 41 "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download", 42 "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download", 43 "mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download", 44 "cristae": "https://owncloud.gwdg.de/index.php/s/Df7OUOyQ1Kc2eEO/download", 45 "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download", 46 "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download", 47 "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download", 48 "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download", 49 # Additional models that are only available in the CLI, not in the plugin model selection. 50 "vesicles_2d_maus": "https://owncloud.gwdg.de/index.php/s/sZ8woLr0zs5zOpv/download", 51 "vesicles_3d_endbulb": "https://owncloud.gwdg.de/index.php/s/16tmnWrEDpYIMzU/download", 52 "vesicles_3d_innerear": "https://owncloud.gwdg.de/index.php/s/UFUCYivsCxrqISX/download", 53 } 54 cache_dir = get_cache_dir() 55 models = pooch.create( 56 path=os.path.join(cache_dir, "models"), 57 base_url="", 58 registry=registry, 59 urls=urls, 60 ) 61 return models 62 63 64def get_model_path(model_type: str) -> str: 65 """Get the local path to a pretrained model. 66 67 Args: 68 The model type. 69 70 Returns: 71 The local path to the model. 72 """ 73 model_registry = _get_model_registry() 74 model_path = model_registry.fetch(model_type) 75 return model_path 76 77 78def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 79 """Get the model for a specific segmentation type. 80 81 Args: 82 model_type: The model for one of the following segmentation tasks: 83 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. 84 device: The device to use. 85 86 Returns: 87 The model. 88 """ 89 if device is None: 90 device = get_device(device) 91 model_path = get_model_path(model_type) 92 model = torch.load(model_path, weights_only=False) 93 model.to(device) 94 return model 95 96 97# 98# Functions for training resolution / voxel size. 99# 100 101 102def get_model_training_resolution(model_type: str) -> Dict[str, float]: 103 """Get the average resolution / voxel size of the training data for a given pretrained model. 104 105 Args: 106 model_type: The name of the pretrained model. 107 108 Returns: 109 Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. 110 """ 111 resolutions = { 112 "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38}, 113 "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, 114 "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, 115 # TODO: this is just copied from the previous mito model, it may be necessary to update this. 116 "mitochondria2": {"x": 1.45, "y": 1.45, "z": 1.45}, 117 "cristae": {"x": 1.44, "y": 1.44, "z": 1.44}, 118 "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, 119 "vesicles_2d": {"x": 1.35, "y": 1.35}, 120 "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, 121 "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, 122 # TODO add the correct resolutions, these are the resolutions of the source models. 123 "vesicles_2d_maus": {"x": 1.35, "y": 1.35}, 124 "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35}, 125 "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35}, 126 } 127 return resolutions[model_type] 128 129 130def compute_scale_from_voxel_size( 131 voxel_size: Dict[str, float], 132 model_type: str 133) -> List[float]: 134 """Compute the appropriate scale factor for inference with a given pretrained model. 135 136 Args: 137 voxel_size: The voxel size of the data for inference. 138 model_type: The name of the pretrained model. 139 140 Returns: 141 The scale factor, as a list in zyx order. 142 """ 143 training_voxel_size = get_model_training_resolution(model_type) 144 scale = [ 145 voxel_size["x"] / training_voxel_size["x"], 146 voxel_size["y"] / training_voxel_size["y"], 147 ] 148 if len(voxel_size) == 3 and len(training_voxel_size) == 3: 149 scale.append( 150 voxel_size["z"] / training_voxel_size["z"] 151 ) 152 return scale 153 154 155# 156# Convenience functions for segmentation. 157# 158 159 160def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size): 161 from synapse_net.inference.postprocessing import ( 162 segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, 163 ) 164 165 ribbon = segment_ribbon( 166 predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons, 167 max_vesicle_distance=40, 168 ) 169 PD = segment_presynaptic_density( 170 predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40, 171 ) 172 ref_segmentation = PD if PD.sum() > 0 else ribbon 173 membrane = segment_membrane_distance_based( 174 predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude, 175 resolution=resolution, min_size=min_membrane_size, 176 ) 177 178 segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane} 179 return segmentations 180 181 182def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs): 183 # Parse additional keyword arguments from the kwargs. 184 vesicles = kwargs.pop("extra_segmentation") 185 threshold = kwargs.pop("threshold", 0.5) 186 n_slices_exclude = kwargs.pop("n_slices_exclude", 20) 187 n_ribbons = kwargs.pop("n_slices_exclude", 1) 188 resolution = kwargs.pop("resolution", None) 189 min_membrane_size = kwargs.pop("min_membrane_size", 0) 190 191 predictions = segment_ribbon_synapse_structures( 192 image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs 193 ) 194 195 # Otherwise, just return the predictions. 196 if vesicles is None: 197 if verbose: 198 print("Vesicle segmentation was not passed, WILL NOT run post-processing.") 199 segmentations = predictions 200 201 # If the vesicles were passed then run additional post-processing. 202 else: 203 if verbose: 204 print("Vesicle segmentation was passed, WILL run post-processing.") 205 segmentations = _ribbon_AZ_postprocessing( 206 predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size 207 ) 208 209 if return_predictions: 210 return segmentations, predictions 211 return segmentations 212 213 214def run_segmentation( 215 image: np.ndarray, 216 model: torch.nn.Module, 217 model_type: str, 218 tiling: Optional[Dict[str, Dict[str, int]]] = None, 219 scale: Optional[List[float]] = None, 220 verbose: bool = False, 221 **kwargs, 222) -> np.ndarray | Dict[str, np.ndarray]: 223 """Run synaptic structure segmentation. 224 225 Args: 226 image: The input image or image volume. 227 model: The segmentation model. 228 model_type: The model type. This will determine which segmentation post-processing is used. 229 tiling: The tiling settings for inference. 230 scale: A scale factor for resizing the input before applying the model. 231 The output will be scaled back to the initial size. 232 verbose: Whether to print detailed information about the prediction and segmentation. 233 kwargs: Optional parameters for the segmentation function. 234 235 Returns: 236 The segmentation. For models that return multiple segmentations, this function returns a dictionary. 237 """ 238 if model_type.startswith("vesicles"): 239 segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 240 elif model_type == "mitochondria" or model_type == "mitochondria2": 241 segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 242 elif model_type == "active_zone": 243 segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 244 elif model_type == "compartments": 245 segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 246 elif model_type == "ribbon": 247 segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 248 elif model_type == "cristae": 249 segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 250 else: 251 raise ValueError(f"Unknown model type: {model_type}") 252 return segmentation
def
get_model_path(model_type: str) -> str:
65def get_model_path(model_type: str) -> str: 66 """Get the local path to a pretrained model. 67 68 Args: 69 The model type. 70 71 Returns: 72 The local path to the model. 73 """ 74 model_registry = _get_model_registry() 75 model_path = model_registry.fetch(model_type) 76 return model_path
Get the local path to a pretrained model.
Arguments:
- The model type.
Returns:
The local path to the model.
def
get_model( model_type: str, device: Union[torch.device, str, NoneType] = None) -> torch.nn.modules.module.Module:
79def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 80 """Get the model for a specific segmentation type. 81 82 Args: 83 model_type: The model for one of the following segmentation tasks: 84 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. 85 device: The device to use. 86 87 Returns: 88 The model. 89 """ 90 if device is None: 91 device = get_device(device) 92 model_path = get_model_path(model_type) 93 model = torch.load(model_path, weights_only=False) 94 model.to(device) 95 return model
Get the model for a specific segmentation type.
Arguments:
- model_type: The model for one of the following segmentation tasks: 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
- device: The device to use.
Returns:
The model.
def
get_model_training_resolution(model_type: str) -> Dict[str, float]:
103def get_model_training_resolution(model_type: str) -> Dict[str, float]: 104 """Get the average resolution / voxel size of the training data for a given pretrained model. 105 106 Args: 107 model_type: The name of the pretrained model. 108 109 Returns: 110 Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. 111 """ 112 resolutions = { 113 "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38}, 114 "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, 115 "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, 116 # TODO: this is just copied from the previous mito model, it may be necessary to update this. 117 "mitochondria2": {"x": 1.45, "y": 1.45, "z": 1.45}, 118 "cristae": {"x": 1.44, "y": 1.44, "z": 1.44}, 119 "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, 120 "vesicles_2d": {"x": 1.35, "y": 1.35}, 121 "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, 122 "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, 123 # TODO add the correct resolutions, these are the resolutions of the source models. 124 "vesicles_2d_maus": {"x": 1.35, "y": 1.35}, 125 "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35}, 126 "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35}, 127 } 128 return resolutions[model_type]
Get the average resolution / voxel size of the training data for a given pretrained model.
Arguments:
- model_type: The name of the pretrained model.
Returns:
Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
def
compute_scale_from_voxel_size(voxel_size: Dict[str, float], model_type: str) -> List[float]:
131def compute_scale_from_voxel_size( 132 voxel_size: Dict[str, float], 133 model_type: str 134) -> List[float]: 135 """Compute the appropriate scale factor for inference with a given pretrained model. 136 137 Args: 138 voxel_size: The voxel size of the data for inference. 139 model_type: The name of the pretrained model. 140 141 Returns: 142 The scale factor, as a list in zyx order. 143 """ 144 training_voxel_size = get_model_training_resolution(model_type) 145 scale = [ 146 voxel_size["x"] / training_voxel_size["x"], 147 voxel_size["y"] / training_voxel_size["y"], 148 ] 149 if len(voxel_size) == 3 and len(training_voxel_size) == 3: 150 scale.append( 151 voxel_size["z"] / training_voxel_size["z"] 152 ) 153 return scale
Compute the appropriate scale factor for inference with a given pretrained model.
Arguments:
- voxel_size: The voxel size of the data for inference.
- model_type: The name of the pretrained model.
Returns:
The scale factor, as a list in zyx order.
def
run_segmentation( image: numpy.ndarray, model: torch.nn.modules.module.Module, model_type: str, tiling: Optional[Dict[str, Dict[str, int]]] = None, scale: Optional[List[float]] = None, verbose: bool = False, **kwargs) -> Union[numpy.ndarray, Dict[str, numpy.ndarray]]:
215def run_segmentation( 216 image: np.ndarray, 217 model: torch.nn.Module, 218 model_type: str, 219 tiling: Optional[Dict[str, Dict[str, int]]] = None, 220 scale: Optional[List[float]] = None, 221 verbose: bool = False, 222 **kwargs, 223) -> np.ndarray | Dict[str, np.ndarray]: 224 """Run synaptic structure segmentation. 225 226 Args: 227 image: The input image or image volume. 228 model: The segmentation model. 229 model_type: The model type. This will determine which segmentation post-processing is used. 230 tiling: The tiling settings for inference. 231 scale: A scale factor for resizing the input before applying the model. 232 The output will be scaled back to the initial size. 233 verbose: Whether to print detailed information about the prediction and segmentation. 234 kwargs: Optional parameters for the segmentation function. 235 236 Returns: 237 The segmentation. For models that return multiple segmentations, this function returns a dictionary. 238 """ 239 if model_type.startswith("vesicles"): 240 segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 241 elif model_type == "mitochondria" or model_type == "mitochondria2": 242 segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 243 elif model_type == "active_zone": 244 segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 245 elif model_type == "compartments": 246 segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 247 elif model_type == "ribbon": 248 segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 249 elif model_type == "cristae": 250 segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 251 else: 252 raise ValueError(f"Unknown model type: {model_type}") 253 return segmentation
Run synaptic structure segmentation.
Arguments:
- image: The input image or image volume.
- model: The segmentation model.
- model_type: The model type. This will determine which segmentation post-processing is used.
- tiling: The tiling settings for inference.
- scale: A scale factor for resizing the input before applying the model. The output will be scaled back to the initial size.
- verbose: Whether to print detailed information about the prediction and segmentation.
- kwargs: Optional parameters for the segmentation function.
Returns:
The segmentation. For models that return multiple segmentations, this function returns a dictionary.