synapse_net.inference.inference

  1import os
  2from typing import Dict, List, Optional, Union
  3
  4import torch
  5import numpy as np
  6import pooch
  7
  8from .active_zone import segment_active_zone
  9from .compartments import segment_compartments
 10from .mitochondria import segment_mitochondria
 11from .ribbon_synapse import segment_ribbon_synapse_structures
 12from .vesicles import segment_vesicles
 13from .util import get_device
 14from ..file_utils import get_cache_dir
 15
 16
 17#
 18# Functions to access SynapseNet's pretrained models.
 19#
 20
 21
 22def _get_model_registry():
 23    registry = {
 24        "active_zone": "a18f29168aed72edec0f5c2cb1aa9a4baa227812db6082a6538fd38d9f43afb0",
 25        "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1",
 26        "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186",
 27        "mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673",
 28        "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9",
 29        "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1",
 30        "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29",
 31        "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b",
 32    }
 33    urls = {
 34        "active_zone": "https://owncloud.gwdg.de/index.php/s/zvuY342CyQebPsX/download",
 35        "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download",
 36        "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download",
 37        "mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download",
 38        "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download",
 39        "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download",
 40        "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download",
 41        "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download",
 42    }
 43    cache_dir = get_cache_dir()
 44    models = pooch.create(
 45        path=os.path.join(cache_dir, "models"),
 46        base_url="",
 47        registry=registry,
 48        urls=urls,
 49    )
 50    return models
 51
 52
 53def get_model_path(model_type: str) -> str:
 54    """Get the local path to a pretrained model.
 55
 56    Args:
 57        The model type.
 58
 59    Returns:
 60        The local path to the model.
 61    """
 62    model_registry = _get_model_registry()
 63    model_path = model_registry.fetch(model_type)
 64    return model_path
 65
 66
 67def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
 68    """Get the model for a specific segmentation type.
 69
 70    Args:
 71        model_type: The model for one of the following segmentation tasks:
 72            'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
 73        device: The device to use.
 74
 75    Returns:
 76        The model.
 77    """
 78    if device is None:
 79        device = get_device(device)
 80    model_path = get_model_path(model_type)
 81    model = torch.load(model_path, weights_only=False)
 82    model.to(device)
 83    return model
 84
 85
 86#
 87# Functions for training resolution / voxel size.
 88#
 89
 90
 91def get_model_training_resolution(model_type: str) -> Dict[str, float]:
 92    """Get the average resolution / voxel size of the training data for a given pretrained model.
 93
 94    Args:
 95        model_type: The name of the pretrained model.
 96
 97    Returns:
 98        Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
 99    """
100    resolutions = {
101        "active_zone": {"x": 1.44, "y": 1.44, "z": 1.44},
102        "compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
103        "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
104        "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188},
105        "vesicles_2d": {"x": 1.35, "y": 1.35},
106        "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35},
107        "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88},
108    }
109    return resolutions[model_type]
110
111
112def compute_scale_from_voxel_size(
113    voxel_size: Dict[str, float],
114    model_type: str
115) -> List[float]:
116    """Compute the appropriate scale factor for inference with a given pretrained model.
117
118    Args:
119        voxel_size: The voxel size of the data for inference.
120        model_type: The name of the pretrained model.
121
122    Returns:
123        The scale factor, as a list in zyx order.
124    """
125    training_voxel_size = get_model_training_resolution(model_type)
126    scale = [
127        voxel_size["x"] / training_voxel_size["x"],
128        voxel_size["y"] / training_voxel_size["y"],
129    ]
130    if len(voxel_size) == 3 and len(training_voxel_size) == 3:
131        scale.append(
132            voxel_size["z"] / training_voxel_size["z"]
133        )
134    return scale
135
136
137#
138# Convenience functions for segmentation.
139#
140
141
142def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons):
143    from synapse_net.inference.postprocessing import (
144        segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based,
145    )
146
147    ribbon = segment_ribbon(
148        predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons,
149        max_vesicle_distance=40,
150    )
151    PD = segment_presynaptic_density(
152        predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40,
153    )
154    ref_segmentation = PD if PD.sum() > 0 else ribbon
155    membrane = segment_membrane_distance_based(
156        predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude,
157    )
158
159    segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane}
160    return segmentations
161
162
163def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs):
164    # Parse additional keyword arguments from the kwargs.
165    vesicles = kwargs.pop("extra_segmentation")
166    threshold = kwargs.pop("threshold", 0.5)
167    n_slices_exclude = kwargs.pop("n_slices_exclude", 20)
168    n_ribbons = kwargs.pop("n_slices_exclude", 1)
169
170    predictions = segment_ribbon_synapse_structures(
171        image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs
172    )
173
174    # Otherwise, just return the predictions.
175    if vesicles is None:
176        if verbose:
177            print("Vesicle segmentation was not passed, WILL NOT run post-processing.")
178        segmentations = predictions
179
180    # If the vesicles were passed then run additional post-processing.
181    else:
182        if verbose:
183            print("Vesicle segmentation was passed, WILL run post-processing.")
184        segmentations = _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons)
185
186    if return_predictions:
187        return segmentations, predictions
188    return segmentations
189
190
191def run_segmentation(
192    image: np.ndarray,
193    model: torch.nn.Module,
194    model_type: str,
195    tiling: Optional[Dict[str, Dict[str, int]]] = None,
196    scale: Optional[List[float]] = None,
197    verbose: bool = False,
198    **kwargs,
199) -> np.ndarray | Dict[str, np.ndarray]:
200    """Run synaptic structure segmentation.
201
202    Args:
203        image: The input image or image volume.
204        model: The segmentation model.
205        model_type: The model type. This will determine which segmentation post-processing is used.
206        tiling: The tiling settings for inference.
207        scale: A scale factor for resizing the input before applying the model.
208            The output will be scaled back to the initial size.
209        verbose: Whether to print detailed information about the prediction and segmentation.
210        kwargs: Optional parameters for the segmentation function.
211
212    Returns:
213        The segmentation. For models that return multiple segmentations, this function returns a dictionary.
214    """
215    if model_type.startswith("vesicles"):
216        segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
217    elif model_type == "mitochondria":
218        segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
219    elif model_type == "active_zone":
220        segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
221    elif model_type == "compartments":
222        segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
223    elif model_type == "ribbon":
224        segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
225    else:
226        raise ValueError(f"Unknown model type: {model_type}")
227    return segmentation
def get_model_path(model_type: str) -> str:
54def get_model_path(model_type: str) -> str:
55    """Get the local path to a pretrained model.
56
57    Args:
58        The model type.
59
60    Returns:
61        The local path to the model.
62    """
63    model_registry = _get_model_registry()
64    model_path = model_registry.fetch(model_type)
65    return model_path

Get the local path to a pretrained model.

Arguments:
  • The model type.
Returns:

The local path to the model.

def get_model( model_type: str, device: Union[str, torch.device, NoneType] = None) -> torch.nn.modules.module.Module:
68def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
69    """Get the model for a specific segmentation type.
70
71    Args:
72        model_type: The model for one of the following segmentation tasks:
73            'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
74        device: The device to use.
75
76    Returns:
77        The model.
78    """
79    if device is None:
80        device = get_device(device)
81    model_path = get_model_path(model_type)
82    model = torch.load(model_path, weights_only=False)
83    model.to(device)
84    return model

Get the model for a specific segmentation type.

Arguments:
  • model_type: The model for one of the following segmentation tasks: 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
  • device: The device to use.
Returns:

The model.

def get_model_training_resolution(model_type: str) -> Dict[str, float]:
 92def get_model_training_resolution(model_type: str) -> Dict[str, float]:
 93    """Get the average resolution / voxel size of the training data for a given pretrained model.
 94
 95    Args:
 96        model_type: The name of the pretrained model.
 97
 98    Returns:
 99        Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
100    """
101    resolutions = {
102        "active_zone": {"x": 1.44, "y": 1.44, "z": 1.44},
103        "compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
104        "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
105        "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188},
106        "vesicles_2d": {"x": 1.35, "y": 1.35},
107        "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35},
108        "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88},
109    }
110    return resolutions[model_type]

Get the average resolution / voxel size of the training data for a given pretrained model.

Arguments:
  • model_type: The name of the pretrained model.
Returns:

Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.

def compute_scale_from_voxel_size(voxel_size: Dict[str, float], model_type: str) -> List[float]:
113def compute_scale_from_voxel_size(
114    voxel_size: Dict[str, float],
115    model_type: str
116) -> List[float]:
117    """Compute the appropriate scale factor for inference with a given pretrained model.
118
119    Args:
120        voxel_size: The voxel size of the data for inference.
121        model_type: The name of the pretrained model.
122
123    Returns:
124        The scale factor, as a list in zyx order.
125    """
126    training_voxel_size = get_model_training_resolution(model_type)
127    scale = [
128        voxel_size["x"] / training_voxel_size["x"],
129        voxel_size["y"] / training_voxel_size["y"],
130    ]
131    if len(voxel_size) == 3 and len(training_voxel_size) == 3:
132        scale.append(
133            voxel_size["z"] / training_voxel_size["z"]
134        )
135    return scale

Compute the appropriate scale factor for inference with a given pretrained model.

Arguments:
  • voxel_size: The voxel size of the data for inference.
  • model_type: The name of the pretrained model.
Returns:

The scale factor, as a list in zyx order.

def run_segmentation( image: numpy.ndarray, model: torch.nn.modules.module.Module, model_type: str, tiling: Optional[Dict[str, Dict[str, int]]] = None, scale: Optional[List[float]] = None, verbose: bool = False, **kwargs) -> Union[numpy.ndarray, Dict[str, numpy.ndarray]]:
192def run_segmentation(
193    image: np.ndarray,
194    model: torch.nn.Module,
195    model_type: str,
196    tiling: Optional[Dict[str, Dict[str, int]]] = None,
197    scale: Optional[List[float]] = None,
198    verbose: bool = False,
199    **kwargs,
200) -> np.ndarray | Dict[str, np.ndarray]:
201    """Run synaptic structure segmentation.
202
203    Args:
204        image: The input image or image volume.
205        model: The segmentation model.
206        model_type: The model type. This will determine which segmentation post-processing is used.
207        tiling: The tiling settings for inference.
208        scale: A scale factor for resizing the input before applying the model.
209            The output will be scaled back to the initial size.
210        verbose: Whether to print detailed information about the prediction and segmentation.
211        kwargs: Optional parameters for the segmentation function.
212
213    Returns:
214        The segmentation. For models that return multiple segmentations, this function returns a dictionary.
215    """
216    if model_type.startswith("vesicles"):
217        segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
218    elif model_type == "mitochondria":
219        segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
220    elif model_type == "active_zone":
221        segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
222    elif model_type == "compartments":
223        segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
224    elif model_type == "ribbon":
225        segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
226    else:
227        raise ValueError(f"Unknown model type: {model_type}")
228    return segmentation

Run synaptic structure segmentation.

Arguments:
  • image: The input image or image volume.
  • model: The segmentation model.
  • model_type: The model type. This will determine which segmentation post-processing is used.
  • tiling: The tiling settings for inference.
  • scale: A scale factor for resizing the input before applying the model. The output will be scaled back to the initial size.
  • verbose: Whether to print detailed information about the prediction and segmentation.
  • kwargs: Optional parameters for the segmentation function.
Returns:

The segmentation. For models that return multiple segmentations, this function returns a dictionary.