synapse_net.inference.inference

  1import os
  2from typing import Dict, List, Optional, Union
  3
  4import torch
  5import numpy as np
  6import pooch
  7
  8from .active_zone import segment_active_zone
  9from .compartments import segment_compartments
 10from .mitochondria import segment_mitochondria
 11from .ribbon_synapse import segment_ribbon_synapse_structures
 12from .vesicles import segment_vesicles
 13from .cristae import segment_cristae
 14from .util import get_device
 15from ..file_utils import get_cache_dir
 16
 17
 18#
 19# Functions to access SynapseNet's pretrained models.
 20#
 21
 22
 23def _get_model_registry():
 24    registry = {
 25        "active_zone": "c23652a8fe06daa113546af6d3200c4c1dcc79917056c6ed7357b8c93548372a",
 26        "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1",
 27        "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186",
 28        "mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673",
 29        "mitochondria3": "aa690c49017f84fd333436ddcb995ffeaf9b8174e9505586a4caf735a5805e7c",
 30        "cristae": "f96c90484f4ea92ac0515a06e389cc117580f02c2aacdc44b5828820cf38c3c3",
 31        "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9",
 32        "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1",
 33        "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29",
 34        "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b",
 35        # Additional models that are only available in the CLI, not in the plugin model selection.
 36        "vesicles_2d_maus": "01506895df6343fc33ffc9c9eb3f975bf42eb4eaaaf4848bac83b57f1b46e460",
 37        "vesicles_3d_endbulb": "8582c7e3e5f16ef2bf34d6f9e34644862ca3c76835c9e7d44475c9dd7891d228",
 38        "vesicles_3d_innerear": "924f0f7cfb648a3a6931c1d48d8b1fdc6c0c0d2cb3330fe2cae49d13e7c3b69d",
 39    }
 40    urls = {
 41        "active_zone": "https://owncloud.gwdg.de/index.php/s/wpea9FH9waG4zJd/download",
 42        "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download",
 43        "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download",
 44        "mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download",
 45        "mitochondria3": "https://owncloud.gwdg.de/index.php/s/FwLQRCV2HnXDxsT/download",
 46        "cristae": "https://owncloud.gwdg.de/index.php/s/Df7OUOyQ1Kc2eEO/download",
 47        "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download",
 48        "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download",
 49        "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download",
 50        "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download",
 51        # Additional models that are only available in the CLI, not in the plugin model selection.
 52        "vesicles_2d_maus": "https://owncloud.gwdg.de/index.php/s/sZ8woLr0zs5zOpv/download",
 53        "vesicles_3d_endbulb": "https://owncloud.gwdg.de/index.php/s/16tmnWrEDpYIMzU/download",
 54        "vesicles_3d_innerear": "https://owncloud.gwdg.de/index.php/s/UFUCYivsCxrqISX/download",
 55    }
 56    cache_dir = get_cache_dir()
 57    models = pooch.create(
 58        path=os.path.join(cache_dir, "models"),
 59        base_url="",
 60        registry=registry,
 61        urls=urls,
 62    )
 63    return models
 64
 65
 66def get_model_path(model_type: str) -> str:
 67    """Get the local path to a pretrained model.
 68
 69    Args:
 70        The model type.
 71
 72    Returns:
 73        The local path to the model.
 74    """
 75    model_registry = _get_model_registry()
 76    model_path = model_registry.fetch(model_type)
 77    return model_path
 78
 79
 80def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
 81    """Get the model for a specific segmentation type.
 82
 83    Args:
 84        model_type: The model for one of the following segmentation tasks:
 85            'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
 86        device: The device to use.
 87
 88    Returns:
 89        The model.
 90    """
 91    if device is None:
 92        device = get_device(device)
 93    model_path = get_model_path(model_type)
 94    model = torch.load(model_path, weights_only=False)
 95    model.to(device)
 96    return model
 97
 98
 99#
100# Functions for training resolution / voxel size.
101#
102
103
104def get_model_training_resolution(model_type: str) -> Dict[str, float]:
105    """Get the average resolution / voxel size of the training data for a given pretrained model.
106
107    Args:
108        model_type: The name of the pretrained model.
109
110    Returns:
111        Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
112    """
113    resolutions = {
114        "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38},
115        "compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
116        "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
117        # TODO: this is just copied from the previous mito model, it may be necessary to update this.
118        "mitochondria2": {"x": 1.45, "y": 1.45, "z": 1.45},
119        "mitochondria3": {"x": 2.87, "y": 2.87, "z": 2.87},
120        "cristae": {"x": 1.44, "y": 1.44, "z": 1.44},
121        "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188},
122        "vesicles_2d": {"x": 1.35, "y": 1.35},
123        "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35},
124        "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88},
125        # TODO add the correct resolutions, these are the resolutions of the source models.
126        "vesicles_2d_maus": {"x": 1.35, "y": 1.35},
127        "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35},
128        "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35},
129    }
130    return resolutions[model_type]
131
132
133def compute_scale_from_voxel_size(
134    voxel_size: Dict[str, float],
135    model_type: str
136) -> List[float]:
137    """Compute the appropriate scale factor for inference with a given pretrained model.
138
139    Args:
140        voxel_size: The voxel size of the data for inference.
141        model_type: The name of the pretrained model.
142
143    Returns:
144        The scale factor, as a list in zyx order.
145    """
146    training_voxel_size = get_model_training_resolution(model_type)
147    scale = [
148        voxel_size["x"] / training_voxel_size["x"],
149        voxel_size["y"] / training_voxel_size["y"],
150    ]
151    if len(voxel_size) == 3 and len(training_voxel_size) == 3:
152        scale.append(
153            voxel_size["z"] / training_voxel_size["z"]
154        )
155    return scale
156
157
158#
159# Convenience functions for segmentation.
160#
161
162
163def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size):
164    from synapse_net.inference.postprocessing import (
165        segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based,
166    )
167
168    ribbon = segment_ribbon(
169        predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons,
170        max_vesicle_distance=40,
171    )
172    PD = segment_presynaptic_density(
173        predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40,
174    )
175    ref_segmentation = PD if PD.sum() > 0 else ribbon
176    membrane = segment_membrane_distance_based(
177        predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude,
178        resolution=resolution, min_size=min_membrane_size,
179    )
180
181    segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane}
182    return segmentations
183
184
185def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs):
186    # Parse additional keyword arguments from the kwargs.
187    vesicles = kwargs.pop("extra_segmentation")
188    threshold = kwargs.pop("threshold", 0.5)
189    n_slices_exclude = kwargs.pop("n_slices_exclude", 20)
190    n_ribbons = kwargs.pop("n_slices_exclude", 1)
191    resolution = kwargs.pop("resolution", None)
192    min_membrane_size = kwargs.pop("min_membrane_size", 0)
193
194    predictions = segment_ribbon_synapse_structures(
195        image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs
196    )
197
198    # Otherwise, just return the predictions.
199    if vesicles is None:
200        if verbose:
201            print("Vesicle segmentation was not passed, WILL NOT run post-processing.")
202        segmentations = predictions
203
204    # If the vesicles were passed then run additional post-processing.
205    else:
206        if verbose:
207            print("Vesicle segmentation was passed, WILL run post-processing.")
208        segmentations = _ribbon_AZ_postprocessing(
209            predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size
210        )
211
212    if return_predictions:
213        return segmentations, predictions
214    return segmentations
215
216
217def run_segmentation(
218    image: np.ndarray,
219    model: torch.nn.Module,
220    model_type: str,
221    tiling: Optional[Dict[str, Dict[str, int]]] = None,
222    scale: Optional[List[float]] = None,
223    verbose: bool = False,
224    **kwargs,
225) -> np.ndarray | Dict[str, np.ndarray]:
226    """Run synaptic structure segmentation.
227
228    Args:
229        image: The input image or image volume.
230        model: The segmentation model.
231        model_type: The model type. This will determine which segmentation post-processing is used.
232        tiling: The tiling settings for inference.
233        scale: A scale factor for resizing the input before applying the model.
234            The output will be scaled back to the initial size.
235        verbose: Whether to print detailed information about the prediction and segmentation.
236        kwargs: Optional parameters for the segmentation function.
237
238    Returns:
239        The segmentation. For models that return multiple segmentations, this function returns a dictionary.
240    """
241    if model_type.startswith("vesicles"):
242        segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
243    elif model_type == "mitochondria" or model_type == "mitochondria2" or model_type == "mitochondria3":
244        segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
245    elif model_type == "active_zone":
246        segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
247    elif model_type == "compartments":
248        segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
249    elif model_type == "ribbon":
250        segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
251    elif model_type == "cristae":
252        segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
253    else:
254        raise ValueError(f"Unknown model type: {model_type}")
255    return segmentation
def get_model_path(model_type: str) -> str:
67def get_model_path(model_type: str) -> str:
68    """Get the local path to a pretrained model.
69
70    Args:
71        The model type.
72
73    Returns:
74        The local path to the model.
75    """
76    model_registry = _get_model_registry()
77    model_path = model_registry.fetch(model_type)
78    return model_path

Get the local path to a pretrained model.

Arguments:
  • The model type.
Returns:

The local path to the model.

def get_model( model_type: str, device: Union[torch.device, str, NoneType] = None) -> torch.nn.modules.module.Module:
81def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
82    """Get the model for a specific segmentation type.
83
84    Args:
85        model_type: The model for one of the following segmentation tasks:
86            'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
87        device: The device to use.
88
89    Returns:
90        The model.
91    """
92    if device is None:
93        device = get_device(device)
94    model_path = get_model_path(model_type)
95    model = torch.load(model_path, weights_only=False)
96    model.to(device)
97    return model

Get the model for a specific segmentation type.

Arguments:
  • model_type: The model for one of the following segmentation tasks: 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'.
  • device: The device to use.
Returns:

The model.

def get_model_training_resolution(model_type: str) -> Dict[str, float]:
105def get_model_training_resolution(model_type: str) -> Dict[str, float]:
106    """Get the average resolution / voxel size of the training data for a given pretrained model.
107
108    Args:
109        model_type: The name of the pretrained model.
110
111    Returns:
112        Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
113    """
114    resolutions = {
115        "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38},
116        "compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
117        "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
118        # TODO: this is just copied from the previous mito model, it may be necessary to update this.
119        "mitochondria2": {"x": 1.45, "y": 1.45, "z": 1.45},
120        "mitochondria3": {"x": 2.87, "y": 2.87, "z": 2.87},
121        "cristae": {"x": 1.44, "y": 1.44, "z": 1.44},
122        "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188},
123        "vesicles_2d": {"x": 1.35, "y": 1.35},
124        "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35},
125        "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88},
126        # TODO add the correct resolutions, these are the resolutions of the source models.
127        "vesicles_2d_maus": {"x": 1.35, "y": 1.35},
128        "vesicles_3d_endbulb": {"x": 1.35, "y": 1.35, "z": 1.35},
129        "vesicles_3d_innerear": {"x": 1.35, "y": 1.35, "z": 1.35},
130    }
131    return resolutions[model_type]

Get the average resolution / voxel size of the training data for a given pretrained model.

Arguments:
  • model_type: The name of the pretrained model.
Returns:

Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.

def compute_scale_from_voxel_size(voxel_size: Dict[str, float], model_type: str) -> List[float]:
134def compute_scale_from_voxel_size(
135    voxel_size: Dict[str, float],
136    model_type: str
137) -> List[float]:
138    """Compute the appropriate scale factor for inference with a given pretrained model.
139
140    Args:
141        voxel_size: The voxel size of the data for inference.
142        model_type: The name of the pretrained model.
143
144    Returns:
145        The scale factor, as a list in zyx order.
146    """
147    training_voxel_size = get_model_training_resolution(model_type)
148    scale = [
149        voxel_size["x"] / training_voxel_size["x"],
150        voxel_size["y"] / training_voxel_size["y"],
151    ]
152    if len(voxel_size) == 3 and len(training_voxel_size) == 3:
153        scale.append(
154            voxel_size["z"] / training_voxel_size["z"]
155        )
156    return scale

Compute the appropriate scale factor for inference with a given pretrained model.

Arguments:
  • voxel_size: The voxel size of the data for inference.
  • model_type: The name of the pretrained model.
Returns:

The scale factor, as a list in zyx order.

def run_segmentation( image: numpy.ndarray, model: torch.nn.modules.module.Module, model_type: str, tiling: Optional[Dict[str, Dict[str, int]]] = None, scale: Optional[List[float]] = None, verbose: bool = False, **kwargs) -> Union[numpy.ndarray, Dict[str, numpy.ndarray]]:
218def run_segmentation(
219    image: np.ndarray,
220    model: torch.nn.Module,
221    model_type: str,
222    tiling: Optional[Dict[str, Dict[str, int]]] = None,
223    scale: Optional[List[float]] = None,
224    verbose: bool = False,
225    **kwargs,
226) -> np.ndarray | Dict[str, np.ndarray]:
227    """Run synaptic structure segmentation.
228
229    Args:
230        image: The input image or image volume.
231        model: The segmentation model.
232        model_type: The model type. This will determine which segmentation post-processing is used.
233        tiling: The tiling settings for inference.
234        scale: A scale factor for resizing the input before applying the model.
235            The output will be scaled back to the initial size.
236        verbose: Whether to print detailed information about the prediction and segmentation.
237        kwargs: Optional parameters for the segmentation function.
238
239    Returns:
240        The segmentation. For models that return multiple segmentations, this function returns a dictionary.
241    """
242    if model_type.startswith("vesicles"):
243        segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
244    elif model_type == "mitochondria" or model_type == "mitochondria2" or model_type == "mitochondria3":
245        segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
246    elif model_type == "active_zone":
247        segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
248    elif model_type == "compartments":
249        segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
250    elif model_type == "ribbon":
251        segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
252    elif model_type == "cristae":
253        segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
254    else:
255        raise ValueError(f"Unknown model type: {model_type}")
256    return segmentation

Run synaptic structure segmentation.

Arguments:
  • image: The input image or image volume.
  • model: The segmentation model.
  • model_type: The model type. This will determine which segmentation post-processing is used.
  • tiling: The tiling settings for inference.
  • scale: A scale factor for resizing the input before applying the model. The output will be scaled back to the initial size.
  • verbose: Whether to print detailed information about the prediction and segmentation.
  • kwargs: Optional parameters for the segmentation function.
Returns:

The segmentation. For models that return multiple segmentations, this function returns a dictionary.