synapse_net.inference.inference
1import os 2from typing import Dict, List, Optional, Union 3 4import torch 5import numpy as np 6import pooch 7 8from .active_zone import segment_active_zone 9from .compartments import segment_compartments 10from .mitochondria import segment_mitochondria 11from .ribbon_synapse import segment_ribbon_synapse_structures 12from .vesicles import segment_vesicles 13from .cristae import segment_cristae 14from .util import get_device 15from ..file_utils import get_cache_dir 16 17 18# 19# Functions to access SynapseNet's pretrained models. 20# 21 22 23def _get_model_registry(): 24 registry = { 25 "active_zone": "c23652a8fe06daa113546af6d3200c4c1dcc79917056c6ed7357b8c93548372a", 26 "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1", 27 "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186", 28 "mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673", 29 "mitochondria3": "aa690c49017f84fd333436ddcb995ffeaf9b8174e9505586a4caf735a5805e7c", 30 "cristae": "f96c90484f4ea92ac0515a06e389cc117580f02c2aacdc44b5828820cf38c3c3", 31 "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9", 32 "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1", 33 "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29", 34 "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b", 35 # Additional models that are only available in the CLI, not in the plugin model selection. 36 "vesicles_2d_maus": "01506895df6343fc33ffc9c9eb3f975bf42eb4eaaaf4848bac83b57f1b46e460", 37 "vesicles_3d_endbulb": "8582c7e3e5f16ef2bf34d6f9e34644862ca3c76835c9e7d44475c9dd7891d228", 38 "vesicles_3d_innerear": "924f0f7cfb648a3a6931c1d48d8b1fdc6c0c0d2cb3330fe2cae49d13e7c3b69d", 39 } 40 urls = { 41 "active_zone": "https://owncloud.gwdg.de/index.php/s/wpea9FH9waG4zJd/download", 42 "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download", 43 "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download", 44 "mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download", 45 "mitochondria3": "https://owncloud.gwdg.de/index.php/s/FwLQRCV2HnXDxsT/download", 46 "cristae": "https://owncloud.gwdg.de/index.php/s/Df7OUOyQ1Kc2eEO/download", 47 "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download", 48 "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download", 49 "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download", 50 "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download", 51 # Additional models that are only available in the CLI, not in the plugin model selection. 52 "vesicles_2d_maus": "https://owncloud.gwdg.de/index.php/s/sZ8woLr0zs5zOpv/download", 53 "vesicles_3d_endbulb": "https://owncloud.gwdg.de/index.php/s/16tmnWrEDpYIMzU/download", 54 "vesicles_3d_innerear": "https://owncloud.gwdg.de/index.php/s/UFUCYivsCxrqISX/download", 55 } 56 cache_dir = get_cache_dir() 57 models = pooch.create( 58 path=os.path.join(cache_dir, "models"), 59 base_url="", 60 registry=registry, 61 urls=urls, 62 ) 63 return models 64 65 66def get_model_path(model_type: str) -> str: 67 """Get the local path to a pretrained model. 68 69 Args: 70 The model type. 71 72 Returns: 73 The local path to the model. 74 """ 75 model_registry = _get_model_registry() 76 model_path = model_registry.fetch(model_type) 77 return model_path 78 79 80def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 81 """Get the model for a specific segmentation type. 82 83 Args: 84 model_type: The model for one of the following segmentation tasks: 85 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. 86 device: The device to use. 87 88 Returns: 89 The model. 90 """ 91 if device is None: 92 device = get_device(device) 93 model_path = get_model_path(model_type) 94 model = torch.load(model_path, weights_only=False) 95 model.to(device) 96 return model 97 98 99# 100# Functions for training resolution / voxel size. 101# 102 103 104def get_model_training_resolution(model_type: str) -> Dict[str, float]: 105 """Get the average resolution / voxel size of the training data for a given pretrained model. 106 107 Args: 108 model_type: The name of the pretrained model. 109 110 Returns: 111 Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. 112 """ 113 resolutions = { 114 "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38}, 115 "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, 116 "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, 117 # TODO: this is just copied from the previous mito model, it may be necessary to update this. 118 "mitochondria2": {"x": 1.45, "y": 1.45, "z": 1.45}, 119 "mitochondria3": {"x": 2.87, "y": 2.87, "z": 2.87}, 120 "cristae": {"x": 1.44, "y": 1.44, "z": 1.44}, 121 "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, 122 "vesicles_2d": {"x": 1.35, "y": 1.35}, 123 "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, 124 "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, 125 # TODO add the correct resolutions, these are the resolutions of the source models. 126 "vesicles_2d_maus": {"x": 1.35, "y": 1.35}, 127 "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35}, 128 "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35}, 129 } 130 return resolutions[model_type] 131 132 133def compute_scale_from_voxel_size( 134 voxel_size: Dict[str, float], 135 model_type: str 136) -> List[float]: 137 """Compute the appropriate scale factor for inference with a given pretrained model. 138 139 Args: 140 voxel_size: The voxel size of the data for inference. 141 model_type: The name of the pretrained model. 142 143 Returns: 144 The scale factor, as a list in zyx order. 145 """ 146 training_voxel_size = get_model_training_resolution(model_type) 147 scale = [ 148 voxel_size["x"] / training_voxel_size["x"], 149 voxel_size["y"] / training_voxel_size["y"], 150 ] 151 if len(voxel_size) == 3 and len(training_voxel_size) == 3: 152 scale.append( 153 voxel_size["z"] / training_voxel_size["z"] 154 ) 155 return scale 156 157 158# 159# Convenience functions for segmentation. 160# 161 162 163def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size): 164 from synapse_net.inference.postprocessing import ( 165 segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, 166 ) 167 168 ribbon = segment_ribbon( 169 predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons, 170 max_vesicle_distance=40, 171 ) 172 PD = segment_presynaptic_density( 173 predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40, 174 ) 175 ref_segmentation = PD if PD.sum() > 0 else ribbon 176 membrane = segment_membrane_distance_based( 177 predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude, 178 resolution=resolution, min_size=min_membrane_size, 179 ) 180 181 segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane} 182 return segmentations 183 184 185def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs): 186 # Parse additional keyword arguments from the kwargs. 187 vesicles = kwargs.pop("extra_segmentation") 188 threshold = kwargs.pop("threshold", 0.5) 189 n_slices_exclude = kwargs.pop("n_slices_exclude", 20) 190 n_ribbons = kwargs.pop("n_slices_exclude", 1) 191 resolution = kwargs.pop("resolution", None) 192 min_membrane_size = kwargs.pop("min_membrane_size", 0) 193 194 predictions = segment_ribbon_synapse_structures( 195 image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs 196 ) 197 198 # Otherwise, just return the predictions. 199 if vesicles is None: 200 if verbose: 201 print("Vesicle segmentation was not passed, WILL NOT run post-processing.") 202 segmentations = predictions 203 204 # If the vesicles were passed then run additional post-processing. 205 else: 206 if verbose: 207 print("Vesicle segmentation was passed, WILL run post-processing.") 208 segmentations = _ribbon_AZ_postprocessing( 209 predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size 210 ) 211 212 if return_predictions: 213 return segmentations, predictions 214 return segmentations 215 216 217def run_segmentation( 218 image: np.ndarray, 219 model: torch.nn.Module, 220 model_type: str, 221 tiling: Optional[Dict[str, Dict[str, int]]] = None, 222 scale: Optional[List[float]] = None, 223 verbose: bool = False, 224 **kwargs, 225) -> np.ndarray | Dict[str, np.ndarray]: 226 """Run synaptic structure segmentation. 227 228 Args: 229 image: The input image or image volume. 230 model: The segmentation model. 231 model_type: The model type. This will determine which segmentation post-processing is used. 232 tiling: The tiling settings for inference. 233 scale: A scale factor for resizing the input before applying the model. 234 The output will be scaled back to the initial size. 235 verbose: Whether to print detailed information about the prediction and segmentation. 236 kwargs: Optional parameters for the segmentation function. 237 238 Returns: 239 The segmentation. For models that return multiple segmentations, this function returns a dictionary. 240 """ 241 if model_type.startswith("vesicles"): 242 segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 243 elif model_type == "mitochondria" or model_type == "mitochondria2" or model_type == "mitochondria3": 244 segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 245 elif model_type == "active_zone": 246 segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 247 elif model_type == "compartments": 248 segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 249 elif model_type == "ribbon": 250 segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 251 elif model_type == "cristae": 252 segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 253 else: 254 raise ValueError(f"Unknown model type: {model_type}") 255 return segmentation
def
get_model_path(model_type: str) -> str:
67def get_model_path(model_type: str) -> str: 68 """Get the local path to a pretrained model. 69 70 Args: 71 The model type. 72 73 Returns: 74 The local path to the model. 75 """ 76 model_registry = _get_model_registry() 77 model_path = model_registry.fetch(model_type) 78 return model_path
Get the local path to a pretrained model.
Arguments:
- The model type.
Returns:
The local path to the model.
def
get_model( model_type: str, device: Union[torch.device, str, NoneType] = None) -> torch.nn.modules.module.Module:
81def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 82 """Get the model for a specific segmentation type. 83 84 Args: 85 model_type: The model for one of the following segmentation tasks: 86 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. 87 device: The device to use. 88 89 Returns: 90 The model. 91 """ 92 if device is None: 93 device = get_device(device) 94 model_path = get_model_path(model_type) 95 model = torch.load(model_path, weights_only=False) 96 model.to(device) 97 return model
Get the model for a specific segmentation type.
Arguments:
- model_type: The model for one of the following segmentation tasks: 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
- device: The device to use.
Returns:
The model.
def
get_model_training_resolution(model_type: str) -> Dict[str, float]:
105def get_model_training_resolution(model_type: str) -> Dict[str, float]: 106 """Get the average resolution / voxel size of the training data for a given pretrained model. 107 108 Args: 109 model_type: The name of the pretrained model. 110 111 Returns: 112 Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. 113 """ 114 resolutions = { 115 "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38}, 116 "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, 117 "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, 118 # TODO: this is just copied from the previous mito model, it may be necessary to update this. 119 "mitochondria2": {"x": 1.45, "y": 1.45, "z": 1.45}, 120 "mitochondria3": {"x": 2.87, "y": 2.87, "z": 2.87}, 121 "cristae": {"x": 1.44, "y": 1.44, "z": 1.44}, 122 "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, 123 "vesicles_2d": {"x": 1.35, "y": 1.35}, 124 "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, 125 "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, 126 # TODO add the correct resolutions, these are the resolutions of the source models. 127 "vesicles_2d_maus": {"x": 1.35, "y": 1.35}, 128 "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35}, 129 "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35}, 130 } 131 return resolutions[model_type]
Get the average resolution / voxel size of the training data for a given pretrained model.
Arguments:
- model_type: The name of the pretrained model.
Returns:
Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
def
compute_scale_from_voxel_size(voxel_size: Dict[str, float], model_type: str) -> List[float]:
134def compute_scale_from_voxel_size( 135 voxel_size: Dict[str, float], 136 model_type: str 137) -> List[float]: 138 """Compute the appropriate scale factor for inference with a given pretrained model. 139 140 Args: 141 voxel_size: The voxel size of the data for inference. 142 model_type: The name of the pretrained model. 143 144 Returns: 145 The scale factor, as a list in zyx order. 146 """ 147 training_voxel_size = get_model_training_resolution(model_type) 148 scale = [ 149 voxel_size["x"] / training_voxel_size["x"], 150 voxel_size["y"] / training_voxel_size["y"], 151 ] 152 if len(voxel_size) == 3 and len(training_voxel_size) == 3: 153 scale.append( 154 voxel_size["z"] / training_voxel_size["z"] 155 ) 156 return scale
Compute the appropriate scale factor for inference with a given pretrained model.
Arguments:
- voxel_size: The voxel size of the data for inference.
- model_type: The name of the pretrained model.
Returns:
The scale factor, as a list in zyx order.
def
run_segmentation( image: numpy.ndarray, model: torch.nn.modules.module.Module, model_type: str, tiling: Optional[Dict[str, Dict[str, int]]] = None, scale: Optional[List[float]] = None, verbose: bool = False, **kwargs) -> Union[numpy.ndarray, Dict[str, numpy.ndarray]]:
218def run_segmentation( 219 image: np.ndarray, 220 model: torch.nn.Module, 221 model_type: str, 222 tiling: Optional[Dict[str, Dict[str, int]]] = None, 223 scale: Optional[List[float]] = None, 224 verbose: bool = False, 225 **kwargs, 226) -> np.ndarray | Dict[str, np.ndarray]: 227 """Run synaptic structure segmentation. 228 229 Args: 230 image: The input image or image volume. 231 model: The segmentation model. 232 model_type: The model type. This will determine which segmentation post-processing is used. 233 tiling: The tiling settings for inference. 234 scale: A scale factor for resizing the input before applying the model. 235 The output will be scaled back to the initial size. 236 verbose: Whether to print detailed information about the prediction and segmentation. 237 kwargs: Optional parameters for the segmentation function. 238 239 Returns: 240 The segmentation. For models that return multiple segmentations, this function returns a dictionary. 241 """ 242 if model_type.startswith("vesicles"): 243 segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 244 elif model_type == "mitochondria" or model_type == "mitochondria2" or model_type == "mitochondria3": 245 segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 246 elif model_type == "active_zone": 247 segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 248 elif model_type == "compartments": 249 segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 250 elif model_type == "ribbon": 251 segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 252 elif model_type == "cristae": 253 segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 254 else: 255 raise ValueError(f"Unknown model type: {model_type}") 256 return segmentation
Run synaptic structure segmentation.
Arguments:
- image: The input image or image volume.
- model: The segmentation model.
- model_type: The model type. This will determine which segmentation post-processing is used.
- tiling: The tiling settings for inference.
- scale: A scale factor for resizing the input before applying the model. The output will be scaled back to the initial size.
- verbose: Whether to print detailed information about the prediction and segmentation.
- kwargs: Optional parameters for the segmentation function.
Returns:
The segmentation. For models that return multiple segmentations, this function returns a dictionary.