synapse_net.inference.inference
1import os 2from typing import Dict, List, Optional, Union 3 4import torch 5import numpy as np 6import pooch 7 8from .active_zone import segment_active_zone 9from .compartments import segment_compartments 10from .mitochondria import segment_mitochondria 11from .ribbon_synapse import segment_ribbon_synapse_structures 12from .vesicles import segment_vesicles 13from .util import get_device 14from ..file_utils import get_cache_dir 15 16 17# 18# Functions to access SynapseNet's pretrained models. 19# 20 21 22def _get_model_registry(): 23 registry = { 24 "active_zone": "a18f29168aed72edec0f5c2cb1aa9a4baa227812db6082a6538fd38d9f43afb0", 25 "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1", 26 "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186", 27 "mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673", 28 "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9", 29 "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1", 30 "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29", 31 "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b", 32 } 33 urls = { 34 "active_zone": "https://owncloud.gwdg.de/index.php/s/zvuY342CyQebPsX/download", 35 "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download", 36 "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download", 37 "mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download", 38 "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download", 39 "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download", 40 "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download", 41 "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download", 42 } 43 cache_dir = get_cache_dir() 44 models = pooch.create( 45 path=os.path.join(cache_dir, "models"), 46 base_url="", 47 registry=registry, 48 urls=urls, 49 ) 50 return models 51 52 53def get_model_path(model_type: str) -> str: 54 """Get the local path to a pretrained model. 55 56 Args: 57 The model type. 58 59 Returns: 60 The local path to the model. 61 """ 62 model_registry = _get_model_registry() 63 model_path = model_registry.fetch(model_type) 64 return model_path 65 66 67def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 68 """Get the model for a specific segmentation type. 69 70 Args: 71 model_type: The model for one of the following segmentation tasks: 72 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. 73 device: The device to use. 74 75 Returns: 76 The model. 77 """ 78 if device is None: 79 device = get_device(device) 80 model_path = get_model_path(model_type) 81 model = torch.load(model_path, weights_only=False) 82 model.to(device) 83 return model 84 85 86# 87# Functions for training resolution / voxel size. 88# 89 90 91def get_model_training_resolution(model_type: str) -> Dict[str, float]: 92 """Get the average resolution / voxel size of the training data for a given pretrained model. 93 94 Args: 95 model_type: The name of the pretrained model. 96 97 Returns: 98 Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. 99 """ 100 resolutions = { 101 "active_zone": {"x": 1.44, "y": 1.44, "z": 1.44}, 102 "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, 103 "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, 104 "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, 105 "vesicles_2d": {"x": 1.35, "y": 1.35}, 106 "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, 107 "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, 108 } 109 return resolutions[model_type] 110 111 112def compute_scale_from_voxel_size( 113 voxel_size: Dict[str, float], 114 model_type: str 115) -> List[float]: 116 """Compute the appropriate scale factor for inference with a given pretrained model. 117 118 Args: 119 voxel_size: The voxel size of the data for inference. 120 model_type: The name of the pretrained model. 121 122 Returns: 123 The scale factor, as a list in zyx order. 124 """ 125 training_voxel_size = get_model_training_resolution(model_type) 126 scale = [ 127 voxel_size["x"] / training_voxel_size["x"], 128 voxel_size["y"] / training_voxel_size["y"], 129 ] 130 if len(voxel_size) == 3 and len(training_voxel_size) == 3: 131 scale.append( 132 voxel_size["z"] / training_voxel_size["z"] 133 ) 134 return scale 135 136 137# 138# Convenience functions for segmentation. 139# 140 141 142def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons): 143 from synapse_net.inference.postprocessing import ( 144 segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, 145 ) 146 147 ribbon = segment_ribbon( 148 predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons, 149 max_vesicle_distance=40, 150 ) 151 PD = segment_presynaptic_density( 152 predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40, 153 ) 154 ref_segmentation = PD if PD.sum() > 0 else ribbon 155 membrane = segment_membrane_distance_based( 156 predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude, 157 ) 158 159 segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane} 160 return segmentations 161 162 163def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs): 164 # Parse additional keyword arguments from the kwargs. 165 vesicles = kwargs.pop("extra_segmentation") 166 threshold = kwargs.pop("threshold", 0.5) 167 n_slices_exclude = kwargs.pop("n_slices_exclude", 20) 168 n_ribbons = kwargs.pop("n_slices_exclude", 1) 169 170 predictions = segment_ribbon_synapse_structures( 171 image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs 172 ) 173 174 # Otherwise, just return the predictions. 175 if vesicles is None: 176 if verbose: 177 print("Vesicle segmentation was not passed, WILL NOT run post-processing.") 178 segmentations = predictions 179 180 # If the vesicles were passed then run additional post-processing. 181 else: 182 if verbose: 183 print("Vesicle segmentation was passed, WILL run post-processing.") 184 segmentations = _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons) 185 186 if return_predictions: 187 return segmentations, predictions 188 return segmentations 189 190 191def run_segmentation( 192 image: np.ndarray, 193 model: torch.nn.Module, 194 model_type: str, 195 tiling: Optional[Dict[str, Dict[str, int]]] = None, 196 scale: Optional[List[float]] = None, 197 verbose: bool = False, 198 **kwargs, 199) -> np.ndarray | Dict[str, np.ndarray]: 200 """Run synaptic structure segmentation. 201 202 Args: 203 image: The input image or image volume. 204 model: The segmentation model. 205 model_type: The model type. This will determine which segmentation post-processing is used. 206 tiling: The tiling settings for inference. 207 scale: A scale factor for resizing the input before applying the model. 208 The output will be scaled back to the initial size. 209 verbose: Whether to print detailed information about the prediction and segmentation. 210 kwargs: Optional parameters for the segmentation function. 211 212 Returns: 213 The segmentation. For models that return multiple segmentations, this function returns a dictionary. 214 """ 215 if model_type.startswith("vesicles"): 216 segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 217 elif model_type == "mitochondria": 218 segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 219 elif model_type == "active_zone": 220 segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 221 elif model_type == "compartments": 222 segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 223 elif model_type == "ribbon": 224 segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 225 else: 226 raise ValueError(f"Unknown model type: {model_type}") 227 return segmentation
def
get_model_path(model_type: str) -> str:
54def get_model_path(model_type: str) -> str: 55 """Get the local path to a pretrained model. 56 57 Args: 58 The model type. 59 60 Returns: 61 The local path to the model. 62 """ 63 model_registry = _get_model_registry() 64 model_path = model_registry.fetch(model_type) 65 return model_path
Get the local path to a pretrained model.
Arguments:
- The model type.
Returns:
The local path to the model.
def
get_model( model_type: str, device: Union[str, torch.device, NoneType] = None) -> torch.nn.modules.module.Module:
68def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: 69 """Get the model for a specific segmentation type. 70 71 Args: 72 model_type: The model for one of the following segmentation tasks: 73 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. 74 device: The device to use. 75 76 Returns: 77 The model. 78 """ 79 if device is None: 80 device = get_device(device) 81 model_path = get_model_path(model_type) 82 model = torch.load(model_path, weights_only=False) 83 model.to(device) 84 return model
Get the model for a specific segmentation type.
Arguments:
- model_type: The model for one of the following segmentation tasks: 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
- device: The device to use.
Returns:
The model.
def
get_model_training_resolution(model_type: str) -> Dict[str, float]:
92def get_model_training_resolution(model_type: str) -> Dict[str, float]: 93 """Get the average resolution / voxel size of the training data for a given pretrained model. 94 95 Args: 96 model_type: The name of the pretrained model. 97 98 Returns: 99 Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. 100 """ 101 resolutions = { 102 "active_zone": {"x": 1.44, "y": 1.44, "z": 1.44}, 103 "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, 104 "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, 105 "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, 106 "vesicles_2d": {"x": 1.35, "y": 1.35}, 107 "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, 108 "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, 109 } 110 return resolutions[model_type]
Get the average resolution / voxel size of the training data for a given pretrained model.
Arguments:
- model_type: The name of the pretrained model.
Returns:
Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
def
compute_scale_from_voxel_size(voxel_size: Dict[str, float], model_type: str) -> List[float]:
113def compute_scale_from_voxel_size( 114 voxel_size: Dict[str, float], 115 model_type: str 116) -> List[float]: 117 """Compute the appropriate scale factor for inference with a given pretrained model. 118 119 Args: 120 voxel_size: The voxel size of the data for inference. 121 model_type: The name of the pretrained model. 122 123 Returns: 124 The scale factor, as a list in zyx order. 125 """ 126 training_voxel_size = get_model_training_resolution(model_type) 127 scale = [ 128 voxel_size["x"] / training_voxel_size["x"], 129 voxel_size["y"] / training_voxel_size["y"], 130 ] 131 if len(voxel_size) == 3 and len(training_voxel_size) == 3: 132 scale.append( 133 voxel_size["z"] / training_voxel_size["z"] 134 ) 135 return scale
Compute the appropriate scale factor for inference with a given pretrained model.
Arguments:
- voxel_size: The voxel size of the data for inference.
- model_type: The name of the pretrained model.
Returns:
The scale factor, as a list in zyx order.
def
run_segmentation( image: numpy.ndarray, model: torch.nn.modules.module.Module, model_type: str, tiling: Optional[Dict[str, Dict[str, int]]] = None, scale: Optional[List[float]] = None, verbose: bool = False, **kwargs) -> Union[numpy.ndarray, Dict[str, numpy.ndarray]]:
192def run_segmentation( 193 image: np.ndarray, 194 model: torch.nn.Module, 195 model_type: str, 196 tiling: Optional[Dict[str, Dict[str, int]]] = None, 197 scale: Optional[List[float]] = None, 198 verbose: bool = False, 199 **kwargs, 200) -> np.ndarray | Dict[str, np.ndarray]: 201 """Run synaptic structure segmentation. 202 203 Args: 204 image: The input image or image volume. 205 model: The segmentation model. 206 model_type: The model type. This will determine which segmentation post-processing is used. 207 tiling: The tiling settings for inference. 208 scale: A scale factor for resizing the input before applying the model. 209 The output will be scaled back to the initial size. 210 verbose: Whether to print detailed information about the prediction and segmentation. 211 kwargs: Optional parameters for the segmentation function. 212 213 Returns: 214 The segmentation. For models that return multiple segmentations, this function returns a dictionary. 215 """ 216 if model_type.startswith("vesicles"): 217 segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 218 elif model_type == "mitochondria": 219 segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 220 elif model_type == "active_zone": 221 segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 222 elif model_type == "compartments": 223 segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 224 elif model_type == "ribbon": 225 segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) 226 else: 227 raise ValueError(f"Unknown model type: {model_type}") 228 return segmentation
Run synaptic structure segmentation.
Arguments:
- image: The input image or image volume.
- model: The segmentation model.
- model_type: The model type. This will determine which segmentation post-processing is used.
- tiling: The tiling settings for inference.
- scale: A scale factor for resizing the input before applying the model. The output will be scaled back to the initial size.
- verbose: Whether to print detailed information about the prediction and segmentation.
- kwargs: Optional parameters for the segmentation function.
Returns:
The segmentation. For models that return multiple segmentations, this function returns a dictionary.