synapse_net.inference.inference
1import os 2from typing import Dict, List, Optional, Union 3 4import torch 5import numpy as np 6import pooch 7 8from .active_zone import segment_active_zone 9from .compartments import segment_compartments 10from .mitochondria import segment_mitochondria 11from .ribbon_synapse import segment_ribbon_synapse_structures 12from .vesicles import segment_vesicles 13from .cristae import segment_cristae 14from .util import get_device 15from ..file_utils import get_cache_dir 16 17 18# 19# Functions to access SynapseNet's pretrained models. 20# 21 22 23def _get_model_registry(): 24 registry = { 25 "active_zone": "c23652a8fe06daa113546af6d3200c4c1dcc79917056c6ed7357b8c93548372a", 26 "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1", 27 "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186", 28 "mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673", 29 "mitochondria3": "aa690c49017f84fd333436ddcb995ffeaf9b8174e9505586a4caf735a5805e7c", 30 "cristae": "f96c90484f4ea92ac0515a06e389cc117580f02c2aacdc44b5828820cf38c3c3", 31 "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9", 32 "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1", 33 "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29", 34 "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b", 35 # Additional models that are only available in the CLI, not in the plugin model selection. 36 "vesicles_2d_maus": "01506895df6343fc33ffc9c9eb3f975bf42eb4eaaaf4848bac83b57f1b46e460", 37 "vesicles_3d_endbulb": "8582c7e3e5f16ef2bf34d6f9e34644862ca3c76835c9e7d44475c9dd7891d228", 38 "vesicles_3d_innerear": "924f0f7cfb648a3a6931c1d48d8b1fdc6c0c0d2cb3330fe2cae49d13e7c3b69d", 39 } 40 urls = { 41 "active_zone": "https://owncloud.gwdg.de/index.php/s/wpea9FH9waG4zJd/download", 42 "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download", 43 "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download", 44 "mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download", 45 "mitochondria3": "https://owncloud.gwdg.de/index.php/s/FwLQRCV2HnXDxsT/download", 46 "cristae": "https://owncloud.gwdg.de/index.php/s/Df7OUOyQ1Kc2eEO/download", 47 "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download", 48 "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download", 49 "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download", 50 "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download", 51 # Additional models that are only available in the CLI, not in the plugin model selection. 52 "vesicles_2d_maus": "https://owncloud.gwdg.de/index.php/s/sZ8woLr0zs5zOpv/download", 53 "vesicles_3d_endbulb": "https://owncloud.gwdg.de/index.php/s/16tmnWrEDpYIMzU/download", 54 "vesicles_3d_innerear": "https://owncloud.gwdg.de/index.php/s/UFUCYivsCxrqISX/download", 55 } 56 cache_dir = get_cache_dir() 57 models = pooch.create( 58 path=os.path.join(cache_dir, "models"), 59 base_url="", 60 registry=registry, 61 urls=urls, 62 ) 63 return models 64 65 66def get_available_models() -> List[str]: 67 """Get the names of all available pretrained models. 68 69 Returns: 70 The list of available model names. 71 """ 72 model_registry = _get_model_registry() 73 return list(model_registry.urls.keys()) 74 75 76def get_model_path(model_type: str) -> str: 77 """Get the local path to a pretrained model. 78 79 Args: 80 The model type. 81 82 Returns: 83 The local path to the model. 84 """ 85 model_registry = _get_model_registry() 86 model_path = model_registry.fetch(model_type) 87 return model_path 88 89 90def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 91 """Get the model for a specific segmentation type. 92 93 Args: 94 model_type: The model for one of the following segmentation tasks: 95 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. 96 device: The device to use. 97 98 Returns: 99 The model. 100 """ 101 if device is None: 102 device = get_device(device) 103 model_path = get_model_path(model_type) 104 model = torch.load(model_path, weights_only=False) 105 model.to(device) 106 return model 107 108 109# 110# Functions for training resolution / voxel size. 111# 112 113 114def get_model_training_resolution(model_type: str) -> Dict[str, float]: 115 """Get the average resolution / voxel size of the training data for a given pretrained model. 116 117 Args: 118 model_type: The name of the pretrained model. 119 120 Returns: 121 Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. 122 """ 123 resolutions = { 124 "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38}, 125 "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, 126 "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, 127 # TODO: this is just copied from the previous mito model, it may be necessary to update this. 128 "mitochondria2": {"x": 1.45, "y": 1.45, "z": 1.45}, 129 "mitochondria3": {"x": 2.87, "y": 2.87, "z": 2.87}, 130 "cristae": {"x": 1.44, "y": 1.44, "z": 1.44}, 131 "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, 132 "vesicles_2d": {"x": 1.35, "y": 1.35}, 133 "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, 134 "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, 135 # TODO add the correct resolutions, these are the resolutions of the source models. 136 "vesicles_2d_maus": {"x": 1.35, "y": 1.35}, 137 "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35}, 138 "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35}, 139 } 140 return resolutions[model_type] 141 142 143def compute_scale_from_voxel_size( 144 voxel_size: Dict[str, float], 145 model_type: str 146) -> List[float]: 147 """Compute the appropriate scale factor for inference with a given pretrained model. 148 149 Args: 150 voxel_size: The voxel size of the data for inference. 151 model_type: The name of the pretrained model. 152 153 Returns: 154 The scale factor, as a list in zyx order. 155 """ 156 training_voxel_size = get_model_training_resolution(model_type) 157 scale = [ 158 voxel_size["x"] / training_voxel_size["x"], 159 voxel_size["y"] / training_voxel_size["y"], 160 ] 161 if len(voxel_size) == 3 and len(training_voxel_size) == 3: 162 scale.append( 163 voxel_size["z"] / training_voxel_size["z"] 164 ) 165 return scale 166 167 168# 169# Convenience functions for segmentation. 170# 171 172 173def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size): 174 from synapse_net.inference.postprocessing import ( 175 segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, 176 ) 177 178 ribbon = segment_ribbon( 179 predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons, 180 max_vesicle_distance=40, 181 ) 182 PD = segment_presynaptic_density( 183 predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40, 184 ) 185 ref_segmentation = PD if PD.sum() > 0 else ribbon 186 membrane = segment_membrane_distance_based( 187 predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude, 188 resolution=resolution, min_size=min_membrane_size, 189 ) 190 191 segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane} 192 return segmentations 193 194 195def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs): 196 # Parse additional keyword arguments from the kwargs. 197 vesicles = kwargs.pop("extra_segmentation") 198 threshold = kwargs.pop("threshold", 0.5) 199 n_slices_exclude = kwargs.pop("n_slices_exclude", 20) 200 n_ribbons = kwargs.pop("n_slices_exclude", 1) 201 resolution = kwargs.pop("resolution", None) 202 min_membrane_size = kwargs.pop("min_membrane_size", 0) 203 204 predictions = segment_ribbon_synapse_structures( 205 image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs 206 ) 207 208 # Otherwise, just return the predictions. 209 if vesicles is None: 210 if verbose: 211 print("Vesicle segmentation was not passed, WILL NOT run post-processing.") 212 segmentations = predictions 213 214 # If the vesicles were passed then run additional post-processing. 215 else: 216 if verbose: 217 print("Vesicle segmentation was passed, WILL run post-processing.") 218 segmentations = _ribbon_AZ_postprocessing( 219 predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size 220 ) 221 222 if return_predictions: 223 return segmentations, predictions 224 return segmentations 225 226 227def run_segmentation( 228 image: np.ndarray, 229 model: torch.nn.Module, 230 model_type: str, 231 tiling: Optional[Dict[str, Dict[str, int]]] = None, 232 scale: Optional[List[float]] = None, 233 verbose: bool = False, 234 **kwargs, 235) -> np.ndarray | Dict[str, np.ndarray]: 236 """Run synaptic structure segmentation. 237 238 Args: 239 image: The input image or image volume. 240 model: The segmentation model. 241 model_type: The model type. This will determine which segmentation post-processing is used. 242 tiling: The tiling settings for inference. 243 scale: A scale factor for resizing the input before applying the model. 244 The output will be scaled back to the initial size. 245 verbose: Whether to print detailed information about the prediction and segmentation. 246 kwargs: Optional parameters for the segmentation function. 247 248 Returns: 249 The segmentation. For models that return multiple segmentations, this function returns a dictionary. 250 """ 251 if model_type.startswith("vesicles"): 252 segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 253 elif model_type == "mitochondria" or model_type == "mitochondria2" or model_type == "mitochondria3": 254 segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 255 elif model_type == "active_zone": 256 segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 257 elif model_type == "compartments": 258 segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 259 elif model_type == "ribbon": 260 segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 261 elif model_type == "cristae": 262 segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 263 else: 264 raise ValueError(f"Unknown model type: {model_type}") 265 return segmentation
def
get_available_models() -> List[str]:
67def get_available_models() -> List[str]: 68 """Get the names of all available pretrained models. 69 70 Returns: 71 The list of available model names. 72 """ 73 model_registry = _get_model_registry() 74 return list(model_registry.urls.keys())
Get the names of all available pretrained models.
Returns:
The list of available model names.
def
get_model_path(model_type: str) -> str:
77def get_model_path(model_type: str) -> str: 78 """Get the local path to a pretrained model. 79 80 Args: 81 The model type. 82 83 Returns: 84 The local path to the model. 85 """ 86 model_registry = _get_model_registry() 87 model_path = model_registry.fetch(model_type) 88 return model_path
Get the local path to a pretrained model.
Arguments:
- The model type.
Returns:
The local path to the model.
def
get_model( model_type: str, device: Union[torch.device, str, NoneType] = None) -> torch.nn.modules.module.Module:
91def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 92 """Get the model for a specific segmentation type. 93 94 Args: 95 model_type: The model for one of the following segmentation tasks: 96 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. 97 device: The device to use. 98 99 Returns: 100 The model. 101 """ 102 if device is None: 103 device = get_device(device) 104 model_path = get_model_path(model_type) 105 model = torch.load(model_path, weights_only=False) 106 model.to(device) 107 return model
Get the model for a specific segmentation type.
Arguments:
- model_type: The model for one of the following segmentation tasks: 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
- device: The device to use.
Returns:
The model.
def
get_model_training_resolution(model_type: str) -> Dict[str, float]:
115def get_model_training_resolution(model_type: str) -> Dict[str, float]: 116 """Get the average resolution / voxel size of the training data for a given pretrained model. 117 118 Args: 119 model_type: The name of the pretrained model. 120 121 Returns: 122 Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. 123 """ 124 resolutions = { 125 "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38}, 126 "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, 127 "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, 128 # TODO: this is just copied from the previous mito model, it may be necessary to update this. 129 "mitochondria2": {"x": 1.45, "y": 1.45, "z": 1.45}, 130 "mitochondria3": {"x": 2.87, "y": 2.87, "z": 2.87}, 131 "cristae": {"x": 1.44, "y": 1.44, "z": 1.44}, 132 "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, 133 "vesicles_2d": {"x": 1.35, "y": 1.35}, 134 "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, 135 "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, 136 # TODO add the correct resolutions, these are the resolutions of the source models. 137 "vesicles_2d_maus": {"x": 1.35, "y": 1.35}, 138 "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35}, 139 "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35}, 140 } 141 return resolutions[model_type]
Get the average resolution / voxel size of the training data for a given pretrained model.
Arguments:
- model_type: The name of the pretrained model.
Returns:
Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
def
compute_scale_from_voxel_size(voxel_size: Dict[str, float], model_type: str) -> List[float]:
144def compute_scale_from_voxel_size( 145 voxel_size: Dict[str, float], 146 model_type: str 147) -> List[float]: 148 """Compute the appropriate scale factor for inference with a given pretrained model. 149 150 Args: 151 voxel_size: The voxel size of the data for inference. 152 model_type: The name of the pretrained model. 153 154 Returns: 155 The scale factor, as a list in zyx order. 156 """ 157 training_voxel_size = get_model_training_resolution(model_type) 158 scale = [ 159 voxel_size["x"] / training_voxel_size["x"], 160 voxel_size["y"] / training_voxel_size["y"], 161 ] 162 if len(voxel_size) == 3 and len(training_voxel_size) == 3: 163 scale.append( 164 voxel_size["z"] / training_voxel_size["z"] 165 ) 166 return scale
Compute the appropriate scale factor for inference with a given pretrained model.
Arguments:
- voxel_size: The voxel size of the data for inference.
- model_type: The name of the pretrained model.
Returns:
The scale factor, as a list in zyx order.
def
run_segmentation( image: numpy.ndarray, model: torch.nn.modules.module.Module, model_type: str, tiling: Optional[Dict[str, Dict[str, int]]] = None, scale: Optional[List[float]] = None, verbose: bool = False, **kwargs) -> Union[numpy.ndarray, Dict[str, numpy.ndarray]]:
228def run_segmentation( 229 image: np.ndarray, 230 model: torch.nn.Module, 231 model_type: str, 232 tiling: Optional[Dict[str, Dict[str, int]]] = None, 233 scale: Optional[List[float]] = None, 234 verbose: bool = False, 235 **kwargs, 236) -> np.ndarray | Dict[str, np.ndarray]: 237 """Run synaptic structure segmentation. 238 239 Args: 240 image: The input image or image volume. 241 model: The segmentation model. 242 model_type: The model type. This will determine which segmentation post-processing is used. 243 tiling: The tiling settings for inference. 244 scale: A scale factor for resizing the input before applying the model. 245 The output will be scaled back to the initial size. 246 verbose: Whether to print detailed information about the prediction and segmentation. 247 kwargs: Optional parameters for the segmentation function. 248 249 Returns: 250 The segmentation. For models that return multiple segmentations, this function returns a dictionary. 251 """ 252 if model_type.startswith("vesicles"): 253 segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 254 elif model_type == "mitochondria" or model_type == "mitochondria2" or model_type == "mitochondria3": 255 segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 256 elif model_type == "active_zone": 257 segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 258 elif model_type == "compartments": 259 segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 260 elif model_type == "ribbon": 261 segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 262 elif model_type == "cristae": 263 segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 264 else: 265 raise ValueError(f"Unknown model type: {model_type}") 266 return segmentation
Run synaptic structure segmentation.
Arguments:
- image: The input image or image volume.
- model: The segmentation model.
- model_type: The model type. This will determine which segmentation post-processing is used.
- tiling: The tiling settings for inference.
- scale: A scale factor for resizing the input before applying the model. The output will be scaled back to the initial size.
- verbose: Whether to print detailed information about the prediction and segmentation.
- kwargs: Optional parameters for the segmentation function.
Returns:
The segmentation. For models that return multiple segmentations, this function returns a dictionary.