synapse_net.inference.cristae

 1import time
 2from typing import Dict, List, Optional, Tuple, Union
 3
 4import elf.parallel as parallel
 5import numpy as np
 6import torch
 7
 8from synapse_net.inference.util import get_prediction, _Scaler
 9
10
11def _run_segmentation(
12    foreground, verbose, min_size,
13    # blocking shapes for parallel computation
14    block_shape=(128, 256, 256),
15):
16
17    # get the segmentation via seeded watershed
18    t0 = time.time()
19    seg = parallel.label(foreground > 0.5, block_shape=block_shape, verbose=verbose)
20    if verbose:
21        print("Compute connected components in", time.time() - t0, "s")
22
23    # size filter
24    t0 = time.time()
25    ids, sizes = parallel.unique(seg, return_counts=True, block_shape=block_shape, verbose=verbose)
26    filter_ids = ids[sizes < min_size]
27    seg[np.isin(seg, filter_ids)] = 0
28    if verbose:
29        print("Size filter in", time.time() - t0, "s")
30    seg = np.where(seg > 0, 1, 0)
31    return seg
32
33
34def segment_cristae(
35    input_volume: np.ndarray,
36    model_path: Optional[str] = None,
37    model: Optional[torch.nn.Module] = None,
38    tiling: Optional[Dict[str, Dict[str, int]]] = None,
39    min_size: int = 500,
40    verbose: bool = True,
41    distance_based_segmentation: bool = False,
42    return_predictions: bool = False,
43    scale: Optional[List[float]] = None,
44    mask: Optional[np.ndarray] = None,
45) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
46    """Segment cristae in an input volume.
47
48    Args:
49        input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria
50        model_path: The path to the model checkpoint if `model` is not provided.
51        model: Pre-loaded model. Either `model_path` or `model` is required.
52        tiling: The tiling configuration for the prediction.
53        min_size: The minimum size of a cristae to be considered.
54        verbose: Whether to print timing information.
55        distance_based_segmentation: Whether to use distance-based segmentation.
56        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
57        scale: The scale factor to use for rescaling the input volume before prediction.
58        mask: An optional mask that is used to restrict the segmentation.
59
60    Returns:
61        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
62        and the predictions if return_predictions is True.
63    """
64    if verbose:
65        print("Segmenting cristae in volume of shape", input_volume.shape)
66    # Create the scaler to handle prediction with a different scaling factor.
67    scaler = _Scaler(scale, verbose)
68    input_volume = scaler.scale_input(input_volume)
69
70    # Run prediction and segmentation.
71    if mask is not None:
72        mask = scaler.scale_input(mask, is_segmentation=True)
73    pred = get_prediction(
74        input_volume, model_path=model_path, model=model, mask=mask,
75        tiling=tiling, with_channels=True, verbose=verbose
76    )
77    foreground, boundaries = pred[:2]
78    seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size)
79    seg = scaler.rescale_output(seg, is_segmentation=True)
80
81    if return_predictions:
82        pred = scaler.rescale_output(pred, is_segmentation=False)
83        return seg, pred
84    return seg
def segment_cristae( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 500, verbose: bool = True, distance_based_segmentation: bool = False, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
35def segment_cristae(
36    input_volume: np.ndarray,
37    model_path: Optional[str] = None,
38    model: Optional[torch.nn.Module] = None,
39    tiling: Optional[Dict[str, Dict[str, int]]] = None,
40    min_size: int = 500,
41    verbose: bool = True,
42    distance_based_segmentation: bool = False,
43    return_predictions: bool = False,
44    scale: Optional[List[float]] = None,
45    mask: Optional[np.ndarray] = None,
46) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
47    """Segment cristae in an input volume.
48
49    Args:
50        input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria
51        model_path: The path to the model checkpoint if `model` is not provided.
52        model: Pre-loaded model. Either `model_path` or `model` is required.
53        tiling: The tiling configuration for the prediction.
54        min_size: The minimum size of a cristae to be considered.
55        verbose: Whether to print timing information.
56        distance_based_segmentation: Whether to use distance-based segmentation.
57        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
58        scale: The scale factor to use for rescaling the input volume before prediction.
59        mask: An optional mask that is used to restrict the segmentation.
60
61    Returns:
62        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
63        and the predictions if return_predictions is True.
64    """
65    if verbose:
66        print("Segmenting cristae in volume of shape", input_volume.shape)
67    # Create the scaler to handle prediction with a different scaling factor.
68    scaler = _Scaler(scale, verbose)
69    input_volume = scaler.scale_input(input_volume)
70
71    # Run prediction and segmentation.
72    if mask is not None:
73        mask = scaler.scale_input(mask, is_segmentation=True)
74    pred = get_prediction(
75        input_volume, model_path=model_path, model=model, mask=mask,
76        tiling=tiling, with_channels=True, verbose=verbose
77    )
78    foreground, boundaries = pred[:2]
79    seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size)
80    seg = scaler.rescale_output(seg, is_segmentation=True)
81
82    if return_predictions:
83        pred = scaler.rescale_output(pred, is_segmentation=False)
84        return seg, pred
85    return seg

Segment cristae in an input volume.

Arguments:
  • input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria
  • model_path: The path to the model checkpoint if model is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • min_size: The minimum size of a cristae to be considered.
  • verbose: Whether to print timing information.
  • distance_based_segmentation: Whether to use distance-based segmentation.
  • return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
  • scale: The scale factor to use for rescaling the input volume before prediction.
  • mask: An optional mask that is used to restrict the segmentation.
Returns:

The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.