synapse_net.inference.cristae
1import time 2from typing import Dict, List, Optional, Tuple, Union 3 4import elf.parallel as parallel 5from elf.wrapper.base import ( 6 SimpleTransformationWrapper, 7 MultiTransformationWrapper, 8) 9from skimage.morphology import binary_erosion, ball 10from skimage.measure import regionprops 11import numpy as np 12import torch 13 14from synapse_net.inference.util import get_prediction, _Scaler 15 16 17def _erode_instances(mito_data, erode_voxels, verbose): 18 """Erodes instances globally and returns a memory-efficient boolean mask.""" 19 if verbose: 20 t_erode = time.time() 21 print("Eroding mitochondria instances globally...") 22 23 footprint = ball(erode_voxels) 24 props = regionprops(mito_data) 25 26 # Allocate a boolean array 27 eroded_binary_mask = np.zeros(mito_data.shape, dtype=bool) 28 29 for prop in props: 30 sl = prop.slice 31 32 # Isolate this specific instance within its bounding box 33 instance_mask = (mito_data[sl] == prop.label) 34 35 # Apply erosion 36 eroded_mask = binary_erosion(instance_mask, footprint=footprint) 37 38 # Write True to the newly eroded locations in our boolean array 39 eroded_binary_mask[sl][eroded_mask] = True 40 41 if verbose: 42 print(f"Instance erosion completed in {time.time() - t_erode:.2f} s") 43 44 return eroded_binary_mask 45 46 47def _run_segmentation( 48 foreground, verbose, min_size, 49 # blocking shapes for parallel computation 50 block_shape=(128, 256, 256), 51 mito_seg=None, 52 erode_voxels=3, 53): 54 mito_seg = _erode_instances(mito_seg, erode_voxels, verbose) 55 56 # Mask the foreground lazily 57 # Even though mito_seg is now in memory, foreground might not be. 58 # MultiTransformationWrapper safely handles this mix. 59 def mask_foreground(inputs): 60 fg_block, mito_block = inputs 61 return np.where(mito_block != 0, fg_block, 0) 62 63 foreground = MultiTransformationWrapper( 64 mask_foreground, 65 foreground, 66 mito_seg, 67 apply_to_list=True 68 ) 69 70 # Apply the threshold lazily 71 def threshold_block(block): 72 return block > 0.5 73 74 binary_foreground = SimpleTransformationWrapper( 75 foreground, 76 transformation=threshold_block 77 ) 78 79 t0 = time.time() 80 seg = parallel.label(binary_foreground, block_shape=block_shape, verbose=verbose) 81 if verbose: 82 print("Compute connected components in", time.time() - t0, "s") 83 84 # Size filter 85 if min_size > 0: 86 t0 = time.time() 87 parallel.size_filter(seg, out=seg, min_size=min_size, block_shape=block_shape, verbose=verbose) 88 if verbose: 89 print("Size filter in", time.time() - t0, "s") 90 91 seg = np.where(seg > 0, 1, 0) 92 93 return seg 94 95 96def segment_cristae( 97 input_volume: np.ndarray, 98 model_path: Optional[str] = None, 99 model: Optional[torch.nn.Module] = None, 100 tiling: Optional[Dict[str, Dict[str, int]]] = None, 101 min_size: int = 500, 102 verbose: bool = True, 103 distance_based_segmentation: bool = False, 104 return_predictions: bool = False, 105 scale: Optional[List[float]] = None, 106 mask: Optional[np.ndarray] = None, 107 **kwargs 108) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 109 """Segment cristae in an input volume. 110 111 Args: 112 input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria 113 model_path: The path to the model checkpoint if `model` is not provided. 114 model: Pre-loaded model. Either `model_path` or `model` is required. 115 tiling: The tiling configuration for the prediction. 116 min_size: The minimum size of a cristae to be considered. 117 verbose: Whether to print timing information. 118 distance_based_segmentation: Whether to use distance-based segmentation. 119 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 120 scale: The scale factor to use for rescaling the input volume before prediction. 121 mask: An optional mask that is used to restrict the segmentation. 122 123 Returns: 124 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 125 and the predictions if return_predictions is True. 126 """ 127 mitochondria = kwargs.pop("extra_segmentation", None) 128 if mitochondria is None: 129 # try extract from input volume 130 if input_volume.ndim == 4: 131 mitochondria = input_volume[1] 132 input_volume = input_volume[0] 133 if mitochondria is None: 134 raise ValueError("Mitochondria segmentation is required") 135 with_channels = kwargs.pop("with_channels", True) 136 channels_to_standardize = kwargs.pop("channels_to_standardize", [0]) 137 if verbose: 138 print("Segmenting cristae in volume of shape", input_volume.shape) 139 # Create the scaler to handle prediction with a different scaling factor. 140 scaler = _Scaler(scale, verbose) 141 # rescale each channel 142 volume = scaler.scale_input(input_volume) 143 mito_seg = scaler.scale_input(mitochondria, is_segmentation=True) 144 input_volume = np.stack([volume, mito_seg], axis=0) 145 146 # Run prediction and segmentation. 147 if mask is not None: 148 mask = scaler.scale_input(mask, is_segmentation=True) 149 pred = get_prediction( 150 input_volume, model_path=model_path, model=model, mask=mask, 151 tiling=tiling, with_channels=with_channels, channels_to_standardize=channels_to_standardize, verbose=verbose 152 ) 153 foreground, boundaries = pred[:2] 154 seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size, mito_seg=mito_seg) 155 seg = scaler.rescale_output(seg, is_segmentation=True) 156 157 if return_predictions: 158 pred = scaler.rescale_output(pred, is_segmentation=False) 159 return seg, pred 160 return seg
def
segment_cristae( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 500, verbose: bool = True, distance_based_segmentation: bool = False, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, **kwargs) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
97def segment_cristae( 98 input_volume: np.ndarray, 99 model_path: Optional[str] = None, 100 model: Optional[torch.nn.Module] = None, 101 tiling: Optional[Dict[str, Dict[str, int]]] = None, 102 min_size: int = 500, 103 verbose: bool = True, 104 distance_based_segmentation: bool = False, 105 return_predictions: bool = False, 106 scale: Optional[List[float]] = None, 107 mask: Optional[np.ndarray] = None, 108 **kwargs 109) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 110 """Segment cristae in an input volume. 111 112 Args: 113 input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria 114 model_path: The path to the model checkpoint if `model` is not provided. 115 model: Pre-loaded model. Either `model_path` or `model` is required. 116 tiling: The tiling configuration for the prediction. 117 min_size: The minimum size of a cristae to be considered. 118 verbose: Whether to print timing information. 119 distance_based_segmentation: Whether to use distance-based segmentation. 120 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 121 scale: The scale factor to use for rescaling the input volume before prediction. 122 mask: An optional mask that is used to restrict the segmentation. 123 124 Returns: 125 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 126 and the predictions if return_predictions is True. 127 """ 128 mitochondria = kwargs.pop("extra_segmentation", None) 129 if mitochondria is None: 130 # try extract from input volume 131 if input_volume.ndim == 4: 132 mitochondria = input_volume[1] 133 input_volume = input_volume[0] 134 if mitochondria is None: 135 raise ValueError("Mitochondria segmentation is required") 136 with_channels = kwargs.pop("with_channels", True) 137 channels_to_standardize = kwargs.pop("channels_to_standardize", [0]) 138 if verbose: 139 print("Segmenting cristae in volume of shape", input_volume.shape) 140 # Create the scaler to handle prediction with a different scaling factor. 141 scaler = _Scaler(scale, verbose) 142 # rescale each channel 143 volume = scaler.scale_input(input_volume) 144 mito_seg = scaler.scale_input(mitochondria, is_segmentation=True) 145 input_volume = np.stack([volume, mito_seg], axis=0) 146 147 # Run prediction and segmentation. 148 if mask is not None: 149 mask = scaler.scale_input(mask, is_segmentation=True) 150 pred = get_prediction( 151 input_volume, model_path=model_path, model=model, mask=mask, 152 tiling=tiling, with_channels=with_channels, channels_to_standardize=channels_to_standardize, verbose=verbose 153 ) 154 foreground, boundaries = pred[:2] 155 seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size, mito_seg=mito_seg) 156 seg = scaler.rescale_output(seg, is_segmentation=True) 157 158 if return_predictions: 159 pred = scaler.rescale_output(pred, is_segmentation=False) 160 return seg, pred 161 return seg
Segment cristae in an input volume.
Arguments:
- input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria
- model_path: The path to the model checkpoint if
modelis not provided. - model: Pre-loaded model. Either
model_pathormodelis required. - tiling: The tiling configuration for the prediction.
- min_size: The minimum size of a cristae to be considered.
- verbose: Whether to print timing information.
- distance_based_segmentation: Whether to use distance-based segmentation.
- return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
- scale: The scale factor to use for rescaling the input volume before prediction.
- mask: An optional mask that is used to restrict the segmentation.
Returns:
The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.