synapse_net.inference.vesicles
1import time 2import warnings 3from typing import Dict, List, Optional, Tuple, Union 4 5import elf.parallel as parallel 6import numpy as np 7 8import torch 9 10from synapse_net.inference.util import apply_size_filter, get_prediction, _Scaler 11from synapse_net.inference.postprocessing.vesicles import filter_border_objects, filter_border_vesicles 12from skimage.segmentation import relabel_sequential 13 14 15def distance_based_vesicle_segmentation( 16 foreground: np.ndarray, 17 boundaries: np.ndarray, 18 verbose: bool, 19 min_size: int, 20 boundary_threshold: float = 0.5, # previous default value was 0.9 21 distance_threshold: int = 8, 22 block_shape: Tuple[int, int, int] = (128, 256, 256), 23 halo: Tuple[int, int, int] = (48, 48, 48), 24) -> np.ndarray: 25 """Segment vesicles using a seeded watershed from connected components derived from 26 distance transform of the boundary predictions. 27 28 This approach can prevent false merges that occur with the `simple_vesicle_segmentation`. 29 30 Args: 31 foreground: The foreground prediction. 32 boundaries: The boundary prediction. 33 verbose: Whether to print timing information. 34 min_size: The minimal vesicle size. 35 boundary_threshold: The threshold for binarizing the boundary predictions for the distance computation. 36 distance_threshold: The threshold for finding connected components in the boundary distances. 37 block_shape: Block shape for parallelizing the operations. 38 halo: Halo for parallelizing the operations. 39 40 Returns: 41 The vesicle segmentation. 42 """ 43 # Compute the boundary distances. 44 t0 = time.time() 45 bd_dist = parallel.distance_transform( 46 boundaries < boundary_threshold, halo=halo, verbose=verbose, block_shape=block_shape 47 ) 48 bd_dist[foreground < 0.5] = 0 49 if verbose: 50 print("Compute distance transform in", time.time() - t0, "s") 51 52 # Get the segmentation via seeded watershed of components in the boundary distances. 53 t0 = time.time() 54 seeds = parallel.label(bd_dist > distance_threshold, block_shape=block_shape, verbose=verbose) 55 if verbose: 56 print("Compute connected components in", time.time() - t0, "s") 57 58 # Compute distances from the seeds, which are used as heightmap for the watershed, 59 # to assign all pixels to the nearest seed. 60 t0 = time.time() 61 dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape) 62 if verbose: 63 print("Compute distance transform in", time.time() - t0, "s") 64 65 t0 = time.time() 66 mask = (foreground + boundaries) > 0.5 67 seg = np.zeros_like(seeds) 68 seg = parallel.seeded_watershed( 69 dist, seeds, block_shape=block_shape, 70 out=seg, mask=mask, verbose=verbose, halo=halo, 71 ) 72 if verbose: 73 print("Compute watershed in", time.time() - t0, "s") 74 75 seg = apply_size_filter(seg, min_size, verbose, block_shape) 76 return seg 77 78 79def simple_vesicle_segmentation( 80 foreground: np.ndarray, 81 boundaries: np.ndarray, 82 verbose: bool, 83 min_size: int, 84 block_shape: Tuple[int, int, int] = (128, 256, 256), 85 halo: Tuple[int, int, int] = (48, 48, 48), 86) -> np.ndarray: 87 """Segment vesicles by subtracting boundary from foreground prediction and 88 applying connected components. 89 90 Args: 91 foreground: The foreground prediction. 92 boundaries: The boundary prediction. 93 verbose: Whether to print timing information. 94 min_size: The minimal vesicle size. 95 block_shape: Block shape for parallelizing the operations. 96 halo: Halo for parallelizing the operations. 97 98 Returns: 99 The vesicle segmentation. 100 """ 101 102 t0 = time.time() 103 seeds = parallel.label((foreground - boundaries) > 0.5, block_shape=block_shape, verbose=verbose) 104 if verbose: 105 print("Compute connected components in", time.time() - t0, "s") 106 107 t0 = time.time() 108 dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape) 109 if verbose: 110 print("Compute distance transform in", time.time() - t0, "s") 111 112 t0 = time.time() 113 mask = (foreground + boundaries) > 0.5 114 seg = np.zeros_like(seeds) 115 seg = parallel.seeded_watershed( 116 dist, seeds, block_shape=block_shape, 117 out=seg, mask=mask, verbose=verbose, halo=halo, 118 ) 119 if verbose: 120 print("Compute watershed in", time.time() - t0, "s") 121 122 seg = apply_size_filter(seg, min_size, verbose, block_shape) 123 return seg 124 125 126def segment_vesicles( 127 input_volume: np.ndarray, 128 model_path: Optional[str] = None, 129 model: Optional[torch.nn.Module] = None, 130 tiling: Optional[Dict[str, Dict[str, int]]] = None, 131 min_size: int = 500, 132 verbose: bool = True, 133 distance_based_segmentation: bool = True, 134 return_predictions: bool = False, 135 scale: Optional[List[float]] = None, 136 exclude_boundary: bool = False, 137 exclude_boundary_vesicles: bool = False, 138 mask: Optional[np.ndarray] = None, 139) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 140 """Segment vesicles in an input volume or image. 141 142 Args: 143 input_volume: The input volume to segment. 144 model_path: The path to the model checkpoint if `model` is not provided. 145 model: Pre-loaded model. Either `model_path` or `model` is required. 146 tiling: The tiling configuration for the prediction. 147 min_size: The minimum size of a vesicle to be considered. 148 verbose: Whether to print timing information. 149 distance_based_segmentation: Whether to use distance-based segmentation. 150 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 151 scale: The scale factor to use for rescaling the input volume before prediction. 152 exclude_boundary: Whether to exclude vesicles that touch the upper / lower border in z. 153 exclude_boundary_vesicles: Whether to exlude vesicles on the boundary that have less than the full diameter 154 inside of the volume. This is an alternative to post-processing with `exclude_boundary` that filters 155 out less vesicles at the boundary and is better suited for volumes with small context in z. 156 If `exclude_boundary` is also set to True, then this option will have no effect. 157 mask: An optional mask that is used to restrict the segmentation. 158 159 Returns: 160 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 161 and the predictions if return_predictions is True. 162 """ 163 if verbose: 164 print("Segmenting vesicles in volume of shape", input_volume.shape) 165 # Create the scaler to handle prediction with a different scaling factor. 166 scaler = _Scaler(scale, verbose) 167 input_volume = scaler.scale_input(input_volume) 168 169 # Rescale the mask if it was given and run prediction. 170 if mask is not None: 171 mask = scaler.scale_input(mask, is_segmentation=True) 172 pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose, mask=mask) 173 foreground, boundaries = pred[:2] 174 175 # Deal with 2D segmentation case. 176 kwargs = {} 177 if len(input_volume.shape) == 2: 178 kwargs["block_shape"] = (256, 256) 179 kwargs["halo"] = (48, 48) 180 181 if distance_based_segmentation: 182 seg = distance_based_vesicle_segmentation( 183 foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs 184 ) 185 else: 186 seg = simple_vesicle_segmentation( 187 foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs 188 ) 189 190 if exclude_boundary and exclude_boundary_vesicles: 191 warnings.warn( 192 "You have set both 'exclude_boundary' and 'exclude_boundary_vesicles' to True." 193 "The 'exclude_boundary_vesicles' option will have no effect." 194 ) 195 seg = filter_border_objects(seg) 196 197 elif exclude_boundary: 198 seg = filter_border_objects(seg) 199 200 elif exclude_boundary_vesicles: 201 # Filter the vesicles that are at the z-border with less than their full diameter. 202 seg_ids = filter_border_vesicles(seg) 203 204 # Remove everything not in seg ids and relable the remaining IDs consecutively. 205 seg[~np.isin(seg, seg_ids)] = 0 206 seg = relabel_sequential(seg)[0] 207 208 seg = scaler.rescale_output(seg, is_segmentation=True) 209 210 if return_predictions: 211 pred = scaler.rescale_output(pred, is_segmentation=False) 212 return seg, pred 213 return seg
def
distance_based_vesicle_segmentation( foreground: numpy.ndarray, boundaries: numpy.ndarray, verbose: bool, min_size: int, boundary_threshold: float = 0.5, distance_threshold: int = 8, block_shape: Tuple[int, int, int] = (128, 256, 256), halo: Tuple[int, int, int] = (48, 48, 48)) -> numpy.ndarray:
16def distance_based_vesicle_segmentation( 17 foreground: np.ndarray, 18 boundaries: np.ndarray, 19 verbose: bool, 20 min_size: int, 21 boundary_threshold: float = 0.5, # previous default value was 0.9 22 distance_threshold: int = 8, 23 block_shape: Tuple[int, int, int] = (128, 256, 256), 24 halo: Tuple[int, int, int] = (48, 48, 48), 25) -> np.ndarray: 26 """Segment vesicles using a seeded watershed from connected components derived from 27 distance transform of the boundary predictions. 28 29 This approach can prevent false merges that occur with the `simple_vesicle_segmentation`. 30 31 Args: 32 foreground: The foreground prediction. 33 boundaries: The boundary prediction. 34 verbose: Whether to print timing information. 35 min_size: The minimal vesicle size. 36 boundary_threshold: The threshold for binarizing the boundary predictions for the distance computation. 37 distance_threshold: The threshold for finding connected components in the boundary distances. 38 block_shape: Block shape for parallelizing the operations. 39 halo: Halo for parallelizing the operations. 40 41 Returns: 42 The vesicle segmentation. 43 """ 44 # Compute the boundary distances. 45 t0 = time.time() 46 bd_dist = parallel.distance_transform( 47 boundaries < boundary_threshold, halo=halo, verbose=verbose, block_shape=block_shape 48 ) 49 bd_dist[foreground < 0.5] = 0 50 if verbose: 51 print("Compute distance transform in", time.time() - t0, "s") 52 53 # Get the segmentation via seeded watershed of components in the boundary distances. 54 t0 = time.time() 55 seeds = parallel.label(bd_dist > distance_threshold, block_shape=block_shape, verbose=verbose) 56 if verbose: 57 print("Compute connected components in", time.time() - t0, "s") 58 59 # Compute distances from the seeds, which are used as heightmap for the watershed, 60 # to assign all pixels to the nearest seed. 61 t0 = time.time() 62 dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape) 63 if verbose: 64 print("Compute distance transform in", time.time() - t0, "s") 65 66 t0 = time.time() 67 mask = (foreground + boundaries) > 0.5 68 seg = np.zeros_like(seeds) 69 seg = parallel.seeded_watershed( 70 dist, seeds, block_shape=block_shape, 71 out=seg, mask=mask, verbose=verbose, halo=halo, 72 ) 73 if verbose: 74 print("Compute watershed in", time.time() - t0, "s") 75 76 seg = apply_size_filter(seg, min_size, verbose, block_shape) 77 return seg
Segment vesicles using a seeded watershed from connected components derived from distance transform of the boundary predictions.
This approach can prevent false merges that occur with the simple_vesicle_segmentation
.
Arguments:
- foreground: The foreground prediction.
- boundaries: The boundary prediction.
- verbose: Whether to print timing information.
- min_size: The minimal vesicle size.
- boundary_threshold: The threshold for binarizing the boundary predictions for the distance computation.
- distance_threshold: The threshold for finding connected components in the boundary distances.
- block_shape: Block shape for parallelizing the operations.
- halo: Halo for parallelizing the operations.
Returns:
The vesicle segmentation.
def
simple_vesicle_segmentation( foreground: numpy.ndarray, boundaries: numpy.ndarray, verbose: bool, min_size: int, block_shape: Tuple[int, int, int] = (128, 256, 256), halo: Tuple[int, int, int] = (48, 48, 48)) -> numpy.ndarray:
80def simple_vesicle_segmentation( 81 foreground: np.ndarray, 82 boundaries: np.ndarray, 83 verbose: bool, 84 min_size: int, 85 block_shape: Tuple[int, int, int] = (128, 256, 256), 86 halo: Tuple[int, int, int] = (48, 48, 48), 87) -> np.ndarray: 88 """Segment vesicles by subtracting boundary from foreground prediction and 89 applying connected components. 90 91 Args: 92 foreground: The foreground prediction. 93 boundaries: The boundary prediction. 94 verbose: Whether to print timing information. 95 min_size: The minimal vesicle size. 96 block_shape: Block shape for parallelizing the operations. 97 halo: Halo for parallelizing the operations. 98 99 Returns: 100 The vesicle segmentation. 101 """ 102 103 t0 = time.time() 104 seeds = parallel.label((foreground - boundaries) > 0.5, block_shape=block_shape, verbose=verbose) 105 if verbose: 106 print("Compute connected components in", time.time() - t0, "s") 107 108 t0 = time.time() 109 dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape) 110 if verbose: 111 print("Compute distance transform in", time.time() - t0, "s") 112 113 t0 = time.time() 114 mask = (foreground + boundaries) > 0.5 115 seg = np.zeros_like(seeds) 116 seg = parallel.seeded_watershed( 117 dist, seeds, block_shape=block_shape, 118 out=seg, mask=mask, verbose=verbose, halo=halo, 119 ) 120 if verbose: 121 print("Compute watershed in", time.time() - t0, "s") 122 123 seg = apply_size_filter(seg, min_size, verbose, block_shape) 124 return seg
Segment vesicles by subtracting boundary from foreground prediction and applying connected components.
Arguments:
- foreground: The foreground prediction.
- boundaries: The boundary prediction.
- verbose: Whether to print timing information.
- min_size: The minimal vesicle size.
- block_shape: Block shape for parallelizing the operations.
- halo: Halo for parallelizing the operations.
Returns:
The vesicle segmentation.
def
segment_vesicles( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 500, verbose: bool = True, distance_based_segmentation: bool = True, return_predictions: bool = False, scale: Optional[List[float]] = None, exclude_boundary: bool = False, exclude_boundary_vesicles: bool = False, mask: Optional[numpy.ndarray] = None) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
127def segment_vesicles( 128 input_volume: np.ndarray, 129 model_path: Optional[str] = None, 130 model: Optional[torch.nn.Module] = None, 131 tiling: Optional[Dict[str, Dict[str, int]]] = None, 132 min_size: int = 500, 133 verbose: bool = True, 134 distance_based_segmentation: bool = True, 135 return_predictions: bool = False, 136 scale: Optional[List[float]] = None, 137 exclude_boundary: bool = False, 138 exclude_boundary_vesicles: bool = False, 139 mask: Optional[np.ndarray] = None, 140) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 141 """Segment vesicles in an input volume or image. 142 143 Args: 144 input_volume: The input volume to segment. 145 model_path: The path to the model checkpoint if `model` is not provided. 146 model: Pre-loaded model. Either `model_path` or `model` is required. 147 tiling: The tiling configuration for the prediction. 148 min_size: The minimum size of a vesicle to be considered. 149 verbose: Whether to print timing information. 150 distance_based_segmentation: Whether to use distance-based segmentation. 151 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 152 scale: The scale factor to use for rescaling the input volume before prediction. 153 exclude_boundary: Whether to exclude vesicles that touch the upper / lower border in z. 154 exclude_boundary_vesicles: Whether to exlude vesicles on the boundary that have less than the full diameter 155 inside of the volume. This is an alternative to post-processing with `exclude_boundary` that filters 156 out less vesicles at the boundary and is better suited for volumes with small context in z. 157 If `exclude_boundary` is also set to True, then this option will have no effect. 158 mask: An optional mask that is used to restrict the segmentation. 159 160 Returns: 161 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 162 and the predictions if return_predictions is True. 163 """ 164 if verbose: 165 print("Segmenting vesicles in volume of shape", input_volume.shape) 166 # Create the scaler to handle prediction with a different scaling factor. 167 scaler = _Scaler(scale, verbose) 168 input_volume = scaler.scale_input(input_volume) 169 170 # Rescale the mask if it was given and run prediction. 171 if mask is not None: 172 mask = scaler.scale_input(mask, is_segmentation=True) 173 pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose, mask=mask) 174 foreground, boundaries = pred[:2] 175 176 # Deal with 2D segmentation case. 177 kwargs = {} 178 if len(input_volume.shape) == 2: 179 kwargs["block_shape"] = (256, 256) 180 kwargs["halo"] = (48, 48) 181 182 if distance_based_segmentation: 183 seg = distance_based_vesicle_segmentation( 184 foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs 185 ) 186 else: 187 seg = simple_vesicle_segmentation( 188 foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs 189 ) 190 191 if exclude_boundary and exclude_boundary_vesicles: 192 warnings.warn( 193 "You have set both 'exclude_boundary' and 'exclude_boundary_vesicles' to True." 194 "The 'exclude_boundary_vesicles' option will have no effect." 195 ) 196 seg = filter_border_objects(seg) 197 198 elif exclude_boundary: 199 seg = filter_border_objects(seg) 200 201 elif exclude_boundary_vesicles: 202 # Filter the vesicles that are at the z-border with less than their full diameter. 203 seg_ids = filter_border_vesicles(seg) 204 205 # Remove everything not in seg ids and relable the remaining IDs consecutively. 206 seg[~np.isin(seg, seg_ids)] = 0 207 seg = relabel_sequential(seg)[0] 208 209 seg = scaler.rescale_output(seg, is_segmentation=True) 210 211 if return_predictions: 212 pred = scaler.rescale_output(pred, is_segmentation=False) 213 return seg, pred 214 return seg
Segment vesicles in an input volume or image.
Arguments:
- input_volume: The input volume to segment.
- model_path: The path to the model checkpoint if
model
is not provided. - model: Pre-loaded model. Either
model_path
ormodel
is required. - tiling: The tiling configuration for the prediction.
- min_size: The minimum size of a vesicle to be considered.
- verbose: Whether to print timing information.
- distance_based_segmentation: Whether to use distance-based segmentation.
- return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
- scale: The scale factor to use for rescaling the input volume before prediction.
- exclude_boundary: Whether to exclude vesicles that touch the upper / lower border in z.
- exclude_boundary_vesicles: Whether to exlude vesicles on the boundary that have less than the full diameter
inside of the volume. This is an alternative to post-processing with
exclude_boundary
that filters out less vesicles at the boundary and is better suited for volumes with small context in z. Ifexclude_boundary
is also set to True, then this option will have no effect. - mask: An optional mask that is used to restrict the segmentation.
Returns:
The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.