synapse_net.inference.inference

  1import os
  2from typing import Dict, List, Optional, Union
  3
  4import torch
  5import numpy as np
  6import pooch
  7
  8from .active_zone import segment_active_zone
  9from .compartments import segment_compartments
 10from .mitochondria import segment_mitochondria
 11from .ribbon_synapse import segment_ribbon_synapse_structures
 12from .vesicles import segment_vesicles
 13from .cristae import segment_cristae
 14from .util import get_device
 15from ..file_utils import get_cache_dir
 16
 17
 18#
 19# Functions to access SynapseNet's pretrained models.
 20#
 21
 22
 23def _get_model_registry():
 24    registry = {
 25        "active_zone": "c23652a8fe06daa113546af6d3200c4c1dcc79917056c6ed7357b8c93548372a",
 26        "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1",
 27        "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186",
 28        "mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673",
 29        "mitochondria3": "aa690c49017f84fd333436ddcb995ffeaf9b8174e9505586a4caf735a5805e7c",
 30        "cristae": "f96c90484f4ea92ac0515a06e389cc117580f02c2aacdc44b5828820cf38c3c3",
 31        "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9",
 32        "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1",
 33        "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29",
 34        "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b",
 35        # Additional models that are only available in the CLI, not in the plugin model selection.
 36        "vesicles_2d_maus": "01506895df6343fc33ffc9c9eb3f975bf42eb4eaaaf4848bac83b57f1b46e460",
 37        "vesicles_3d_endbulb": "8582c7e3e5f16ef2bf34d6f9e34644862ca3c76835c9e7d44475c9dd7891d228",
 38        "vesicles_3d_innerear": "924f0f7cfb648a3a6931c1d48d8b1fdc6c0c0d2cb3330fe2cae49d13e7c3b69d",
 39    }
 40    urls = {
 41        "active_zone": "https://owncloud.gwdg.de/index.php/s/wpea9FH9waG4zJd/download",
 42        "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download",
 43        "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download",
 44        "mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download",
 45        "mitochondria3": "https://owncloud.gwdg.de/index.php/s/FwLQRCV2HnXDxsT/download",
 46        "cristae": "https://owncloud.gwdg.de/index.php/s/Df7OUOyQ1Kc2eEO/download",
 47        "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download",
 48        "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download",
 49        "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download",
 50        "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download",
 51        # Additional models that are only available in the CLI, not in the plugin model selection.
 52        "vesicles_2d_maus": "https://owncloud.gwdg.de/index.php/s/sZ8woLr0zs5zOpv/download",
 53        "vesicles_3d_endbulb": "https://owncloud.gwdg.de/index.php/s/16tmnWrEDpYIMzU/download",
 54        "vesicles_3d_innerear": "https://owncloud.gwdg.de/index.php/s/UFUCYivsCxrqISX/download",
 55    }
 56    cache_dir = get_cache_dir()
 57    models = pooch.create(
 58        path=os.path.join(cache_dir, "models"),
 59        base_url="",
 60        registry=registry,
 61        urls=urls,
 62    )
 63    return models
 64
 65
 66def get_available_models() -> List[str]:
 67    """Get the names of all available pretrained models.
 68
 69    Returns:
 70        The list of available model names.
 71    """
 72    model_registry = _get_model_registry()
 73    return list(model_registry.urls.keys())
 74
 75
 76def get_model_path(model_type: str) -> str:
 77    """Get the local path to a pretrained model.
 78
 79    Args:
 80        The model type.
 81
 82    Returns:
 83        The local path to the model.
 84    """
 85    model_registry = _get_model_registry()
 86    model_path = model_registry.fetch(model_type)
 87    return model_path
 88
 89
 90def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
 91    """Get the model for a specific segmentation type.
 92
 93    Args:
 94        model_type: The model for one of the following segmentation tasks:
 95            'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
 96        device: The device to use.
 97
 98    Returns:
 99        The model.
100    """
101    if device is None:
102        device = get_device(device)
103    model_path = get_model_path(model_type)
104    model = torch.load(model_path, weights_only=False)
105    model.to(device)
106    return model
107
108
109#
110# Functions for training resolution / voxel size.
111#
112
113
114def get_model_training_resolution(model_type: str) -> Dict[str, float]:
115    """Get the average resolution / voxel size of the training data for a given pretrained model.
116
117    Args:
118        model_type: The name of the pretrained model.
119
120    Returns:
121        Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
122    """
123    resolutions = {
124        "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38},
125        "compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
126        "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
127        # TODO: this is just copied from the previous mito model, it may be necessary to update this.
128        "mitochondria2": {"x": 1.45, "y": 1.45, "z": 1.45},
129        "mitochondria3": {"x": 2.87, "y": 2.87, "z": 2.87},
130        "cristae": {"x": 1.44, "y": 1.44, "z": 1.44},
131        "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188},
132        "vesicles_2d": {"x": 1.35, "y": 1.35},
133        "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35},
134        "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88},
135        # TODO add the correct resolutions, these are the resolutions of the source models.
136        "vesicles_2d_maus": {"x": 1.35, "y": 1.35},
137        "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35},
138        "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35},
139    }
140    return resolutions[model_type]
141
142
143def compute_scale_from_voxel_size(
144    voxel_size: Dict[str, float],
145    model_type: str
146) -> List[float]:
147    """Compute the appropriate scale factor for inference with a given pretrained model.
148
149    Args:
150        voxel_size: The voxel size of the data for inference.
151        model_type: The name of the pretrained model.
152
153    Returns:
154        The scale factor, as a list in zyx order.
155    """
156    training_voxel_size = get_model_training_resolution(model_type)
157    scale = [
158        voxel_size["x"] / training_voxel_size["x"],
159        voxel_size["y"] / training_voxel_size["y"],
160    ]
161    if len(voxel_size) == 3 and len(training_voxel_size) == 3:
162        scale.append(
163            voxel_size["z"] / training_voxel_size["z"]
164        )
165    return scale
166
167
168#
169# Convenience functions for segmentation.
170#
171
172
173def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size):
174    from synapse_net.inference.postprocessing import (
175        segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based,
176    )
177
178    ribbon = segment_ribbon(
179        predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons,
180        max_vesicle_distance=40,
181    )
182    PD = segment_presynaptic_density(
183        predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40,
184    )
185    ref_segmentation = PD if PD.sum() > 0 else ribbon
186    membrane = segment_membrane_distance_based(
187        predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude,
188        resolution=resolution, min_size=min_membrane_size,
189    )
190
191    segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane}
192    return segmentations
193
194
195def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs):
196    # Parse additional keyword arguments from the kwargs.
197    vesicles = kwargs.pop("extra_segmentation")
198    threshold = kwargs.pop("threshold", 0.5)
199    n_slices_exclude = kwargs.pop("n_slices_exclude", 20)
200    n_ribbons = kwargs.pop("n_slices_exclude", 1)
201    resolution = kwargs.pop("resolution", None)
202    min_membrane_size = kwargs.pop("min_membrane_size", 0)
203
204    predictions = segment_ribbon_synapse_structures(
205        image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs
206    )
207
208    # Otherwise, just return the predictions.
209    if vesicles is None:
210        if verbose:
211            print("Vesicle segmentation was not passed, WILL NOT run post-processing.")
212        segmentations = predictions
213
214    # If the vesicles were passed then run additional post-processing.
215    else:
216        if verbose:
217            print("Vesicle segmentation was passed, WILL run post-processing.")
218        segmentations = _ribbon_AZ_postprocessing(
219            predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size
220        )
221
222    if return_predictions:
223        return segmentations, predictions
224    return segmentations
225
226
227def run_segmentation(
228    image: np.ndarray,
229    model: torch.nn.Module,
230    model_type: str,
231    tiling: Optional[Dict[str, Dict[str, int]]] = None,
232    scale: Optional[List[float]] = None,
233    verbose: bool = False,
234    **kwargs,
235) -> np.ndarray | Dict[str, np.ndarray]:
236    """Run synaptic structure segmentation.
237
238    Args:
239        image: The input image or image volume.
240        model: The segmentation model.
241        model_type: The model type. This will determine which segmentation post-processing is used.
242        tiling: The tiling settings for inference.
243        scale: A scale factor for resizing the input before applying the model.
244            The output will be scaled back to the initial size.
245        verbose: Whether to print detailed information about the prediction and segmentation.
246        kwargs: Optional parameters for the segmentation function.
247
248    Returns:
249        The segmentation. For models that return multiple segmentations, this function returns a dictionary.
250    """
251    if model_type.startswith("vesicles"):
252        segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
253    elif model_type == "mitochondria" or model_type == "mitochondria2" or model_type == "mitochondria3":
254        segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
255    elif model_type == "active_zone":
256        segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
257    elif model_type == "compartments":
258        segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
259    elif model_type == "ribbon":
260        segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
261    elif model_type == "cristae":
262        segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
263    else:
264        raise ValueError(f"Unknown model type: {model_type}")
265    return segmentation
def get_available_models() -> List[str]:
67def get_available_models() -> List[str]:
68    """Get the names of all available pretrained models.
69
70    Returns:
71        The list of available model names.
72    """
73    model_registry = _get_model_registry()
74    return list(model_registry.urls.keys())

Get the names of all available pretrained models.

Returns:

The list of available model names.

def get_model_path(model_type: str) -> str:
77def get_model_path(model_type: str) -> str:
78    """Get the local path to a pretrained model.
79
80    Args:
81        The model type.
82
83    Returns:
84        The local path to the model.
85    """
86    model_registry = _get_model_registry()
87    model_path = model_registry.fetch(model_type)
88    return model_path

Get the local path to a pretrained model.

Arguments:
  • The model type.
Returns:

The local path to the model.

def get_model( model_type: str, device: Union[torch.device, str, NoneType] = None) -> torch.nn.modules.module.Module:
 91def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
 92    """Get the model for a specific segmentation type.
 93
 94    Args:
 95        model_type: The model for one of the following segmentation tasks:
 96            'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
 97        device: The device to use.
 98
 99    Returns:
100        The model.
101    """
102    if device is None:
103        device = get_device(device)
104    model_path = get_model_path(model_type)
105    model = torch.load(model_path, weights_only=False)
106    model.to(device)
107    return model

Get the model for a specific segmentation type.

Arguments:
  • model_type: The model for one of the following segmentation tasks: 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
  • device: The device to use.
Returns:

The model.

def get_model_training_resolution(model_type: str) -> Dict[str, float]:
115def get_model_training_resolution(model_type: str) -> Dict[str, float]:
116    """Get the average resolution / voxel size of the training data for a given pretrained model.
117
118    Args:
119        model_type: The name of the pretrained model.
120
121    Returns:
122        Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
123    """
124    resolutions = {
125        "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38},
126        "compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
127        "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
128        # TODO: this is just copied from the previous mito model, it may be necessary to update this.
129        "mitochondria2": {"x": 1.45, "y": 1.45, "z": 1.45},
130        "mitochondria3": {"x": 2.87, "y": 2.87, "z": 2.87},
131        "cristae": {"x": 1.44, "y": 1.44, "z": 1.44},
132        "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188},
133        "vesicles_2d": {"x": 1.35, "y": 1.35},
134        "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35},
135        "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88},
136        # TODO add the correct resolutions, these are the resolutions of the source models.
137        "vesicles_2d_maus": {"x": 1.35, "y": 1.35},
138        "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35},
139        "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35},
140    }
141    return resolutions[model_type]

Get the average resolution / voxel size of the training data for a given pretrained model.

Arguments:
  • model_type: The name of the pretrained model.
Returns:

Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.

def compute_scale_from_voxel_size(voxel_size: Dict[str, float], model_type: str) -> List[float]:
144def compute_scale_from_voxel_size(
145    voxel_size: Dict[str, float],
146    model_type: str
147) -> List[float]:
148    """Compute the appropriate scale factor for inference with a given pretrained model.
149
150    Args:
151        voxel_size: The voxel size of the data for inference.
152        model_type: The name of the pretrained model.
153
154    Returns:
155        The scale factor, as a list in zyx order.
156    """
157    training_voxel_size = get_model_training_resolution(model_type)
158    scale = [
159        voxel_size["x"] / training_voxel_size["x"],
160        voxel_size["y"] / training_voxel_size["y"],
161    ]
162    if len(voxel_size) == 3 and len(training_voxel_size) == 3:
163        scale.append(
164            voxel_size["z"] / training_voxel_size["z"]
165        )
166    return scale

Compute the appropriate scale factor for inference with a given pretrained model.

Arguments:
  • voxel_size: The voxel size of the data for inference.
  • model_type: The name of the pretrained model.
Returns:

The scale factor, as a list in zyx order.

def run_segmentation( image: numpy.ndarray, model: torch.nn.modules.module.Module, model_type: str, tiling: Optional[Dict[str, Dict[str, int]]] = None, scale: Optional[List[float]] = None, verbose: bool = False, **kwargs) -> Union[numpy.ndarray, Dict[str, numpy.ndarray]]:
228def run_segmentation(
229    image: np.ndarray,
230    model: torch.nn.Module,
231    model_type: str,
232    tiling: Optional[Dict[str, Dict[str, int]]] = None,
233    scale: Optional[List[float]] = None,
234    verbose: bool = False,
235    **kwargs,
236) -> np.ndarray | Dict[str, np.ndarray]:
237    """Run synaptic structure segmentation.
238
239    Args:
240        image: The input image or image volume.
241        model: The segmentation model.
242        model_type: The model type. This will determine which segmentation post-processing is used.
243        tiling: The tiling settings for inference.
244        scale: A scale factor for resizing the input before applying the model.
245            The output will be scaled back to the initial size.
246        verbose: Whether to print detailed information about the prediction and segmentation.
247        kwargs: Optional parameters for the segmentation function.
248
249    Returns:
250        The segmentation. For models that return multiple segmentations, this function returns a dictionary.
251    """
252    if model_type.startswith("vesicles"):
253        segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
254    elif model_type == "mitochondria" or model_type == "mitochondria2" or model_type == "mitochondria3":
255        segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
256    elif model_type == "active_zone":
257        segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
258    elif model_type == "compartments":
259        segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
260    elif model_type == "ribbon":
261        segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
262    elif model_type == "cristae":
263        segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
264    else:
265        raise ValueError(f"Unknown model type: {model_type}")
266    return segmentation

Run synaptic structure segmentation.

Arguments:
  • image: The input image or image volume.
  • model: The segmentation model.
  • model_type: The model type. This will determine which segmentation post-processing is used.
  • tiling: The tiling settings for inference.
  • scale: A scale factor for resizing the input before applying the model. The output will be scaled back to the initial size.
  • verbose: Whether to print detailed information about the prediction and segmentation.
  • kwargs: Optional parameters for the segmentation function.
Returns:

The segmentation. For models that return multiple segmentations, this function returns a dictionary.