synapse_net.inference.vesicles
1import time 2from typing import Dict, List, Optional, Tuple, Union 3 4import elf.parallel as parallel 5import numpy as np 6 7import torch 8 9from synapse_net.inference.util import apply_size_filter, get_prediction, _Scaler 10from synapse_net.inference.postprocessing.vesicles import filter_border_objects 11 12 13def distance_based_vesicle_segmentation( 14 foreground: np.ndarray, 15 boundaries: np.ndarray, 16 verbose: bool, 17 min_size: int, 18 boundary_threshold: float = 0.5, # previous default value was 0.9 19 distance_threshold: int = 8, 20 block_shape: Tuple[int, int, int] = (128, 256, 256), 21 halo: Tuple[int, int, int] = (48, 48, 48), 22) -> np.ndarray: 23 """Segment vesicles using a seeded watershed from connected components derived from 24 distance transform of the boundary predictions. 25 26 This approach can prevent false merges that occur with the `simple_vesicle_segmentation`. 27 28 Args: 29 foreground: The foreground prediction. 30 boundaries: The boundary prediction. 31 verbose: Whether to print timing information. 32 min_size: The minimal vesicle size. 33 boundary_threshold: The threshold for binarizing the boundary predictions for the distance computation. 34 distance_threshold: The threshold for finding connected components in the boundary distances. 35 block_shape: Block shape for parallelizing the operations. 36 halo: Halo for parallelizing the operations. 37 38 Returns: 39 The vesicle segmentation. 40 """ 41 # Compute the boundary distances. 42 t0 = time.time() 43 bd_dist = parallel.distance_transform( 44 boundaries < boundary_threshold, halo=halo, verbose=verbose, block_shape=block_shape 45 ) 46 bd_dist[foreground < 0.5] = 0 47 if verbose: 48 print("Compute distance transform in", time.time() - t0, "s") 49 50 # Get the segmentation via seeded watershed of components in the boundary distances. 51 t0 = time.time() 52 seeds = parallel.label(bd_dist > distance_threshold, block_shape=block_shape, verbose=verbose) 53 if verbose: 54 print("Compute connected components in", time.time() - t0, "s") 55 56 # Compute distances from the seeds, which are used as heightmap for the watershed, 57 # to assign all pixels to the nearest seed. 58 t0 = time.time() 59 dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape) 60 if verbose: 61 print("Compute distance transform in", time.time() - t0, "s") 62 63 t0 = time.time() 64 mask = (foreground + boundaries) > 0.5 65 seg = np.zeros_like(seeds) 66 seg = parallel.seeded_watershed( 67 dist, seeds, block_shape=block_shape, 68 out=seg, mask=mask, verbose=verbose, halo=halo, 69 ) 70 if verbose: 71 print("Compute watershed in", time.time() - t0, "s") 72 73 seg = apply_size_filter(seg, min_size, verbose, block_shape) 74 return seg 75 76 77def simple_vesicle_segmentation( 78 foreground: np.ndarray, 79 boundaries: np.ndarray, 80 verbose: bool, 81 min_size: int, 82 block_shape: Tuple[int, int, int] = (128, 256, 256), 83 halo: Tuple[int, int, int] = (48, 48, 48), 84) -> np.ndarray: 85 """Segment vesicles by subtracting boundary from foreground prediction and 86 applying connected components. 87 88 Args: 89 foreground: The foreground prediction. 90 boundaries: The boundary prediction. 91 verbose: Whether to print timing information. 92 min_size: The minimal vesicle size. 93 block_shape: Block shape for parallelizing the operations. 94 halo: Halo for parallelizing the operations. 95 96 Returns: 97 The vesicle segmentation. 98 """ 99 100 t0 = time.time() 101 seeds = parallel.label((foreground - boundaries) > 0.5, block_shape=block_shape, verbose=verbose) 102 if verbose: 103 print("Compute connected components in", time.time() - t0, "s") 104 105 t0 = time.time() 106 dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape) 107 if verbose: 108 print("Compute distance transform in", time.time() - t0, "s") 109 110 t0 = time.time() 111 mask = (foreground + boundaries) > 0.5 112 seg = np.zeros_like(seeds) 113 seg = parallel.seeded_watershed( 114 dist, seeds, block_shape=block_shape, 115 out=seg, mask=mask, verbose=verbose, halo=halo, 116 ) 117 if verbose: 118 print("Compute watershed in", time.time() - t0, "s") 119 120 seg = apply_size_filter(seg, min_size, verbose, block_shape) 121 return seg 122 123 124def segment_vesicles( 125 input_volume: np.ndarray, 126 model_path: Optional[str] = None, 127 model: Optional[torch.nn.Module] = None, 128 tiling: Optional[Dict[str, Dict[str, int]]] = None, 129 min_size: int = 500, 130 verbose: bool = True, 131 distance_based_segmentation: bool = True, 132 return_predictions: bool = False, 133 scale: Optional[List[float]] = None, 134 exclude_boundary: bool = False, 135 mask: Optional[np.ndarray] = None, 136) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 137 """Segment vesicles in an input volume or image. 138 139 Args: 140 input_volume: The input volume to segment. 141 model_path: The path to the model checkpoint if `model` is not provided. 142 model: Pre-loaded model. Either `model_path` or `model` is required. 143 tiling: The tiling configuration for the prediction. 144 min_size: The minimum size of a vesicle to be considered. 145 verbose: Whether to print timing information. 146 distance_based_segmentation: Whether to use distance-based segmentation. 147 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 148 scale: The scale factor to use for rescaling the input volume before prediction. 149 exclude_boundary: Whether to exclude vesicles that touch the upper / lower border in z. 150 mask: An optional mask that is used to restrict the segmentation. 151 152 Returns: 153 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 154 and the predictions if return_predictions is True. 155 """ 156 if verbose: 157 print("Segmenting vesicles in volume of shape", input_volume.shape) 158 # Create the scaler to handle prediction with a different scaling factor. 159 scaler = _Scaler(scale, verbose) 160 input_volume = scaler.scale_input(input_volume) 161 162 # Rescale the mask if it was given and run prediction. 163 if mask is not None: 164 mask = scaler.scale_input(mask, is_segmentation=True) 165 pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose, mask=mask) 166 foreground, boundaries = pred[:2] 167 168 # Deal with 2D segmentation case. 169 kwargs = {} 170 if len(input_volume.shape) == 2: 171 kwargs["block_shape"] = (256, 256) 172 kwargs["halo"] = (48, 48) 173 174 if distance_based_segmentation: 175 seg = distance_based_vesicle_segmentation( 176 foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs 177 ) 178 else: 179 seg = simple_vesicle_segmentation( 180 foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs 181 ) 182 183 if exclude_boundary: 184 seg = filter_border_objects(seg) 185 seg = scaler.rescale_output(seg, is_segmentation=True) 186 187 if return_predictions: 188 pred = scaler.rescale_output(pred, is_segmentation=False) 189 return seg, pred 190 return seg
def
distance_based_vesicle_segmentation( foreground: numpy.ndarray, boundaries: numpy.ndarray, verbose: bool, min_size: int, boundary_threshold: float = 0.5, distance_threshold: int = 8, block_shape: Tuple[int, int, int] = (128, 256, 256), halo: Tuple[int, int, int] = (48, 48, 48)) -> numpy.ndarray:
14def distance_based_vesicle_segmentation( 15 foreground: np.ndarray, 16 boundaries: np.ndarray, 17 verbose: bool, 18 min_size: int, 19 boundary_threshold: float = 0.5, # previous default value was 0.9 20 distance_threshold: int = 8, 21 block_shape: Tuple[int, int, int] = (128, 256, 256), 22 halo: Tuple[int, int, int] = (48, 48, 48), 23) -> np.ndarray: 24 """Segment vesicles using a seeded watershed from connected components derived from 25 distance transform of the boundary predictions. 26 27 This approach can prevent false merges that occur with the `simple_vesicle_segmentation`. 28 29 Args: 30 foreground: The foreground prediction. 31 boundaries: The boundary prediction. 32 verbose: Whether to print timing information. 33 min_size: The minimal vesicle size. 34 boundary_threshold: The threshold for binarizing the boundary predictions for the distance computation. 35 distance_threshold: The threshold for finding connected components in the boundary distances. 36 block_shape: Block shape for parallelizing the operations. 37 halo: Halo for parallelizing the operations. 38 39 Returns: 40 The vesicle segmentation. 41 """ 42 # Compute the boundary distances. 43 t0 = time.time() 44 bd_dist = parallel.distance_transform( 45 boundaries < boundary_threshold, halo=halo, verbose=verbose, block_shape=block_shape 46 ) 47 bd_dist[foreground < 0.5] = 0 48 if verbose: 49 print("Compute distance transform in", time.time() - t0, "s") 50 51 # Get the segmentation via seeded watershed of components in the boundary distances. 52 t0 = time.time() 53 seeds = parallel.label(bd_dist > distance_threshold, block_shape=block_shape, verbose=verbose) 54 if verbose: 55 print("Compute connected components in", time.time() - t0, "s") 56 57 # Compute distances from the seeds, which are used as heightmap for the watershed, 58 # to assign all pixels to the nearest seed. 59 t0 = time.time() 60 dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape) 61 if verbose: 62 print("Compute distance transform in", time.time() - t0, "s") 63 64 t0 = time.time() 65 mask = (foreground + boundaries) > 0.5 66 seg = np.zeros_like(seeds) 67 seg = parallel.seeded_watershed( 68 dist, seeds, block_shape=block_shape, 69 out=seg, mask=mask, verbose=verbose, halo=halo, 70 ) 71 if verbose: 72 print("Compute watershed in", time.time() - t0, "s") 73 74 seg = apply_size_filter(seg, min_size, verbose, block_shape) 75 return seg
Segment vesicles using a seeded watershed from connected components derived from distance transform of the boundary predictions.
This approach can prevent false merges that occur with the simple_vesicle_segmentation
.
Arguments:
- foreground: The foreground prediction.
- boundaries: The boundary prediction.
- verbose: Whether to print timing information.
- min_size: The minimal vesicle size.
- boundary_threshold: The threshold for binarizing the boundary predictions for the distance computation.
- distance_threshold: The threshold for finding connected components in the boundary distances.
- block_shape: Block shape for parallelizing the operations.
- halo: Halo for parallelizing the operations.
Returns:
The vesicle segmentation.
def
simple_vesicle_segmentation( foreground: numpy.ndarray, boundaries: numpy.ndarray, verbose: bool, min_size: int, block_shape: Tuple[int, int, int] = (128, 256, 256), halo: Tuple[int, int, int] = (48, 48, 48)) -> numpy.ndarray:
78def simple_vesicle_segmentation( 79 foreground: np.ndarray, 80 boundaries: np.ndarray, 81 verbose: bool, 82 min_size: int, 83 block_shape: Tuple[int, int, int] = (128, 256, 256), 84 halo: Tuple[int, int, int] = (48, 48, 48), 85) -> np.ndarray: 86 """Segment vesicles by subtracting boundary from foreground prediction and 87 applying connected components. 88 89 Args: 90 foreground: The foreground prediction. 91 boundaries: The boundary prediction. 92 verbose: Whether to print timing information. 93 min_size: The minimal vesicle size. 94 block_shape: Block shape for parallelizing the operations. 95 halo: Halo for parallelizing the operations. 96 97 Returns: 98 The vesicle segmentation. 99 """ 100 101 t0 = time.time() 102 seeds = parallel.label((foreground - boundaries) > 0.5, block_shape=block_shape, verbose=verbose) 103 if verbose: 104 print("Compute connected components in", time.time() - t0, "s") 105 106 t0 = time.time() 107 dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape) 108 if verbose: 109 print("Compute distance transform in", time.time() - t0, "s") 110 111 t0 = time.time() 112 mask = (foreground + boundaries) > 0.5 113 seg = np.zeros_like(seeds) 114 seg = parallel.seeded_watershed( 115 dist, seeds, block_shape=block_shape, 116 out=seg, mask=mask, verbose=verbose, halo=halo, 117 ) 118 if verbose: 119 print("Compute watershed in", time.time() - t0, "s") 120 121 seg = apply_size_filter(seg, min_size, verbose, block_shape) 122 return seg
Segment vesicles by subtracting boundary from foreground prediction and applying connected components.
Arguments:
- foreground: The foreground prediction.
- boundaries: The boundary prediction.
- verbose: Whether to print timing information.
- min_size: The minimal vesicle size.
- block_shape: Block shape for parallelizing the operations.
- halo: Halo for parallelizing the operations.
Returns:
The vesicle segmentation.
def
segment_vesicles( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 500, verbose: bool = True, distance_based_segmentation: bool = True, return_predictions: bool = False, scale: Optional[List[float]] = None, exclude_boundary: bool = False, mask: Optional[numpy.ndarray] = None) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
125def segment_vesicles( 126 input_volume: np.ndarray, 127 model_path: Optional[str] = None, 128 model: Optional[torch.nn.Module] = None, 129 tiling: Optional[Dict[str, Dict[str, int]]] = None, 130 min_size: int = 500, 131 verbose: bool = True, 132 distance_based_segmentation: bool = True, 133 return_predictions: bool = False, 134 scale: Optional[List[float]] = None, 135 exclude_boundary: bool = False, 136 mask: Optional[np.ndarray] = None, 137) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 138 """Segment vesicles in an input volume or image. 139 140 Args: 141 input_volume: The input volume to segment. 142 model_path: The path to the model checkpoint if `model` is not provided. 143 model: Pre-loaded model. Either `model_path` or `model` is required. 144 tiling: The tiling configuration for the prediction. 145 min_size: The minimum size of a vesicle to be considered. 146 verbose: Whether to print timing information. 147 distance_based_segmentation: Whether to use distance-based segmentation. 148 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 149 scale: The scale factor to use for rescaling the input volume before prediction. 150 exclude_boundary: Whether to exclude vesicles that touch the upper / lower border in z. 151 mask: An optional mask that is used to restrict the segmentation. 152 153 Returns: 154 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 155 and the predictions if return_predictions is True. 156 """ 157 if verbose: 158 print("Segmenting vesicles in volume of shape", input_volume.shape) 159 # Create the scaler to handle prediction with a different scaling factor. 160 scaler = _Scaler(scale, verbose) 161 input_volume = scaler.scale_input(input_volume) 162 163 # Rescale the mask if it was given and run prediction. 164 if mask is not None: 165 mask = scaler.scale_input(mask, is_segmentation=True) 166 pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose, mask=mask) 167 foreground, boundaries = pred[:2] 168 169 # Deal with 2D segmentation case. 170 kwargs = {} 171 if len(input_volume.shape) == 2: 172 kwargs["block_shape"] = (256, 256) 173 kwargs["halo"] = (48, 48) 174 175 if distance_based_segmentation: 176 seg = distance_based_vesicle_segmentation( 177 foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs 178 ) 179 else: 180 seg = simple_vesicle_segmentation( 181 foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs 182 ) 183 184 if exclude_boundary: 185 seg = filter_border_objects(seg) 186 seg = scaler.rescale_output(seg, is_segmentation=True) 187 188 if return_predictions: 189 pred = scaler.rescale_output(pred, is_segmentation=False) 190 return seg, pred 191 return seg
Segment vesicles in an input volume or image.
Arguments:
- input_volume: The input volume to segment.
- model_path: The path to the model checkpoint if
model
is not provided. - model: Pre-loaded model. Either
model_path
ormodel
is required. - tiling: The tiling configuration for the prediction.
- min_size: The minimum size of a vesicle to be considered.
- verbose: Whether to print timing information.
- distance_based_segmentation: Whether to use distance-based segmentation.
- return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
- scale: The scale factor to use for rescaling the input volume before prediction.
- exclude_boundary: Whether to exclude vesicles that touch the upper / lower border in z.
- mask: An optional mask that is used to restrict the segmentation.
Returns:
The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.