synapse_net.inference.ribbon_synapse

 1from typing import Dict, Sequence, Optional, Union
 2
 3import numpy as np
 4import torch
 5
 6from synapse_net.inference.util import get_prediction, _Scaler
 7
 8
 9def segment_ribbon_synapse_structures(
10    input_volume: np.ndarray,
11    model_path: Optional[str] = None,
12    model: Optional[torch.nn.Module] = None,
13    structure_names: Sequence[str] = ("ribbon", "PD", "membrane"),
14    verbose: bool = False,
15    tiling: Optional[Dict[str, Dict[str, int]]] = None,
16    threshold: Optional[Union[float, Dict[str, float]]] = None,
17    scale: Optional[Sequence[float]] = None,
18    mask: Optional[np.ndarray] = None,
19) -> np.ndarray:
20    """Segment ribbon synapse structures.
21
22    Args:
23        input_volume: The input volume to segment.
24        model_path: The path to the model checkpoint if 'model' is not provided.
25        model: Pre-loaded model. Either model_path or model is required.
26        structure_names: Names of the structures to be segmented.
27            The default network segments the ribbon, presynaptic density (pd) an local memrane.
28        tiling: The tiling configuration for the prediction.
29        verbose: Whether to print timing information.
30        threshold: The threshold for binarizing predictions.
31        scale: The scale factor to use for rescaling the input volume before prediction.
32        mask: An optional mask that is used to restrict the segmentation.
33
34    Returns:
35        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
36        and the predictions if return_predictions is True.
37    """
38    if verbose:
39        print("Segmenting ribbon synapse structures in volume of shape", input_volume.shape)
40    # Create the scaler to handle prediction with a different scaling factor.
41    scaler = _Scaler(scale, verbose)
42    input_volume = scaler.scale_input(input_volume)
43
44    if mask is not None:
45        mask = scaler.scale_input(mask, is_segmentation=True)
46    predictions = get_prediction(
47        input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose
48    )
49    assert len(structure_names) == predictions.shape[0]
50
51    predictions = {
52        name: scaler.rescale_output(predictions[i], is_segmentation=False) for i, name in enumerate(structure_names)
53    }
54    if threshold is not None:
55        for name in structure_names:
56            # We can either have a single threshold value or a threshold per structure
57            # that is given as a dictionary.
58            this_threshold = threshold if isinstance(threshold, float) else threshold[name]
59            predictions[name] = predictions[name] > this_threshold
60
61    return predictions
def segment_ribbon_synapse_structures( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, structure_names: Sequence[str] = ('ribbon', 'PD', 'membrane'), verbose: bool = False, tiling: Optional[Dict[str, Dict[str, int]]] = None, threshold: Union[float, Dict[str, float], NoneType] = None, scale: Optional[Sequence[float]] = None, mask: Optional[numpy.ndarray] = None) -> numpy.ndarray:
10def segment_ribbon_synapse_structures(
11    input_volume: np.ndarray,
12    model_path: Optional[str] = None,
13    model: Optional[torch.nn.Module] = None,
14    structure_names: Sequence[str] = ("ribbon", "PD", "membrane"),
15    verbose: bool = False,
16    tiling: Optional[Dict[str, Dict[str, int]]] = None,
17    threshold: Optional[Union[float, Dict[str, float]]] = None,
18    scale: Optional[Sequence[float]] = None,
19    mask: Optional[np.ndarray] = None,
20) -> np.ndarray:
21    """Segment ribbon synapse structures.
22
23    Args:
24        input_volume: The input volume to segment.
25        model_path: The path to the model checkpoint if 'model' is not provided.
26        model: Pre-loaded model. Either model_path or model is required.
27        structure_names: Names of the structures to be segmented.
28            The default network segments the ribbon, presynaptic density (pd) an local memrane.
29        tiling: The tiling configuration for the prediction.
30        verbose: Whether to print timing information.
31        threshold: The threshold for binarizing predictions.
32        scale: The scale factor to use for rescaling the input volume before prediction.
33        mask: An optional mask that is used to restrict the segmentation.
34
35    Returns:
36        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
37        and the predictions if return_predictions is True.
38    """
39    if verbose:
40        print("Segmenting ribbon synapse structures in volume of shape", input_volume.shape)
41    # Create the scaler to handle prediction with a different scaling factor.
42    scaler = _Scaler(scale, verbose)
43    input_volume = scaler.scale_input(input_volume)
44
45    if mask is not None:
46        mask = scaler.scale_input(mask, is_segmentation=True)
47    predictions = get_prediction(
48        input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose
49    )
50    assert len(structure_names) == predictions.shape[0]
51
52    predictions = {
53        name: scaler.rescale_output(predictions[i], is_segmentation=False) for i, name in enumerate(structure_names)
54    }
55    if threshold is not None:
56        for name in structure_names:
57            # We can either have a single threshold value or a threshold per structure
58            # that is given as a dictionary.
59            this_threshold = threshold if isinstance(threshold, float) else threshold[name]
60            predictions[name] = predictions[name] > this_threshold
61
62    return predictions

Segment ribbon synapse structures.

Arguments:
  • input_volume: The input volume to segment.
  • model_path: The path to the model checkpoint if 'model' is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • structure_names: Names of the structures to be segmented. The default network segments the ribbon, presynaptic density (pd) an local memrane.
  • tiling: The tiling configuration for the prediction.
  • verbose: Whether to print timing information.
  • threshold: The threshold for binarizing predictions.
  • scale: The scale factor to use for rescaling the input volume before prediction.
  • mask: An optional mask that is used to restrict the segmentation.
Returns:

The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.