synapse_net.inference.inference

  1import os
  2from typing import Dict, List, Optional, Union
  3
  4import torch
  5import numpy as np
  6import pooch
  7
  8from .active_zone import segment_active_zone
  9from .compartments import segment_compartments
 10from .mitochondria import segment_mitochondria
 11from .ribbon_synapse import segment_ribbon_synapse_structures
 12from .vesicles import segment_vesicles
 13from .cristae import segment_cristae
 14from .util import get_device
 15from ..file_utils import get_cache_dir
 16
 17
 18#
 19# Functions to access SynapseNet's pretrained models.
 20#
 21
 22
 23def _get_model_registry():
 24    registry = {
 25        "active_zone": "c23652a8fe06daa113546af6d3200c4c1dcc79917056c6ed7357b8c93548372a",
 26        "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1",
 27        "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186",
 28        "mitochondria2": "0ec4c48fb67ebcdf1c2a86710e1d5e40519758b867e49a6999d155e8eb15d459",
 29        "cristae": "f96c90484f4ea92ac0515a06e389cc117580f02c2aacdc44b5828820cf38c3c3",
 30        "cristae2": "0864945698862df043adc51c0034289a579b0622a61164e5ebd00a24ee25d075",
 31        "cristae3": "5cb8699487bd21204071cbfb784a7f5c2bb6ab5f347a9d02a913aa27ae70eca4",
 32        "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9",
 33        "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1",
 34        "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29",
 35        "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b",
 36        # Additional models that are only available in the CLI, not in the plugin model selection.
 37        "vesicles_2d_maus": "01506895df6343fc33ffc9c9eb3f975bf42eb4eaaaf4848bac83b57f1b46e460",
 38        "vesicles_3d_endbulb": "8582c7e3e5f16ef2bf34d6f9e34644862ca3c76835c9e7d44475c9dd7891d228",
 39        "vesicles_3d_innerear": "924f0f7cfb648a3a6931c1d48d8b1fdc6c0c0d2cb3330fe2cae49d13e7c3b69d",
 40    }
 41    urls = {
 42        "active_zone": "https://owncloud.gwdg.de/index.php/s/wpea9FH9waG4zJd/download",
 43        "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download",
 44        "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download",
 45        "mitochondria2": "https://owncloud.gwdg.de/index.php/s/jivHzhpsqXN3PoH/download",
 46        "cristae": "https://owncloud.gwdg.de/index.php/s/Df7OUOyQ1Kc2eEO/download",
 47        "cristae2": "https://owncloud.gwdg.de/index.php/s/qe0R5pRgH2m0pQ5/download",
 48        "cristae3": "https://owncloud.gwdg.de/index.php/s/51X1et8d7pOfdhU/download",
 49        "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download",
 50        "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download",
 51        "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download",
 52        "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download",
 53        # Additional models that are only available in the CLI, not in the plugin model selection.
 54        "vesicles_2d_maus": "https://owncloud.gwdg.de/index.php/s/sZ8woLr0zs5zOpv/download",
 55        "vesicles_3d_endbulb": "https://owncloud.gwdg.de/index.php/s/16tmnWrEDpYIMzU/download",
 56        "vesicles_3d_innerear": "https://owncloud.gwdg.de/index.php/s/UFUCYivsCxrqISX/download",
 57    }
 58    cache_dir = get_cache_dir()
 59    models = pooch.create(
 60        path=os.path.join(cache_dir, "models"),
 61        base_url="",
 62        registry=registry,
 63        urls=urls,
 64    )
 65    return models
 66
 67
 68def get_available_models() -> List[str]:
 69    """Get the names of all available pretrained models.
 70
 71    Returns:
 72        The list of available model names.
 73    """
 74    model_registry = _get_model_registry()
 75    return list(model_registry.urls.keys())
 76
 77
 78def get_model_path(model_type: str) -> str:
 79    """Get the local path to a pretrained model.
 80
 81    Args:
 82        The model type.
 83
 84    Returns:
 85        The local path to the model.
 86    """
 87    model_registry = _get_model_registry()
 88    model_path = model_registry.fetch(model_type)
 89    return model_path
 90
 91
 92def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
 93    """Get the model for a specific segmentation type.
 94
 95    Args:
 96        model_type: The model for one of the following segmentation tasks:
 97            'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
 98        device: The device to use.
 99
100    Returns:
101        The model.
102    """
103    if device is None:
104        device = get_device(device)
105    model_path = get_model_path(model_type)
106    model = torch.load(model_path, weights_only=False)
107    model.to(device)
108    return model
109
110
111#
112# Functions for training resolution / voxel size.
113#
114
115
116def get_model_training_resolution(model_type: str) -> Dict[str, float]:
117    """Get the average resolution / voxel size of the training data for a given pretrained model.
118
119    Args:
120        model_type: The name of the pretrained model.
121
122    Returns:
123        Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
124    """
125    resolutions = {
126        "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38},
127        "compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
128        "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
129        "mitochondria2": {"x": 2.87, "y": 2.87, "z": 2.87},
130        "cristae": {"x": 1.44, "y": 1.44, "z": 1.44},
131        "cristae2": {"x": 1.44, "y": 1.44, "z": 1.44},
132        "cristae3": {"x": 1.44, "y": 1.44, "z": 1.44},
133        "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188},
134        "vesicles_2d": {"x": 1.35, "y": 1.35},
135        "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35},
136        "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88},
137        # TODO add the correct resolutions, these are the resolutions of the source models.
138        "vesicles_2d_maus": {"x": 1.35, "y": 1.35},
139        "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35},
140        "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35},
141    }
142    return resolutions[model_type]
143
144
145def compute_scale_from_voxel_size(
146    voxel_size: Dict[str, float],
147    model_type: str
148) -> List[float]:
149    """Compute the appropriate scale factor for inference with a given pretrained model.
150
151    Args:
152        voxel_size: The voxel size of the data for inference.
153        model_type: The name of the pretrained model.
154
155    Returns:
156        The scale factor, as a list in zyx order.
157    """
158    training_voxel_size = get_model_training_resolution(model_type)
159    scale = [
160        voxel_size["x"] / training_voxel_size["x"],
161        voxel_size["y"] / training_voxel_size["y"],
162    ]
163    if len(voxel_size) == 3 and len(training_voxel_size) == 3:
164        scale.append(
165            voxel_size["z"] / training_voxel_size["z"]
166        )
167    return scale
168
169
170#
171# Convenience functions for segmentation.
172#
173
174
175def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size):
176    from synapse_net.inference.postprocessing import (
177        segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based,
178    )
179
180    ribbon = segment_ribbon(
181        predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons,
182        max_vesicle_distance=40,
183    )
184    PD = segment_presynaptic_density(
185        predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40,
186    )
187    ref_segmentation = PD if PD.sum() > 0 else ribbon
188    membrane = segment_membrane_distance_based(
189        predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude,
190        resolution=resolution, min_size=min_membrane_size,
191    )
192
193    segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane}
194    return segmentations
195
196
197def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs):
198    # Parse additional keyword arguments from the kwargs.
199    vesicles = kwargs.pop("extra_segmentation")
200    threshold = kwargs.pop("threshold", 0.5)
201    n_slices_exclude = kwargs.pop("n_slices_exclude", 20)
202    n_ribbons = kwargs.pop("n_slices_exclude", 1)
203    resolution = kwargs.pop("resolution", None)
204    min_membrane_size = kwargs.pop("min_membrane_size", 0)
205
206    predictions = segment_ribbon_synapse_structures(
207        image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs
208    )
209
210    # Otherwise, just return the predictions.
211    if vesicles is None:
212        if verbose:
213            print("Vesicle segmentation was not passed, WILL NOT run post-processing.")
214        segmentations = predictions
215
216    # If the vesicles were passed then run additional post-processing.
217    else:
218        if verbose:
219            print("Vesicle segmentation was passed, WILL run post-processing.")
220        segmentations = _ribbon_AZ_postprocessing(
221            predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size
222        )
223
224    if return_predictions:
225        return segmentations, predictions
226    return segmentations
227
228
229def run_segmentation(
230    image: np.ndarray,
231    model: torch.nn.Module,
232    model_type: str,
233    tiling: Optional[Dict[str, Dict[str, int]]] = None,
234    scale: Optional[List[float]] = None,
235    verbose: bool = False,
236    **kwargs,
237) -> np.ndarray | Dict[str, np.ndarray]:
238    """Run synaptic structure segmentation.
239
240    Args:
241        image: The input image or image volume.
242        model: The segmentation model.
243        model_type: The model type. This will determine which segmentation post-processing is used.
244        tiling: The tiling settings for inference.
245        scale: A scale factor for resizing the input before applying the model.
246            The output will be scaled back to the initial size.
247        verbose: Whether to print detailed information about the prediction and segmentation.
248        kwargs: Optional parameters for the segmentation function.
249
250    Returns:
251        The segmentation. For models that return multiple segmentations, this function returns a dictionary.
252    """
253    if model_type.startswith("vesicles"):
254        segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
255    elif model_type == "mitochondria" or model_type == "mitochondria2":
256        segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
257    elif model_type == "active_zone":
258        segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
259    elif model_type == "compartments":
260        segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
261    elif model_type == "ribbon":
262        segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
263    elif model_type == "cristae" or model_type == "cristae2" or model_type == "cristae3":
264        segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
265    else:
266        raise ValueError(f"Unknown model type: {model_type}")
267    return segmentation
def get_available_models() -> List[str]:
69def get_available_models() -> List[str]:
70    """Get the names of all available pretrained models.
71
72    Returns:
73        The list of available model names.
74    """
75    model_registry = _get_model_registry()
76    return list(model_registry.urls.keys())

Get the names of all available pretrained models.

Returns:

The list of available model names.

def get_model_path(model_type: str) -> str:
79def get_model_path(model_type: str) -> str:
80    """Get the local path to a pretrained model.
81
82    Args:
83        The model type.
84
85    Returns:
86        The local path to the model.
87    """
88    model_registry = _get_model_registry()
89    model_path = model_registry.fetch(model_type)
90    return model_path

Get the local path to a pretrained model.

Arguments:
  • The model type.
Returns:

The local path to the model.

def get_model( model_type: str, device: Union[torch.device, str, NoneType] = None) -> torch.nn.modules.module.Module:
 93def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
 94    """Get the model for a specific segmentation type.
 95
 96    Args:
 97        model_type: The model for one of the following segmentation tasks:
 98            'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
 99        device: The device to use.
100
101    Returns:
102        The model.
103    """
104    if device is None:
105        device = get_device(device)
106    model_path = get_model_path(model_type)
107    model = torch.load(model_path, weights_only=False)
108    model.to(device)
109    return model

Get the model for a specific segmentation type.

Arguments:
  • model_type: The model for one of the following segmentation tasks: 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
  • device: The device to use.
Returns:

The model.

def get_model_training_resolution(model_type: str) -> Dict[str, float]:
117def get_model_training_resolution(model_type: str) -> Dict[str, float]:
118    """Get the average resolution / voxel size of the training data for a given pretrained model.
119
120    Args:
121        model_type: The name of the pretrained model.
122
123    Returns:
124        Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
125    """
126    resolutions = {
127        "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38},
128        "compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
129        "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
130        "mitochondria2": {"x": 2.87, "y": 2.87, "z": 2.87},
131        "cristae": {"x": 1.44, "y": 1.44, "z": 1.44},
132        "cristae2": {"x": 1.44, "y": 1.44, "z": 1.44},
133        "cristae3": {"x": 1.44, "y": 1.44, "z": 1.44},
134        "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188},
135        "vesicles_2d": {"x": 1.35, "y": 1.35},
136        "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35},
137        "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88},
138        # TODO add the correct resolutions, these are the resolutions of the source models.
139        "vesicles_2d_maus": {"x": 1.35, "y": 1.35},
140        "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35},
141        "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35},
142    }
143    return resolutions[model_type]

Get the average resolution / voxel size of the training data for a given pretrained model.

Arguments:
  • model_type: The name of the pretrained model.
Returns:

Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.

def compute_scale_from_voxel_size(voxel_size: Dict[str, float], model_type: str) -> List[float]:
146def compute_scale_from_voxel_size(
147    voxel_size: Dict[str, float],
148    model_type: str
149) -> List[float]:
150    """Compute the appropriate scale factor for inference with a given pretrained model.
151
152    Args:
153        voxel_size: The voxel size of the data for inference.
154        model_type: The name of the pretrained model.
155
156    Returns:
157        The scale factor, as a list in zyx order.
158    """
159    training_voxel_size = get_model_training_resolution(model_type)
160    scale = [
161        voxel_size["x"] / training_voxel_size["x"],
162        voxel_size["y"] / training_voxel_size["y"],
163    ]
164    if len(voxel_size) == 3 and len(training_voxel_size) == 3:
165        scale.append(
166            voxel_size["z"] / training_voxel_size["z"]
167        )
168    return scale

Compute the appropriate scale factor for inference with a given pretrained model.

Arguments:
  • voxel_size: The voxel size of the data for inference.
  • model_type: The name of the pretrained model.
Returns:

The scale factor, as a list in zyx order.

def run_segmentation( image: numpy.ndarray, model: torch.nn.modules.module.Module, model_type: str, tiling: Optional[Dict[str, Dict[str, int]]] = None, scale: Optional[List[float]] = None, verbose: bool = False, **kwargs) -> Union[numpy.ndarray, Dict[str, numpy.ndarray]]:
230def run_segmentation(
231    image: np.ndarray,
232    model: torch.nn.Module,
233    model_type: str,
234    tiling: Optional[Dict[str, Dict[str, int]]] = None,
235    scale: Optional[List[float]] = None,
236    verbose: bool = False,
237    **kwargs,
238) -> np.ndarray | Dict[str, np.ndarray]:
239    """Run synaptic structure segmentation.
240
241    Args:
242        image: The input image or image volume.
243        model: The segmentation model.
244        model_type: The model type. This will determine which segmentation post-processing is used.
245        tiling: The tiling settings for inference.
246        scale: A scale factor for resizing the input before applying the model.
247            The output will be scaled back to the initial size.
248        verbose: Whether to print detailed information about the prediction and segmentation.
249        kwargs: Optional parameters for the segmentation function.
250
251    Returns:
252        The segmentation. For models that return multiple segmentations, this function returns a dictionary.
253    """
254    if model_type.startswith("vesicles"):
255        segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
256    elif model_type == "mitochondria" or model_type == "mitochondria2":
257        segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
258    elif model_type == "active_zone":
259        segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
260    elif model_type == "compartments":
261        segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
262    elif model_type == "ribbon":
263        segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
264    elif model_type == "cristae" or model_type == "cristae2" or model_type == "cristae3":
265        segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
266    else:
267        raise ValueError(f"Unknown model type: {model_type}")
268    return segmentation

Run synaptic structure segmentation.

Arguments:
  • image: The input image or image volume.
  • model: The segmentation model.
  • model_type: The model type. This will determine which segmentation post-processing is used.
  • tiling: The tiling settings for inference.
  • scale: A scale factor for resizing the input before applying the model. The output will be scaled back to the initial size.
  • verbose: Whether to print detailed information about the prediction and segmentation.
  • kwargs: Optional parameters for the segmentation function.
Returns:

The segmentation. For models that return multiple segmentations, this function returns a dictionary.