synapse_net.inference.inference

  1import os
  2from typing import Dict, List, Optional, Union
  3
  4import torch
  5import numpy as np
  6import pooch
  7
  8from .active_zone import segment_active_zone
  9from .compartments import segment_compartments
 10from .mitochondria import segment_mitochondria
 11from .ribbon_synapse import segment_ribbon_synapse_structures
 12from .vesicles import segment_vesicles
 13from .cristae import segment_cristae
 14from .util import get_device
 15from ..file_utils import get_cache_dir
 16
 17
 18#
 19# Functions to access SynapseNet's pretrained models.
 20#
 21
 22
 23def _get_model_registry():
 24    registry = {
 25        "active_zone": "c23652a8fe06daa113546af6d3200c4c1dcc79917056c6ed7357b8c93548372a",
 26        "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1",
 27        "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186",
 28        "mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673",
 29        "cristae": "f96c90484f4ea92ac0515a06e389cc117580f02c2aacdc44b5828820cf38c3c3",
 30        "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9",
 31        "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1",
 32        "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29",
 33        "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b",
 34        # Additional models that are only available in the CLI, not in the plugin model selection.
 35        "vesicles_2d_maus": "01506895df6343fc33ffc9c9eb3f975bf42eb4eaaaf4848bac83b57f1b46e460",
 36        "vesicles_3d_endbulb": "8582c7e3e5f16ef2bf34d6f9e34644862ca3c76835c9e7d44475c9dd7891d228",
 37        "vesicles_3d_innerear": "924f0f7cfb648a3a6931c1d48d8b1fdc6c0c0d2cb3330fe2cae49d13e7c3b69d",
 38    }
 39    urls = {
 40        "active_zone": "https://owncloud.gwdg.de/index.php/s/wpea9FH9waG4zJd/download",
 41        "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download",
 42        "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download",
 43        "mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download",
 44        "cristae": "https://owncloud.gwdg.de/index.php/s/Df7OUOyQ1Kc2eEO/download",
 45        "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download",
 46        "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download",
 47        "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download",
 48        "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download",
 49        # Additional models that are only available in the CLI, not in the plugin model selection.
 50        "vesicles_2d_maus": "https://owncloud.gwdg.de/index.php/s/sZ8woLr0zs5zOpv/download",
 51        "vesicles_3d_endbulb": "https://owncloud.gwdg.de/index.php/s/16tmnWrEDpYIMzU/download",
 52        "vesicles_3d_innerear": "https://owncloud.gwdg.de/index.php/s/UFUCYivsCxrqISX/download",
 53    }
 54    cache_dir = get_cache_dir()
 55    models = pooch.create(
 56        path=os.path.join(cache_dir, "models"),
 57        base_url="",
 58        registry=registry,
 59        urls=urls,
 60    )
 61    return models
 62
 63
 64def get_model_path(model_type: str) -> str:
 65    """Get the local path to a pretrained model.
 66
 67    Args:
 68        The model type.
 69
 70    Returns:
 71        The local path to the model.
 72    """
 73    model_registry = _get_model_registry()
 74    model_path = model_registry.fetch(model_type)
 75    return model_path
 76
 77
 78def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
 79    """Get the model for a specific segmentation type.
 80
 81    Args:
 82        model_type: The model for one of the following segmentation tasks:
 83            'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
 84        device: The device to use.
 85
 86    Returns:
 87        The model.
 88    """
 89    if device is None:
 90        device = get_device(device)
 91    model_path = get_model_path(model_type)
 92    model = torch.load(model_path, weights_only=False)
 93    model.to(device)
 94    return model
 95
 96
 97#
 98# Functions for training resolution / voxel size.
 99#
100
101
102def get_model_training_resolution(model_type: str) -> Dict[str, float]:
103    """Get the average resolution / voxel size of the training data for a given pretrained model.
104
105    Args:
106        model_type: The name of the pretrained model.
107
108    Returns:
109        Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
110    """
111    resolutions = {
112        "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38},
113        "compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
114        "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
115        "cristae": {"x": 1.44, "y": 1.44, "z": 1.44},
116        "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188},
117        "vesicles_2d": {"x": 1.35, "y": 1.35},
118        "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35},
119        "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88},
120        # TODO add the correct resolutions, these are the resolutions of the source models.
121        "vesicles_2d_maus": {"x": 1.35, "y": 1.35},
122        "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35},
123        "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35},
124    }
125    return resolutions[model_type]
126
127
128def compute_scale_from_voxel_size(
129    voxel_size: Dict[str, float],
130    model_type: str
131) -> List[float]:
132    """Compute the appropriate scale factor for inference with a given pretrained model.
133
134    Args:
135        voxel_size: The voxel size of the data for inference.
136        model_type: The name of the pretrained model.
137
138    Returns:
139        The scale factor, as a list in zyx order.
140    """
141    training_voxel_size = get_model_training_resolution(model_type)
142    scale = [
143        voxel_size["x"] / training_voxel_size["x"],
144        voxel_size["y"] / training_voxel_size["y"],
145    ]
146    if len(voxel_size) == 3 and len(training_voxel_size) == 3:
147        scale.append(
148            voxel_size["z"] / training_voxel_size["z"]
149        )
150    return scale
151
152
153#
154# Convenience functions for segmentation.
155#
156
157
158def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size):
159    from synapse_net.inference.postprocessing import (
160        segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based,
161    )
162
163    ribbon = segment_ribbon(
164        predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons,
165        max_vesicle_distance=40,
166    )
167    PD = segment_presynaptic_density(
168        predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40,
169    )
170    ref_segmentation = PD if PD.sum() > 0 else ribbon
171    membrane = segment_membrane_distance_based(
172        predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude,
173        resolution=resolution, min_size=min_membrane_size,
174    )
175
176    segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane}
177    return segmentations
178
179
180def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs):
181    # Parse additional keyword arguments from the kwargs.
182    vesicles = kwargs.pop("extra_segmentation")
183    threshold = kwargs.pop("threshold", 0.5)
184    n_slices_exclude = kwargs.pop("n_slices_exclude", 20)
185    n_ribbons = kwargs.pop("n_slices_exclude", 1)
186    resolution = kwargs.pop("resolution", None)
187    min_membrane_size = kwargs.pop("min_membrane_size", 0)
188
189    predictions = segment_ribbon_synapse_structures(
190        image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs
191    )
192
193    # Otherwise, just return the predictions.
194    if vesicles is None:
195        if verbose:
196            print("Vesicle segmentation was not passed, WILL NOT run post-processing.")
197        segmentations = predictions
198
199    # If the vesicles were passed then run additional post-processing.
200    else:
201        if verbose:
202            print("Vesicle segmentation was passed, WILL run post-processing.")
203        segmentations = _ribbon_AZ_postprocessing(
204            predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size
205        )
206
207    if return_predictions:
208        return segmentations, predictions
209    return segmentations
210
211
212def run_segmentation(
213    image: np.ndarray,
214    model: torch.nn.Module,
215    model_type: str,
216    tiling: Optional[Dict[str, Dict[str, int]]] = None,
217    scale: Optional[List[float]] = None,
218    verbose: bool = False,
219    **kwargs,
220) -> np.ndarray | Dict[str, np.ndarray]:
221    """Run synaptic structure segmentation.
222
223    Args:
224        image: The input image or image volume.
225        model: The segmentation model.
226        model_type: The model type. This will determine which segmentation post-processing is used.
227        tiling: The tiling settings for inference.
228        scale: A scale factor for resizing the input before applying the model.
229            The output will be scaled back to the initial size.
230        verbose: Whether to print detailed information about the prediction and segmentation.
231        kwargs: Optional parameters for the segmentation function.
232
233    Returns:
234        The segmentation. For models that return multiple segmentations, this function returns a dictionary.
235    """
236    if model_type.startswith("vesicles"):
237        segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
238    elif model_type == "mitochondria" or model_type == "mitochondria2":
239        segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
240    elif model_type == "active_zone":
241        segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
242    elif model_type == "compartments":
243        segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
244    elif model_type == "ribbon":
245        segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
246    elif model_type == "cristae":
247        segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
248    else:
249        raise ValueError(f"Unknown model type: {model_type}")
250    return segmentation
def get_model_path(model_type: str) -> str:
65def get_model_path(model_type: str) -> str:
66    """Get the local path to a pretrained model.
67
68    Args:
69        The model type.
70
71    Returns:
72        The local path to the model.
73    """
74    model_registry = _get_model_registry()
75    model_path = model_registry.fetch(model_type)
76    return model_path

Get the local path to a pretrained model.

Arguments:
  • The model type.
Returns:

The local path to the model.

def get_model( model_type: str, device: Union[torch.device, str, NoneType] = None) -> torch.nn.modules.module.Module:
79def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
80    """Get the model for a specific segmentation type.
81
82    Args:
83        model_type: The model for one of the following segmentation tasks:
84            'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
85        device: The device to use.
86
87    Returns:
88        The model.
89    """
90    if device is None:
91        device = get_device(device)
92    model_path = get_model_path(model_type)
93    model = torch.load(model_path, weights_only=False)
94    model.to(device)
95    return model

Get the model for a specific segmentation type.

Arguments:
  • model_type: The model for one of the following segmentation tasks: 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
  • device: The device to use.
Returns:

The model.

def get_model_training_resolution(model_type: str) -> Dict[str, float]:
103def get_model_training_resolution(model_type: str) -> Dict[str, float]:
104    """Get the average resolution / voxel size of the training data for a given pretrained model.
105
106    Args:
107        model_type: The name of the pretrained model.
108
109    Returns:
110        Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
111    """
112    resolutions = {
113        "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38},
114        "compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
115        "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
116        "cristae": {"x": 1.44, "y": 1.44, "z": 1.44},
117        "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188},
118        "vesicles_2d": {"x": 1.35, "y": 1.35},
119        "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35},
120        "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88},
121        # TODO add the correct resolutions, these are the resolutions of the source models.
122        "vesicles_2d_maus": {"x": 1.35, "y": 1.35},
123        "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35},
124        "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35},
125    }
126    return resolutions[model_type]

Get the average resolution / voxel size of the training data for a given pretrained model.

Arguments:
  • model_type: The name of the pretrained model.
Returns:

Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.

def compute_scale_from_voxel_size(voxel_size: Dict[str, float], model_type: str) -> List[float]:
129def compute_scale_from_voxel_size(
130    voxel_size: Dict[str, float],
131    model_type: str
132) -> List[float]:
133    """Compute the appropriate scale factor for inference with a given pretrained model.
134
135    Args:
136        voxel_size: The voxel size of the data for inference.
137        model_type: The name of the pretrained model.
138
139    Returns:
140        The scale factor, as a list in zyx order.
141    """
142    training_voxel_size = get_model_training_resolution(model_type)
143    scale = [
144        voxel_size["x"] / training_voxel_size["x"],
145        voxel_size["y"] / training_voxel_size["y"],
146    ]
147    if len(voxel_size) == 3 and len(training_voxel_size) == 3:
148        scale.append(
149            voxel_size["z"] / training_voxel_size["z"]
150        )
151    return scale

Compute the appropriate scale factor for inference with a given pretrained model.

Arguments:
  • voxel_size: The voxel size of the data for inference.
  • model_type: The name of the pretrained model.
Returns:

The scale factor, as a list in zyx order.

def run_segmentation( image: numpy.ndarray, model: torch.nn.modules.module.Module, model_type: str, tiling: Optional[Dict[str, Dict[str, int]]] = None, scale: Optional[List[float]] = None, verbose: bool = False, **kwargs) -> Union[numpy.ndarray, Dict[str, numpy.ndarray]]:
213def run_segmentation(
214    image: np.ndarray,
215    model: torch.nn.Module,
216    model_type: str,
217    tiling: Optional[Dict[str, Dict[str, int]]] = None,
218    scale: Optional[List[float]] = None,
219    verbose: bool = False,
220    **kwargs,
221) -> np.ndarray | Dict[str, np.ndarray]:
222    """Run synaptic structure segmentation.
223
224    Args:
225        image: The input image or image volume.
226        model: The segmentation model.
227        model_type: The model type. This will determine which segmentation post-processing is used.
228        tiling: The tiling settings for inference.
229        scale: A scale factor for resizing the input before applying the model.
230            The output will be scaled back to the initial size.
231        verbose: Whether to print detailed information about the prediction and segmentation.
232        kwargs: Optional parameters for the segmentation function.
233
234    Returns:
235        The segmentation. For models that return multiple segmentations, this function returns a dictionary.
236    """
237    if model_type.startswith("vesicles"):
238        segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
239    elif model_type == "mitochondria" or model_type == "mitochondria2":
240        segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
241    elif model_type == "active_zone":
242        segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
243    elif model_type == "compartments":
244        segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
245    elif model_type == "ribbon":
246        segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
247    elif model_type == "cristae":
248        segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
249    else:
250        raise ValueError(f"Unknown model type: {model_type}")
251    return segmentation

Run synaptic structure segmentation.

Arguments:
  • image: The input image or image volume.
  • model: The segmentation model.
  • model_type: The model type. This will determine which segmentation post-processing is used.
  • tiling: The tiling settings for inference.
  • scale: A scale factor for resizing the input before applying the model. The output will be scaled back to the initial size.
  • verbose: Whether to print detailed information about the prediction and segmentation.
  • kwargs: Optional parameters for the segmentation function.
Returns:

The segmentation. For models that return multiple segmentations, this function returns a dictionary.