synapse_net.inference.cristae

  1import time
  2from typing import Dict, List, Optional, Tuple, Union
  3
  4import elf.parallel as parallel
  5from elf.wrapper.base import (
  6    SimpleTransformationWrapper,
  7    MultiTransformationWrapper,
  8)
  9from skimage.morphology import binary_erosion, ball
 10from skimage.measure import regionprops
 11import numpy as np
 12import torch
 13
 14from synapse_net.inference.util import get_prediction, _Scaler
 15
 16
 17def _erode_instances(mito_data, erode_voxels, verbose):
 18    """Erodes instances globally and returns a memory-efficient boolean mask."""
 19    if verbose:
 20        t_erode = time.time()
 21        print("Eroding mitochondria instances globally...")
 22
 23    footprint = ball(erode_voxels)
 24    props = regionprops(mito_data)
 25
 26    # Allocate a boolean array
 27    eroded_binary_mask = np.zeros(mito_data.shape, dtype=bool)
 28
 29    for prop in props:
 30        sl = prop.slice
 31
 32        # Isolate this specific instance within its bounding box
 33        instance_mask = (mito_data[sl] == prop.label)
 34
 35        # Apply erosion
 36        eroded_mask = binary_erosion(instance_mask, footprint=footprint)
 37
 38        # Write True to the newly eroded locations in our boolean array
 39        eroded_binary_mask[sl][eroded_mask] = True
 40
 41    if verbose:
 42        print(f"Instance erosion completed in {time.time() - t_erode:.2f} s")
 43
 44    return eroded_binary_mask
 45
 46
 47def _run_segmentation(
 48    foreground, verbose, min_size,
 49    # blocking shapes for parallel computation
 50    block_shape=(128, 256, 256),
 51    mito_seg=None,
 52    erode_voxels=3,
 53):
 54    mito_seg = _erode_instances(mito_seg, erode_voxels, verbose)
 55
 56    # Mask the foreground lazily
 57    # Even though mito_seg is now in memory, foreground might not be.
 58    # MultiTransformationWrapper safely handles this mix.
 59    def mask_foreground(inputs):
 60        fg_block, mito_block = inputs
 61        return np.where(mito_block != 0, fg_block, 0)
 62
 63    foreground = MultiTransformationWrapper(
 64        mask_foreground,
 65        foreground,
 66        mito_seg,
 67        apply_to_list=True
 68    )
 69
 70    # Apply the threshold lazily
 71    def threshold_block(block):
 72        return block > 0.5
 73
 74    binary_foreground = SimpleTransformationWrapper(
 75        foreground,
 76        transformation=threshold_block
 77    )
 78
 79    t0 = time.time()
 80    seg = parallel.label(binary_foreground, block_shape=block_shape, verbose=verbose)
 81    if verbose:
 82        print("Compute connected components in", time.time() - t0, "s")
 83
 84    # Size filter
 85    if min_size > 0:
 86        t0 = time.time()
 87        parallel.size_filter(seg, out=seg, min_size=min_size, block_shape=block_shape, verbose=verbose)
 88        if verbose:
 89            print("Size filter in", time.time() - t0, "s")
 90
 91    seg = np.where(seg > 0, 1, 0)
 92
 93    return seg
 94
 95
 96def segment_cristae(
 97    input_volume: np.ndarray,
 98    model_path: Optional[str] = None,
 99    model: Optional[torch.nn.Module] = None,
100    tiling: Optional[Dict[str, Dict[str, int]]] = None,
101    min_size: int = 500,
102    verbose: bool = True,
103    distance_based_segmentation: bool = False,
104    return_predictions: bool = False,
105    scale: Optional[List[float]] = None,
106    mask: Optional[np.ndarray] = None,
107    **kwargs
108) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
109    """Segment cristae in an input volume.
110
111    Args:
112        input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria
113        model_path: The path to the model checkpoint if `model` is not provided.
114        model: Pre-loaded model. Either `model_path` or `model` is required.
115        tiling: The tiling configuration for the prediction.
116        min_size: The minimum size of a cristae to be considered.
117        verbose: Whether to print timing information.
118        distance_based_segmentation: Whether to use distance-based segmentation.
119        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
120        scale: The scale factor to use for rescaling the input volume before prediction.
121        mask: An optional mask that is used to restrict the segmentation.
122
123    Returns:
124        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
125        and the predictions if return_predictions is True.
126    """
127    mitochondria = kwargs.pop("extra_segmentation", None)
128    if mitochondria is None:
129        # try extract from input volume
130        if input_volume.ndim == 4:
131            mitochondria = input_volume[1]
132            input_volume = input_volume[0]
133    if mitochondria is None:
134        raise ValueError("Mitochondria segmentation is required")
135    with_channels = kwargs.pop("with_channels", True)
136    channels_to_standardize = kwargs.pop("channels_to_standardize", [0])
137    if verbose:
138        print("Segmenting cristae in volume of shape", input_volume.shape)
139    # Create the scaler to handle prediction with a different scaling factor.
140    scaler = _Scaler(scale, verbose)
141    # rescale each channel
142    volume = scaler.scale_input(input_volume)
143    mito_seg = scaler.scale_input(mitochondria, is_segmentation=True)
144    input_volume = np.stack([volume, mito_seg], axis=0)
145
146    # Run prediction and segmentation.
147    if mask is not None:
148        mask = scaler.scale_input(mask, is_segmentation=True)
149    pred = get_prediction(
150        input_volume, model_path=model_path, model=model, mask=mask,
151        tiling=tiling, with_channels=with_channels, channels_to_standardize=channels_to_standardize, verbose=verbose
152    )
153    foreground, boundaries = pred[:2]
154    seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size, mito_seg=mito_seg)
155    seg = scaler.rescale_output(seg, is_segmentation=True)
156
157    if return_predictions:
158        pred = scaler.rescale_output(pred, is_segmentation=False)
159        return seg, pred
160    return seg
def segment_cristae( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 500, verbose: bool = True, distance_based_segmentation: bool = False, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, **kwargs) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
 97def segment_cristae(
 98    input_volume: np.ndarray,
 99    model_path: Optional[str] = None,
100    model: Optional[torch.nn.Module] = None,
101    tiling: Optional[Dict[str, Dict[str, int]]] = None,
102    min_size: int = 500,
103    verbose: bool = True,
104    distance_based_segmentation: bool = False,
105    return_predictions: bool = False,
106    scale: Optional[List[float]] = None,
107    mask: Optional[np.ndarray] = None,
108    **kwargs
109) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
110    """Segment cristae in an input volume.
111
112    Args:
113        input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria
114        model_path: The path to the model checkpoint if `model` is not provided.
115        model: Pre-loaded model. Either `model_path` or `model` is required.
116        tiling: The tiling configuration for the prediction.
117        min_size: The minimum size of a cristae to be considered.
118        verbose: Whether to print timing information.
119        distance_based_segmentation: Whether to use distance-based segmentation.
120        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
121        scale: The scale factor to use for rescaling the input volume before prediction.
122        mask: An optional mask that is used to restrict the segmentation.
123
124    Returns:
125        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
126        and the predictions if return_predictions is True.
127    """
128    mitochondria = kwargs.pop("extra_segmentation", None)
129    if mitochondria is None:
130        # try extract from input volume
131        if input_volume.ndim == 4:
132            mitochondria = input_volume[1]
133            input_volume = input_volume[0]
134    if mitochondria is None:
135        raise ValueError("Mitochondria segmentation is required")
136    with_channels = kwargs.pop("with_channels", True)
137    channels_to_standardize = kwargs.pop("channels_to_standardize", [0])
138    if verbose:
139        print("Segmenting cristae in volume of shape", input_volume.shape)
140    # Create the scaler to handle prediction with a different scaling factor.
141    scaler = _Scaler(scale, verbose)
142    # rescale each channel
143    volume = scaler.scale_input(input_volume)
144    mito_seg = scaler.scale_input(mitochondria, is_segmentation=True)
145    input_volume = np.stack([volume, mito_seg], axis=0)
146
147    # Run prediction and segmentation.
148    if mask is not None:
149        mask = scaler.scale_input(mask, is_segmentation=True)
150    pred = get_prediction(
151        input_volume, model_path=model_path, model=model, mask=mask,
152        tiling=tiling, with_channels=with_channels, channels_to_standardize=channels_to_standardize, verbose=verbose
153    )
154    foreground, boundaries = pred[:2]
155    seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size, mito_seg=mito_seg)
156    seg = scaler.rescale_output(seg, is_segmentation=True)
157
158    if return_predictions:
159        pred = scaler.rescale_output(pred, is_segmentation=False)
160        return seg, pred
161    return seg

Segment cristae in an input volume.

Arguments:
  • input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria
  • model_path: The path to the model checkpoint if model is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • min_size: The minimum size of a cristae to be considered.
  • verbose: Whether to print timing information.
  • distance_based_segmentation: Whether to use distance-based segmentation.
  • return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
  • scale: The scale factor to use for rescaling the input volume before prediction.
  • mask: An optional mask that is used to restrict the segmentation.
Returns:

The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.