synapse_net.inference.cristae

  1import time
  2from typing import Dict, List, Optional, Tuple, Union
  3
  4import elf.parallel as parallel
  5from elf.wrapper.base import (
  6    SimpleTransformationWrapper,
  7    MultiTransformationWrapper,
  8)
  9from skimage.morphology import binary_erosion, ball
 10from skimage.measure import regionprops
 11import numpy as np
 12import torch
 13
 14from synapse_net.inference.util import get_prediction, _Scaler
 15
 16
 17def _erode_instances(mito_data, erode_voxels, verbose):
 18    """Erodes instances globally and returns a memory-efficient boolean mask."""
 19    if verbose:
 20        t_erode = time.time()
 21        print(f"Eroding mitochondria instances globally by {erode_voxels} voxels ...")
 22
 23    footprint = ball(erode_voxels)
 24    props = regionprops(mito_data)
 25
 26    # Allocate a boolean array
 27    eroded_binary_mask = np.zeros(mito_data.shape, dtype=bool)
 28
 29    for prop in props:
 30        sl = prop.slice
 31
 32        # Isolate this specific instance within its bounding box
 33        instance_mask = (mito_data[sl] == prop.label)
 34
 35        # Apply erosion
 36        eroded_mask = binary_erosion(instance_mask, footprint=footprint)
 37
 38        # Write True to the newly eroded locations in our boolean array
 39        eroded_binary_mask[sl][eroded_mask] = True
 40
 41    if verbose:
 42        print(f"Instance erosion completed in {time.time() - t_erode:.2f} s")
 43
 44    return eroded_binary_mask
 45
 46
 47def _run_segmentation(
 48    foreground, verbose, min_size,
 49    # blocking shapes for parallel computation
 50    block_shape=(128, 256, 256),
 51    mito_seg=None,
 52    erode_voxels=3,
 53):
 54    mito_seg = _erode_instances(mito_seg, erode_voxels, verbose)
 55
 56    # Mask the foreground lazily
 57    # Even though mito_seg is now in memory, foreground might not be.
 58    # MultiTransformationWrapper safely handles this mix.
 59    def mask_foreground(inputs):
 60        fg_block, mito_block = inputs
 61        return np.where(mito_block != 0, fg_block, 0)
 62
 63    foreground = MultiTransformationWrapper(
 64        mask_foreground,
 65        foreground,
 66        mito_seg,
 67        apply_to_list=True
 68    )
 69
 70    # Apply the threshold lazily
 71    def threshold_block(block):
 72        return block > 0.5
 73
 74    binary_foreground = SimpleTransformationWrapper(
 75        foreground,
 76        transformation=threshold_block
 77    )
 78
 79    t0 = time.time()
 80    seg = parallel.label(binary_foreground, block_shape=block_shape, verbose=verbose)
 81    if verbose:
 82        print("Compute connected components in", time.time() - t0, "s")
 83
 84    # Size filter
 85    if min_size > 0:
 86        t0 = time.time()
 87        parallel.size_filter(seg, out=seg, min_size=min_size, block_shape=block_shape, verbose=verbose)
 88        if verbose:
 89            print("Size filter in", time.time() - t0, "s")
 90
 91    seg = np.where(seg > 0, 1, 0)
 92
 93    return seg
 94
 95
 96def segment_cristae(
 97    input_volume: np.ndarray,
 98    voxel_size: float,
 99    model_path: Optional[str] = None,
100    model: Optional[torch.nn.Module] = None,
101    tiling: Optional[Dict[str, Dict[str, int]]] = None,
102    min_size: int = 2000,
103    verbose: bool = True,
104    distance_based_segmentation: bool = False,
105    return_predictions: bool = False,
106    scale: Optional[List[float]] = None,
107    mask: Optional[np.ndarray] = None,
108    **kwargs
109) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
110    """Segment cristae in an input volume.
111
112    Args:
113        input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria
114        voxel_size: The voxel size of the model's training data.
115        model_path: The path to the model checkpoint if `model` is not provided.
116        model: Pre-loaded model. Either `model_path` or `model` is required.
117        tiling: The tiling configuration for the prediction.
118        min_size: The minimum size of a cristae to be considered.
119        verbose: Whether to print timing information.
120        distance_based_segmentation: Whether to use distance-based segmentation.
121        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
122        scale: The scale factor to use for rescaling the input volume before prediction.
123        mask: An optional mask that is used to restrict the segmentation.
124
125    Returns:
126        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
127        and the predictions if return_predictions is True.
128    """
129    mitochondria = kwargs.pop("extra_segmentation", None)
130    if mitochondria is None:
131        # try extract from input volume
132        if input_volume.ndim == 4:
133            mitochondria = input_volume[1]
134            input_volume = input_volume[0]
135    if mitochondria is None:
136        raise ValueError("Mitochondria segmentation is required")
137    with_channels = kwargs.pop("with_channels", True)
138    channels_to_standardize = kwargs.pop("channels_to_standardize", [0])
139    if verbose:
140        print("Segmenting cristae in volume of shape", input_volume.shape)
141    # Create the scaler to handle prediction with a different scaling factor.
142    scaler = _Scaler(scale, verbose)
143    # rescale each channel
144    volume = scaler.scale_input(input_volume)
145    mito_seg = scaler.scale_input(mitochondria, is_segmentation=True)
146    input_volume = np.stack([volume, mito_seg], axis=0)
147
148    # target 10nm erosion for mitochondria
149    # voxel_size is the model's training voxel size, which is the space we erode in
150    erode_voxels = max(1, round(10.0 / voxel_size))
151
152    # Run prediction and segmentation.
153    if mask is not None:
154        mask = scaler.scale_input(mask, is_segmentation=True)
155    pred = get_prediction(
156        input_volume, model_path=model_path, model=model, mask=mask,
157        tiling=tiling, with_channels=with_channels, channels_to_standardize=channels_to_standardize, verbose=verbose
158    )
159    foreground, boundaries = pred[:2]
160    seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size, mito_seg=mito_seg,
161                            erode_voxels=erode_voxels)
162    seg = scaler.rescale_output(seg, is_segmentation=True)
163
164    if return_predictions:
165        pred = scaler.rescale_output(pred, is_segmentation=False)
166        return seg, pred
167    return seg
def segment_cristae( input_volume: numpy.ndarray, voxel_size: float, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 2000, verbose: bool = True, distance_based_segmentation: bool = False, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, **kwargs) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
 97def segment_cristae(
 98    input_volume: np.ndarray,
 99    voxel_size: float,
100    model_path: Optional[str] = None,
101    model: Optional[torch.nn.Module] = None,
102    tiling: Optional[Dict[str, Dict[str, int]]] = None,
103    min_size: int = 2000,
104    verbose: bool = True,
105    distance_based_segmentation: bool = False,
106    return_predictions: bool = False,
107    scale: Optional[List[float]] = None,
108    mask: Optional[np.ndarray] = None,
109    **kwargs
110) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
111    """Segment cristae in an input volume.
112
113    Args:
114        input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria
115        voxel_size: The voxel size of the model's training data.
116        model_path: The path to the model checkpoint if `model` is not provided.
117        model: Pre-loaded model. Either `model_path` or `model` is required.
118        tiling: The tiling configuration for the prediction.
119        min_size: The minimum size of a cristae to be considered.
120        verbose: Whether to print timing information.
121        distance_based_segmentation: Whether to use distance-based segmentation.
122        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
123        scale: The scale factor to use for rescaling the input volume before prediction.
124        mask: An optional mask that is used to restrict the segmentation.
125
126    Returns:
127        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
128        and the predictions if return_predictions is True.
129    """
130    mitochondria = kwargs.pop("extra_segmentation", None)
131    if mitochondria is None:
132        # try extract from input volume
133        if input_volume.ndim == 4:
134            mitochondria = input_volume[1]
135            input_volume = input_volume[0]
136    if mitochondria is None:
137        raise ValueError("Mitochondria segmentation is required")
138    with_channels = kwargs.pop("with_channels", True)
139    channels_to_standardize = kwargs.pop("channels_to_standardize", [0])
140    if verbose:
141        print("Segmenting cristae in volume of shape", input_volume.shape)
142    # Create the scaler to handle prediction with a different scaling factor.
143    scaler = _Scaler(scale, verbose)
144    # rescale each channel
145    volume = scaler.scale_input(input_volume)
146    mito_seg = scaler.scale_input(mitochondria, is_segmentation=True)
147    input_volume = np.stack([volume, mito_seg], axis=0)
148
149    # target 10nm erosion for mitochondria
150    # voxel_size is the model's training voxel size, which is the space we erode in
151    erode_voxels = max(1, round(10.0 / voxel_size))
152
153    # Run prediction and segmentation.
154    if mask is not None:
155        mask = scaler.scale_input(mask, is_segmentation=True)
156    pred = get_prediction(
157        input_volume, model_path=model_path, model=model, mask=mask,
158        tiling=tiling, with_channels=with_channels, channels_to_standardize=channels_to_standardize, verbose=verbose
159    )
160    foreground, boundaries = pred[:2]
161    seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size, mito_seg=mito_seg,
162                            erode_voxels=erode_voxels)
163    seg = scaler.rescale_output(seg, is_segmentation=True)
164
165    if return_predictions:
166        pred = scaler.rescale_output(pred, is_segmentation=False)
167        return seg, pred
168    return seg

Segment cristae in an input volume.

Arguments:
  • input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria
  • voxel_size: The voxel size of the model's training data.
  • model_path: The path to the model checkpoint if model is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • min_size: The minimum size of a cristae to be considered.
  • verbose: Whether to print timing information.
  • distance_based_segmentation: Whether to use distance-based segmentation.
  • return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
  • scale: The scale factor to use for rescaling the input volume before prediction.
  • mask: An optional mask that is used to restrict the segmentation.
Returns:

The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.