synapse_net.inference.actin
1from typing import Optional, Dict, List, Union, Tuple 2 3import numpy as np 4import torch 5 6from skimage.measure import label 7from synapse_net.inference.util import apply_size_filter, get_prediction, _Scaler 8 9 10# TODO: How exactly do we post-process the actin? 11# Do we want to run an instance segmentation to extract 12# individual fibers? 13# For now we only do connected components to remove small 14# fragments and then binarize again. 15def segment_actin( 16 input_volume: np.ndarray, 17 model_path: Optional[str] = None, 18 model: Optional[torch.nn.Module] = None, 19 tiling: Optional[Dict[str, Dict[str, int]]] = None, 20 foreground_threshold: float = 0.5, 21 min_size: int = 0, 22 verbose: bool = True, 23 return_predictions: bool = False, 24 scale: Optional[List[float]] = None, 25 mask: Optional[np.ndarray] = None, 26) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 27 """Segment actin in an input volume. 28 29 Args: 30 input_volume: The input volume to segment. 31 model_path: The path to the model checkpoint if `model` is not provided. 32 model: Pre-loaded model. Either `model_path` or `model` is required. 33 tiling: The tiling configuration for the prediction. 34 foreground_threshold: Threshold for binarizing foreground predictions. 35 min_size: The minimum size of an actin fiber to be considered. 36 verbose: Whether to print timing information. 37 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 38 scale: The scale factor to use for rescaling the input volume before prediction. 39 mask: An optional mask that is used to restrict the segmentation. 40 41 Returns: 42 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 43 and the predictions if return_predictions is True. 44 """ 45 if verbose: 46 print("Segmenting actin in volume of shape", input_volume.shape) 47 # Create the scaler to handle prediction with a different scaling factor. 48 scaler = _Scaler(scale, verbose) 49 input_volume = scaler.scale_input(input_volume) 50 51 # Run the prediction. 52 if mask is not None: 53 mask = scaler.scale_input(mask, is_segmentation=True) 54 pred = get_prediction(input_volume, model=model, model_path=model_path, tiling=tiling, verbose=verbose) 55 foreground, boundaries = pred[:2] 56 57 # TODO proper segmentation procedure 58 # NOTE: actin fiber recall may improve by choosing a lower foreground threshold 59 seg = foreground > foreground_threshold 60 if min_size > 0: 61 seg = label(seg) 62 seg = apply_size_filter(seg, min_size, verbose=verbose) 63 seg = (seg > 0).astype("uint8") 64 seg = scaler.rescale_output(seg, is_segmentation=True) 65 66 if return_predictions: 67 foreground = scaler.rescale_output(foreground, is_segmentation=True) 68 return seg, foreground 69 return seg
def
segment_actin( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, foreground_threshold: float = 0.5, min_size: int = 0, verbose: bool = True, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
16def segment_actin( 17 input_volume: np.ndarray, 18 model_path: Optional[str] = None, 19 model: Optional[torch.nn.Module] = None, 20 tiling: Optional[Dict[str, Dict[str, int]]] = None, 21 foreground_threshold: float = 0.5, 22 min_size: int = 0, 23 verbose: bool = True, 24 return_predictions: bool = False, 25 scale: Optional[List[float]] = None, 26 mask: Optional[np.ndarray] = None, 27) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 28 """Segment actin in an input volume. 29 30 Args: 31 input_volume: The input volume to segment. 32 model_path: The path to the model checkpoint if `model` is not provided. 33 model: Pre-loaded model. Either `model_path` or `model` is required. 34 tiling: The tiling configuration for the prediction. 35 foreground_threshold: Threshold for binarizing foreground predictions. 36 min_size: The minimum size of an actin fiber to be considered. 37 verbose: Whether to print timing information. 38 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 39 scale: The scale factor to use for rescaling the input volume before prediction. 40 mask: An optional mask that is used to restrict the segmentation. 41 42 Returns: 43 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 44 and the predictions if return_predictions is True. 45 """ 46 if verbose: 47 print("Segmenting actin in volume of shape", input_volume.shape) 48 # Create the scaler to handle prediction with a different scaling factor. 49 scaler = _Scaler(scale, verbose) 50 input_volume = scaler.scale_input(input_volume) 51 52 # Run the prediction. 53 if mask is not None: 54 mask = scaler.scale_input(mask, is_segmentation=True) 55 pred = get_prediction(input_volume, model=model, model_path=model_path, tiling=tiling, verbose=verbose) 56 foreground, boundaries = pred[:2] 57 58 # TODO proper segmentation procedure 59 # NOTE: actin fiber recall may improve by choosing a lower foreground threshold 60 seg = foreground > foreground_threshold 61 if min_size > 0: 62 seg = label(seg) 63 seg = apply_size_filter(seg, min_size, verbose=verbose) 64 seg = (seg > 0).astype("uint8") 65 seg = scaler.rescale_output(seg, is_segmentation=True) 66 67 if return_predictions: 68 foreground = scaler.rescale_output(foreground, is_segmentation=True) 69 return seg, foreground 70 return seg
Segment actin in an input volume.
Arguments:
- input_volume: The input volume to segment.
- model_path: The path to the model checkpoint if
model
is not provided. - model: Pre-loaded model. Either
model_path
ormodel
is required. - tiling: The tiling configuration for the prediction.
- foreground_threshold: Threshold for binarizing foreground predictions.
- min_size: The minimum size of an actin fiber to be considered.
- verbose: Whether to print timing information.
- return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
- scale: The scale factor to use for rescaling the input volume before prediction.
- mask: An optional mask that is used to restrict the segmentation.
Returns:
The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.