synapse_net.inference.actin

 1from typing import Optional, Dict, List, Union, Tuple
 2
 3import numpy as np
 4import torch
 5
 6from skimage.measure import label
 7from synapse_net.inference.util import apply_size_filter, get_prediction, _Scaler
 8
 9
10# TODO: How exactly do we post-process the actin?
11# Do we want to run an instance segmentation to extract
12# individual fibers?
13# For now we only do connected components to remove small
14# fragments and then binarize again.
15def segment_actin(
16    input_volume: np.ndarray,
17    model_path: Optional[str] = None,
18    model: Optional[torch.nn.Module] = None,
19    tiling: Optional[Dict[str, Dict[str, int]]] = None,
20    foreground_threshold: float = 0.5,
21    min_size: int = 0,
22    verbose: bool = True,
23    return_predictions: bool = False,
24    scale: Optional[List[float]] = None,
25    mask: Optional[np.ndarray] = None,
26) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
27    """Segment actin in an input volume.
28
29    Args:
30        input_volume: The input volume to segment.
31        model_path: The path to the model checkpoint if `model` is not provided.
32        model: Pre-loaded model. Either `model_path` or `model` is required.
33        tiling: The tiling configuration for the prediction.
34        foreground_threshold: Threshold for binarizing foreground predictions.
35        min_size: The minimum size of an actin fiber to be considered.
36        verbose: Whether to print timing information.
37        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
38        scale: The scale factor to use for rescaling the input volume before prediction.
39        mask: An optional mask that is used to restrict the segmentation.
40
41    Returns:
42        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
43        and the predictions if return_predictions is True.
44    """
45    if verbose:
46        print("Segmenting actin in volume of shape", input_volume.shape)
47    # Create the scaler to handle prediction with a different scaling factor.
48    scaler = _Scaler(scale, verbose)
49    input_volume = scaler.scale_input(input_volume)
50
51    # Run the prediction.
52    if mask is not None:
53        mask = scaler.scale_input(mask, is_segmentation=True)
54    pred = get_prediction(input_volume, model=model, model_path=model_path, tiling=tiling, verbose=verbose)
55    foreground, boundaries = pred[:2]
56
57    # TODO proper segmentation procedure
58    # NOTE: actin fiber recall may improve by choosing a lower foreground threshold
59    seg = foreground > foreground_threshold
60    if min_size > 0:
61        seg = label(seg)
62        seg = apply_size_filter(seg, min_size, verbose=verbose)
63        seg = (seg > 0).astype("uint8")
64    seg = scaler.rescale_output(seg, is_segmentation=True)
65
66    if return_predictions:
67        foreground = scaler.rescale_output(foreground, is_segmentation=True)
68        return seg, foreground
69    return seg
def segment_actin( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, foreground_threshold: float = 0.5, min_size: int = 0, verbose: bool = True, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
16def segment_actin(
17    input_volume: np.ndarray,
18    model_path: Optional[str] = None,
19    model: Optional[torch.nn.Module] = None,
20    tiling: Optional[Dict[str, Dict[str, int]]] = None,
21    foreground_threshold: float = 0.5,
22    min_size: int = 0,
23    verbose: bool = True,
24    return_predictions: bool = False,
25    scale: Optional[List[float]] = None,
26    mask: Optional[np.ndarray] = None,
27) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
28    """Segment actin in an input volume.
29
30    Args:
31        input_volume: The input volume to segment.
32        model_path: The path to the model checkpoint if `model` is not provided.
33        model: Pre-loaded model. Either `model_path` or `model` is required.
34        tiling: The tiling configuration for the prediction.
35        foreground_threshold: Threshold for binarizing foreground predictions.
36        min_size: The minimum size of an actin fiber to be considered.
37        verbose: Whether to print timing information.
38        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
39        scale: The scale factor to use for rescaling the input volume before prediction.
40        mask: An optional mask that is used to restrict the segmentation.
41
42    Returns:
43        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
44        and the predictions if return_predictions is True.
45    """
46    if verbose:
47        print("Segmenting actin in volume of shape", input_volume.shape)
48    # Create the scaler to handle prediction with a different scaling factor.
49    scaler = _Scaler(scale, verbose)
50    input_volume = scaler.scale_input(input_volume)
51
52    # Run the prediction.
53    if mask is not None:
54        mask = scaler.scale_input(mask, is_segmentation=True)
55    pred = get_prediction(input_volume, model=model, model_path=model_path, tiling=tiling, verbose=verbose)
56    foreground, boundaries = pred[:2]
57
58    # TODO proper segmentation procedure
59    # NOTE: actin fiber recall may improve by choosing a lower foreground threshold
60    seg = foreground > foreground_threshold
61    if min_size > 0:
62        seg = label(seg)
63        seg = apply_size_filter(seg, min_size, verbose=verbose)
64        seg = (seg > 0).astype("uint8")
65    seg = scaler.rescale_output(seg, is_segmentation=True)
66
67    if return_predictions:
68        foreground = scaler.rescale_output(foreground, is_segmentation=True)
69        return seg, foreground
70    return seg

Segment actin in an input volume.

Arguments:
  • input_volume: The input volume to segment.
  • model_path: The path to the model checkpoint if model is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • foreground_threshold: Threshold for binarizing foreground predictions.
  • min_size: The minimum size of an actin fiber to be considered.
  • verbose: Whether to print timing information.
  • return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
  • scale: The scale factor to use for rescaling the input volume before prediction.
  • mask: An optional mask that is used to restrict the segmentation.
Returns:

The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.