synapse_net.inference.ribbon_synapse
1from typing import Dict, Sequence, Optional, Union 2 3import numpy as np 4import torch 5 6from synapse_net.inference.util import get_prediction, _Scaler 7 8 9def segment_ribbon_synapse_structures( 10 input_volume: np.ndarray, 11 model_path: Optional[str] = None, 12 model: Optional[torch.nn.Module] = None, 13 structure_names: Sequence[str] = ("ribbon", "PD", "membrane"), 14 verbose: bool = False, 15 tiling: Optional[Dict[str, Dict[str, int]]] = None, 16 threshold: Optional[Union[float, Dict[str, float]]] = None, 17 scale: Optional[Sequence[float]] = None, 18 mask: Optional[np.ndarray] = None, 19) -> np.ndarray: 20 """Segment ribbon synapse structures. 21 22 Args: 23 input_volume: The input volume to segment. 24 model_path: The path to the model checkpoint if 'model' is not provided. 25 model: Pre-loaded model. Either model_path or model is required. 26 structure_names: Names of the structures to be segmented. 27 The default network segments the ribbon, presynaptic density (pd) an local memrane. 28 tiling: The tiling configuration for the prediction. 29 verbose: Whether to print timing information. 30 threshold: The threshold for binarizing predictions. 31 scale: The scale factor to use for rescaling the input volume before prediction. 32 mask: An optional mask that is used to restrict the segmentation. 33 34 Returns: 35 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 36 and the predictions if return_predictions is True. 37 """ 38 if verbose: 39 print("Segmenting ribbon synapse structures in volume of shape", input_volume.shape) 40 # Create the scaler to handle prediction with a different scaling factor. 41 scaler = _Scaler(scale, verbose) 42 input_volume = scaler.scale_input(input_volume) 43 44 if mask is not None: 45 mask = scaler.scale_input(mask, is_segmentation=True) 46 predictions = get_prediction( 47 input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose 48 ) 49 assert len(structure_names) == predictions.shape[0] 50 51 predictions = { 52 name: scaler.rescale_output(predictions[i], is_segmentation=False) for i, name in enumerate(structure_names) 53 } 54 if threshold is not None: 55 for name in structure_names: 56 # We can either have a single threshold value or a threshold per structure 57 # that is given as a dictionary. 58 this_threshold = threshold if isinstance(threshold, float) else threshold[name] 59 predictions[name] = predictions[name] > this_threshold 60 61 return predictions
def
segment_ribbon_synapse_structures( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, structure_names: Sequence[str] = ('ribbon', 'PD', 'membrane'), verbose: bool = False, tiling: Optional[Dict[str, Dict[str, int]]] = None, threshold: Union[float, Dict[str, float], NoneType] = None, scale: Optional[Sequence[float]] = None, mask: Optional[numpy.ndarray] = None) -> numpy.ndarray:
10def segment_ribbon_synapse_structures( 11 input_volume: np.ndarray, 12 model_path: Optional[str] = None, 13 model: Optional[torch.nn.Module] = None, 14 structure_names: Sequence[str] = ("ribbon", "PD", "membrane"), 15 verbose: bool = False, 16 tiling: Optional[Dict[str, Dict[str, int]]] = None, 17 threshold: Optional[Union[float, Dict[str, float]]] = None, 18 scale: Optional[Sequence[float]] = None, 19 mask: Optional[np.ndarray] = None, 20) -> np.ndarray: 21 """Segment ribbon synapse structures. 22 23 Args: 24 input_volume: The input volume to segment. 25 model_path: The path to the model checkpoint if 'model' is not provided. 26 model: Pre-loaded model. Either model_path or model is required. 27 structure_names: Names of the structures to be segmented. 28 The default network segments the ribbon, presynaptic density (pd) an local memrane. 29 tiling: The tiling configuration for the prediction. 30 verbose: Whether to print timing information. 31 threshold: The threshold for binarizing predictions. 32 scale: The scale factor to use for rescaling the input volume before prediction. 33 mask: An optional mask that is used to restrict the segmentation. 34 35 Returns: 36 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 37 and the predictions if return_predictions is True. 38 """ 39 if verbose: 40 print("Segmenting ribbon synapse structures in volume of shape", input_volume.shape) 41 # Create the scaler to handle prediction with a different scaling factor. 42 scaler = _Scaler(scale, verbose) 43 input_volume = scaler.scale_input(input_volume) 44 45 if mask is not None: 46 mask = scaler.scale_input(mask, is_segmentation=True) 47 predictions = get_prediction( 48 input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose 49 ) 50 assert len(structure_names) == predictions.shape[0] 51 52 predictions = { 53 name: scaler.rescale_output(predictions[i], is_segmentation=False) for i, name in enumerate(structure_names) 54 } 55 if threshold is not None: 56 for name in structure_names: 57 # We can either have a single threshold value or a threshold per structure 58 # that is given as a dictionary. 59 this_threshold = threshold if isinstance(threshold, float) else threshold[name] 60 predictions[name] = predictions[name] > this_threshold 61 62 return predictions
Segment ribbon synapse structures.
Arguments:
- input_volume: The input volume to segment.
- model_path: The path to the model checkpoint if 'model' is not provided.
- model: Pre-loaded model. Either model_path or model is required.
- structure_names: Names of the structures to be segmented. The default network segments the ribbon, presynaptic density (pd) an local memrane.
- tiling: The tiling configuration for the prediction.
- verbose: Whether to print timing information.
- threshold: The threshold for binarizing predictions.
- scale: The scale factor to use for rescaling the input volume before prediction.
- mask: An optional mask that is used to restrict the segmentation.
Returns:
The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.