synapse_net.inference.cristae
1import time 2from typing import Dict, List, Optional, Tuple, Union 3 4import elf.parallel as parallel 5from elf.wrapper.base import ( 6 SimpleTransformationWrapper, 7 MultiTransformationWrapper, 8) 9from skimage.morphology import binary_erosion, ball 10from skimage.measure import regionprops 11import numpy as np 12import torch 13 14from synapse_net.inference.util import get_prediction, _Scaler 15 16 17def _erode_instances(mito_data, erode_voxels, verbose): 18 """Erodes instances globally and returns a memory-efficient boolean mask.""" 19 if verbose: 20 t_erode = time.time() 21 print(f"Eroding mitochondria instances globally by {erode_voxels} voxels ...") 22 23 footprint = ball(erode_voxels) 24 props = regionprops(mito_data) 25 26 # Allocate a boolean array 27 eroded_binary_mask = np.zeros(mito_data.shape, dtype=bool) 28 29 for prop in props: 30 sl = prop.slice 31 32 # Isolate this specific instance within its bounding box 33 instance_mask = (mito_data[sl] == prop.label) 34 35 # Apply erosion 36 eroded_mask = binary_erosion(instance_mask, footprint=footprint) 37 38 # Write True to the newly eroded locations in our boolean array 39 eroded_binary_mask[sl][eroded_mask] = True 40 41 if verbose: 42 print(f"Instance erosion completed in {time.time() - t_erode:.2f} s") 43 44 return eroded_binary_mask 45 46 47def _run_segmentation( 48 foreground, verbose, min_size, 49 # blocking shapes for parallel computation 50 block_shape=(128, 256, 256), 51 mito_seg=None, 52 erode_voxels=3, 53): 54 mito_seg = _erode_instances(mito_seg, erode_voxels, verbose) 55 56 # Mask the foreground lazily 57 # Even though mito_seg is now in memory, foreground might not be. 58 # MultiTransformationWrapper safely handles this mix. 59 def mask_foreground(inputs): 60 fg_block, mito_block = inputs 61 return np.where(mito_block != 0, fg_block, 0) 62 63 foreground = MultiTransformationWrapper( 64 mask_foreground, 65 foreground, 66 mito_seg, 67 apply_to_list=True 68 ) 69 70 # Apply the threshold lazily 71 def threshold_block(block): 72 return block > 0.5 73 74 binary_foreground = SimpleTransformationWrapper( 75 foreground, 76 transformation=threshold_block 77 ) 78 79 t0 = time.time() 80 seg = parallel.label(binary_foreground, block_shape=block_shape, verbose=verbose) 81 if verbose: 82 print("Compute connected components in", time.time() - t0, "s") 83 84 # Size filter 85 if min_size > 0: 86 t0 = time.time() 87 parallel.size_filter(seg, out=seg, min_size=min_size, block_shape=block_shape, verbose=verbose) 88 if verbose: 89 print("Size filter in", time.time() - t0, "s") 90 91 seg = np.where(seg > 0, 1, 0) 92 93 return seg 94 95 96def segment_cristae( 97 input_volume: np.ndarray, 98 voxel_size: float, 99 model_path: Optional[str] = None, 100 model: Optional[torch.nn.Module] = None, 101 tiling: Optional[Dict[str, Dict[str, int]]] = None, 102 min_size: int = 2000, 103 verbose: bool = True, 104 distance_based_segmentation: bool = False, 105 return_predictions: bool = False, 106 scale: Optional[List[float]] = None, 107 mask: Optional[np.ndarray] = None, 108 **kwargs 109) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 110 """Segment cristae in an input volume. 111 112 Args: 113 input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria 114 voxel_size: The voxel size of the model's training data. 115 model_path: The path to the model checkpoint if `model` is not provided. 116 model: Pre-loaded model. Either `model_path` or `model` is required. 117 tiling: The tiling configuration for the prediction. 118 min_size: The minimum size of a cristae to be considered. 119 verbose: Whether to print timing information. 120 distance_based_segmentation: Whether to use distance-based segmentation. 121 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 122 scale: The scale factor to use for rescaling the input volume before prediction. 123 mask: An optional mask that is used to restrict the segmentation. 124 125 Returns: 126 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 127 and the predictions if return_predictions is True. 128 """ 129 mitochondria = kwargs.pop("extra_segmentation", None) 130 if mitochondria is None: 131 # try extract from input volume 132 if input_volume.ndim == 4: 133 mitochondria = input_volume[1] 134 input_volume = input_volume[0] 135 if mitochondria is None: 136 raise ValueError("Mitochondria segmentation is required") 137 with_channels = kwargs.pop("with_channels", True) 138 channels_to_standardize = kwargs.pop("channels_to_standardize", [0]) 139 if verbose: 140 print("Segmenting cristae in volume of shape", input_volume.shape) 141 # Create the scaler to handle prediction with a different scaling factor. 142 scaler = _Scaler(scale, verbose) 143 # rescale each channel 144 volume = scaler.scale_input(input_volume) 145 mito_seg = scaler.scale_input(mitochondria, is_segmentation=True) 146 input_volume = np.stack([volume, mito_seg], axis=0) 147 148 # target 10nm erosion for mitochondria 149 # voxel_size is the model's training voxel size, which is the space we erode in 150 erode_voxels = max(1, round(10.0 / voxel_size)) 151 152 # Run prediction and segmentation. 153 if mask is not None: 154 mask = scaler.scale_input(mask, is_segmentation=True) 155 pred = get_prediction( 156 input_volume, model_path=model_path, model=model, mask=mask, 157 tiling=tiling, with_channels=with_channels, channels_to_standardize=channels_to_standardize, verbose=verbose 158 ) 159 foreground, boundaries = pred[:2] 160 seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size, mito_seg=mito_seg, 161 erode_voxels=erode_voxels) 162 seg = scaler.rescale_output(seg, is_segmentation=True) 163 164 if return_predictions: 165 pred = scaler.rescale_output(pred, is_segmentation=False) 166 return seg, pred 167 return seg
def
segment_cristae( input_volume: numpy.ndarray, voxel_size: float, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 2000, verbose: bool = True, distance_based_segmentation: bool = False, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, **kwargs) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
97def segment_cristae( 98 input_volume: np.ndarray, 99 voxel_size: float, 100 model_path: Optional[str] = None, 101 model: Optional[torch.nn.Module] = None, 102 tiling: Optional[Dict[str, Dict[str, int]]] = None, 103 min_size: int = 2000, 104 verbose: bool = True, 105 distance_based_segmentation: bool = False, 106 return_predictions: bool = False, 107 scale: Optional[List[float]] = None, 108 mask: Optional[np.ndarray] = None, 109 **kwargs 110) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 111 """Segment cristae in an input volume. 112 113 Args: 114 input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria 115 voxel_size: The voxel size of the model's training data. 116 model_path: The path to the model checkpoint if `model` is not provided. 117 model: Pre-loaded model. Either `model_path` or `model` is required. 118 tiling: The tiling configuration for the prediction. 119 min_size: The minimum size of a cristae to be considered. 120 verbose: Whether to print timing information. 121 distance_based_segmentation: Whether to use distance-based segmentation. 122 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 123 scale: The scale factor to use for rescaling the input volume before prediction. 124 mask: An optional mask that is used to restrict the segmentation. 125 126 Returns: 127 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 128 and the predictions if return_predictions is True. 129 """ 130 mitochondria = kwargs.pop("extra_segmentation", None) 131 if mitochondria is None: 132 # try extract from input volume 133 if input_volume.ndim == 4: 134 mitochondria = input_volume[1] 135 input_volume = input_volume[0] 136 if mitochondria is None: 137 raise ValueError("Mitochondria segmentation is required") 138 with_channels = kwargs.pop("with_channels", True) 139 channels_to_standardize = kwargs.pop("channels_to_standardize", [0]) 140 if verbose: 141 print("Segmenting cristae in volume of shape", input_volume.shape) 142 # Create the scaler to handle prediction with a different scaling factor. 143 scaler = _Scaler(scale, verbose) 144 # rescale each channel 145 volume = scaler.scale_input(input_volume) 146 mito_seg = scaler.scale_input(mitochondria, is_segmentation=True) 147 input_volume = np.stack([volume, mito_seg], axis=0) 148 149 # target 10nm erosion for mitochondria 150 # voxel_size is the model's training voxel size, which is the space we erode in 151 erode_voxels = max(1, round(10.0 / voxel_size)) 152 153 # Run prediction and segmentation. 154 if mask is not None: 155 mask = scaler.scale_input(mask, is_segmentation=True) 156 pred = get_prediction( 157 input_volume, model_path=model_path, model=model, mask=mask, 158 tiling=tiling, with_channels=with_channels, channels_to_standardize=channels_to_standardize, verbose=verbose 159 ) 160 foreground, boundaries = pred[:2] 161 seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size, mito_seg=mito_seg, 162 erode_voxels=erode_voxels) 163 seg = scaler.rescale_output(seg, is_segmentation=True) 164 165 if return_predictions: 166 pred = scaler.rescale_output(pred, is_segmentation=False) 167 return seg, pred 168 return seg
Segment cristae in an input volume.
Arguments:
- input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria
- voxel_size: The voxel size of the model's training data.
- model_path: The path to the model checkpoint if
modelis not provided. - model: Pre-loaded model. Either
model_pathormodelis required. - tiling: The tiling configuration for the prediction.
- min_size: The minimum size of a cristae to be considered.
- verbose: Whether to print timing information.
- distance_based_segmentation: Whether to use distance-based segmentation.
- return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
- scale: The scale factor to use for rescaling the input volume before prediction.
- mask: An optional mask that is used to restrict the segmentation.
Returns:
The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.