synapse_net.inference.vesicles

  1import time
  2import warnings
  3from typing import Dict, List, Optional, Tuple, Union
  4
  5import elf.parallel as parallel
  6import numpy as np
  7
  8import torch
  9
 10from synapse_net.inference.util import apply_size_filter, get_prediction, _Scaler
 11from synapse_net.inference.postprocessing.vesicles import filter_border_objects, filter_border_vesicles
 12from skimage.segmentation import relabel_sequential
 13
 14
 15def distance_based_vesicle_segmentation(
 16    foreground: np.ndarray,
 17    boundaries: np.ndarray,
 18    verbose: bool,
 19    min_size: int,
 20    boundary_threshold: float = 0.5,  # previous default value was 0.9
 21    distance_threshold: int = 8,
 22    block_shape: Tuple[int, int, int] = (128, 256, 256),
 23    halo: Tuple[int, int, int] = (48, 48, 48),
 24) -> np.ndarray:
 25    """Segment vesicles using a seeded watershed from connected components derived from
 26    distance transform of the boundary predictions.
 27
 28    This approach can prevent false merges that occur with the `simple_vesicle_segmentation`.
 29
 30    Args:
 31        foreground: The foreground prediction.
 32        boundaries: The boundary prediction.
 33        verbose: Whether to print timing information.
 34        min_size: The minimal vesicle size.
 35        boundary_threshold: The threshold for binarizing the boundary predictions for the distance computation.
 36        distance_threshold: The threshold for finding connected components in the boundary distances.
 37        block_shape: Block shape for parallelizing the operations.
 38        halo: Halo for parallelizing the operations.
 39
 40    Returns:
 41        The vesicle segmentation.
 42    """
 43    # Compute the boundary distances.
 44    t0 = time.time()
 45    bd_dist = parallel.distance_transform(
 46        boundaries < boundary_threshold, halo=halo, verbose=verbose, block_shape=block_shape
 47    )
 48    bd_dist[foreground < 0.5] = 0
 49    if verbose:
 50        print("Compute distance transform in", time.time() - t0, "s")
 51
 52    # Get the segmentation via seeded watershed of components in the boundary distances.
 53    t0 = time.time()
 54    seeds = parallel.label(bd_dist > distance_threshold, block_shape=block_shape, verbose=verbose)
 55    if verbose:
 56        print("Compute connected components in", time.time() - t0, "s")
 57
 58    # Compute distances from the seeds, which are used as heightmap for the watershed,
 59    # to assign all pixels to the nearest seed.
 60    t0 = time.time()
 61    dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape)
 62    if verbose:
 63        print("Compute distance transform in", time.time() - t0, "s")
 64
 65    t0 = time.time()
 66    mask = (foreground + boundaries) > 0.5
 67    seg = np.zeros_like(seeds)
 68    seg = parallel.seeded_watershed(
 69        dist, seeds, block_shape=block_shape,
 70        out=seg, mask=mask, verbose=verbose, halo=halo,
 71    )
 72    if verbose:
 73        print("Compute watershed in", time.time() - t0, "s")
 74
 75    seg = apply_size_filter(seg, min_size, verbose, block_shape)
 76    return seg
 77
 78
 79def simple_vesicle_segmentation(
 80    foreground: np.ndarray,
 81    boundaries: np.ndarray,
 82    verbose: bool,
 83    min_size: int,
 84    block_shape: Tuple[int, int, int] = (128, 256, 256),
 85    halo: Tuple[int, int, int] = (48, 48, 48),
 86) -> np.ndarray:
 87    """Segment vesicles by subtracting boundary from foreground prediction and
 88    applying connected components.
 89
 90    Args:
 91        foreground: The foreground prediction.
 92        boundaries: The boundary prediction.
 93        verbose: Whether to print timing information.
 94        min_size: The minimal vesicle size.
 95        block_shape: Block shape for parallelizing the operations.
 96        halo: Halo for parallelizing the operations.
 97
 98    Returns:
 99        The vesicle segmentation.
100    """
101
102    t0 = time.time()
103    seeds = parallel.label((foreground - boundaries) > 0.5, block_shape=block_shape, verbose=verbose)
104    if verbose:
105        print("Compute connected components in", time.time() - t0, "s")
106
107    t0 = time.time()
108    dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape)
109    if verbose:
110        print("Compute distance transform in", time.time() - t0, "s")
111
112    t0 = time.time()
113    mask = (foreground + boundaries) > 0.5
114    seg = np.zeros_like(seeds)
115    seg = parallel.seeded_watershed(
116        dist, seeds, block_shape=block_shape,
117        out=seg, mask=mask, verbose=verbose, halo=halo,
118    )
119    if verbose:
120        print("Compute watershed in", time.time() - t0, "s")
121
122    seg = apply_size_filter(seg, min_size, verbose, block_shape)
123    return seg
124
125
126def segment_vesicles(
127    input_volume: np.ndarray,
128    model_path: Optional[str] = None,
129    model: Optional[torch.nn.Module] = None,
130    tiling: Optional[Dict[str, Dict[str, int]]] = None,
131    min_size: int = 500,
132    verbose: bool = True,
133    distance_based_segmentation: bool = True,
134    return_predictions: bool = False,
135    scale: Optional[List[float]] = None,
136    exclude_boundary: bool = False,
137    exclude_boundary_vesicles: bool = False,
138    mask: Optional[np.ndarray] = None,
139) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
140    """Segment vesicles in an input volume or image.
141
142    Args:
143        input_volume: The input volume to segment.
144        model_path: The path to the model checkpoint if `model` is not provided.
145        model: Pre-loaded model. Either `model_path` or `model` is required.
146        tiling: The tiling configuration for the prediction.
147        min_size: The minimum size of a vesicle to be considered.
148        verbose: Whether to print timing information.
149        distance_based_segmentation: Whether to use distance-based segmentation.
150        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
151        scale: The scale factor to use for rescaling the input volume before prediction.
152        exclude_boundary: Whether to exclude vesicles that touch the upper / lower border in z.
153        exclude_boundary_vesicles: Whether to exlude vesicles on the boundary that have less than the full diameter
154            inside of the volume. This is an alternative to post-processing with `exclude_boundary` that filters
155            out less vesicles at the boundary and is better suited for volumes with small context in z.
156            If `exclude_boundary` is also set to True, then this option will have no effect.
157        mask: An optional mask that is used to restrict the segmentation.
158
159    Returns:
160        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
161        and the predictions if return_predictions is True.
162    """
163    if verbose:
164        print("Segmenting vesicles in volume of shape", input_volume.shape)
165    # Create the scaler to handle prediction with a different scaling factor.
166    scaler = _Scaler(scale, verbose)
167    input_volume = scaler.scale_input(input_volume)
168
169    # Rescale the mask if it was given and run prediction.
170    if mask is not None:
171        mask = scaler.scale_input(mask, is_segmentation=True)
172    pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose, mask=mask)
173    foreground, boundaries = pred[:2]
174
175    # Deal with 2D segmentation case.
176    kwargs = {}
177    if len(input_volume.shape) == 2:
178        kwargs["block_shape"] = (256, 256)
179        kwargs["halo"] = (48, 48)
180
181    if distance_based_segmentation:
182        seg = distance_based_vesicle_segmentation(
183            foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs
184        )
185    else:
186        seg = simple_vesicle_segmentation(
187            foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs
188        )
189
190    if exclude_boundary and exclude_boundary_vesicles:
191        warnings.warn(
192            "You have set both 'exclude_boundary' and 'exclude_boundary_vesicles' to True."
193            "The 'exclude_boundary_vesicles' option will have no effect."
194        )
195        seg = filter_border_objects(seg)
196
197    elif exclude_boundary:
198        seg = filter_border_objects(seg)
199
200    elif exclude_boundary_vesicles:
201        # Filter the vesicles that are at the z-border with less than their full diameter.
202        seg_ids = filter_border_vesicles(seg)
203
204        # Remove everything not in seg ids and relable the remaining IDs consecutively.
205        seg[~np.isin(seg, seg_ids)] = 0
206        seg = relabel_sequential(seg)[0]
207
208    seg = scaler.rescale_output(seg, is_segmentation=True)
209
210    if return_predictions:
211        pred = scaler.rescale_output(pred, is_segmentation=False)
212        return seg, pred
213    return seg
def distance_based_vesicle_segmentation( foreground: numpy.ndarray, boundaries: numpy.ndarray, verbose: bool, min_size: int, boundary_threshold: float = 0.5, distance_threshold: int = 8, block_shape: Tuple[int, int, int] = (128, 256, 256), halo: Tuple[int, int, int] = (48, 48, 48)) -> numpy.ndarray:
16def distance_based_vesicle_segmentation(
17    foreground: np.ndarray,
18    boundaries: np.ndarray,
19    verbose: bool,
20    min_size: int,
21    boundary_threshold: float = 0.5,  # previous default value was 0.9
22    distance_threshold: int = 8,
23    block_shape: Tuple[int, int, int] = (128, 256, 256),
24    halo: Tuple[int, int, int] = (48, 48, 48),
25) -> np.ndarray:
26    """Segment vesicles using a seeded watershed from connected components derived from
27    distance transform of the boundary predictions.
28
29    This approach can prevent false merges that occur with the `simple_vesicle_segmentation`.
30
31    Args:
32        foreground: The foreground prediction.
33        boundaries: The boundary prediction.
34        verbose: Whether to print timing information.
35        min_size: The minimal vesicle size.
36        boundary_threshold: The threshold for binarizing the boundary predictions for the distance computation.
37        distance_threshold: The threshold for finding connected components in the boundary distances.
38        block_shape: Block shape for parallelizing the operations.
39        halo: Halo for parallelizing the operations.
40
41    Returns:
42        The vesicle segmentation.
43    """
44    # Compute the boundary distances.
45    t0 = time.time()
46    bd_dist = parallel.distance_transform(
47        boundaries < boundary_threshold, halo=halo, verbose=verbose, block_shape=block_shape
48    )
49    bd_dist[foreground < 0.5] = 0
50    if verbose:
51        print("Compute distance transform in", time.time() - t0, "s")
52
53    # Get the segmentation via seeded watershed of components in the boundary distances.
54    t0 = time.time()
55    seeds = parallel.label(bd_dist > distance_threshold, block_shape=block_shape, verbose=verbose)
56    if verbose:
57        print("Compute connected components in", time.time() - t0, "s")
58
59    # Compute distances from the seeds, which are used as heightmap for the watershed,
60    # to assign all pixels to the nearest seed.
61    t0 = time.time()
62    dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape)
63    if verbose:
64        print("Compute distance transform in", time.time() - t0, "s")
65
66    t0 = time.time()
67    mask = (foreground + boundaries) > 0.5
68    seg = np.zeros_like(seeds)
69    seg = parallel.seeded_watershed(
70        dist, seeds, block_shape=block_shape,
71        out=seg, mask=mask, verbose=verbose, halo=halo,
72    )
73    if verbose:
74        print("Compute watershed in", time.time() - t0, "s")
75
76    seg = apply_size_filter(seg, min_size, verbose, block_shape)
77    return seg

Segment vesicles using a seeded watershed from connected components derived from distance transform of the boundary predictions.

This approach can prevent false merges that occur with the simple_vesicle_segmentation.

Arguments:
  • foreground: The foreground prediction.
  • boundaries: The boundary prediction.
  • verbose: Whether to print timing information.
  • min_size: The minimal vesicle size.
  • boundary_threshold: The threshold for binarizing the boundary predictions for the distance computation.
  • distance_threshold: The threshold for finding connected components in the boundary distances.
  • block_shape: Block shape for parallelizing the operations.
  • halo: Halo for parallelizing the operations.
Returns:

The vesicle segmentation.

def simple_vesicle_segmentation( foreground: numpy.ndarray, boundaries: numpy.ndarray, verbose: bool, min_size: int, block_shape: Tuple[int, int, int] = (128, 256, 256), halo: Tuple[int, int, int] = (48, 48, 48)) -> numpy.ndarray:
 80def simple_vesicle_segmentation(
 81    foreground: np.ndarray,
 82    boundaries: np.ndarray,
 83    verbose: bool,
 84    min_size: int,
 85    block_shape: Tuple[int, int, int] = (128, 256, 256),
 86    halo: Tuple[int, int, int] = (48, 48, 48),
 87) -> np.ndarray:
 88    """Segment vesicles by subtracting boundary from foreground prediction and
 89    applying connected components.
 90
 91    Args:
 92        foreground: The foreground prediction.
 93        boundaries: The boundary prediction.
 94        verbose: Whether to print timing information.
 95        min_size: The minimal vesicle size.
 96        block_shape: Block shape for parallelizing the operations.
 97        halo: Halo for parallelizing the operations.
 98
 99    Returns:
100        The vesicle segmentation.
101    """
102
103    t0 = time.time()
104    seeds = parallel.label((foreground - boundaries) > 0.5, block_shape=block_shape, verbose=verbose)
105    if verbose:
106        print("Compute connected components in", time.time() - t0, "s")
107
108    t0 = time.time()
109    dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape)
110    if verbose:
111        print("Compute distance transform in", time.time() - t0, "s")
112
113    t0 = time.time()
114    mask = (foreground + boundaries) > 0.5
115    seg = np.zeros_like(seeds)
116    seg = parallel.seeded_watershed(
117        dist, seeds, block_shape=block_shape,
118        out=seg, mask=mask, verbose=verbose, halo=halo,
119    )
120    if verbose:
121        print("Compute watershed in", time.time() - t0, "s")
122
123    seg = apply_size_filter(seg, min_size, verbose, block_shape)
124    return seg

Segment vesicles by subtracting boundary from foreground prediction and applying connected components.

Arguments:
  • foreground: The foreground prediction.
  • boundaries: The boundary prediction.
  • verbose: Whether to print timing information.
  • min_size: The minimal vesicle size.
  • block_shape: Block shape for parallelizing the operations.
  • halo: Halo for parallelizing the operations.
Returns:

The vesicle segmentation.

def segment_vesicles( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 500, verbose: bool = True, distance_based_segmentation: bool = True, return_predictions: bool = False, scale: Optional[List[float]] = None, exclude_boundary: bool = False, exclude_boundary_vesicles: bool = False, mask: Optional[numpy.ndarray] = None) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
127def segment_vesicles(
128    input_volume: np.ndarray,
129    model_path: Optional[str] = None,
130    model: Optional[torch.nn.Module] = None,
131    tiling: Optional[Dict[str, Dict[str, int]]] = None,
132    min_size: int = 500,
133    verbose: bool = True,
134    distance_based_segmentation: bool = True,
135    return_predictions: bool = False,
136    scale: Optional[List[float]] = None,
137    exclude_boundary: bool = False,
138    exclude_boundary_vesicles: bool = False,
139    mask: Optional[np.ndarray] = None,
140) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
141    """Segment vesicles in an input volume or image.
142
143    Args:
144        input_volume: The input volume to segment.
145        model_path: The path to the model checkpoint if `model` is not provided.
146        model: Pre-loaded model. Either `model_path` or `model` is required.
147        tiling: The tiling configuration for the prediction.
148        min_size: The minimum size of a vesicle to be considered.
149        verbose: Whether to print timing information.
150        distance_based_segmentation: Whether to use distance-based segmentation.
151        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
152        scale: The scale factor to use for rescaling the input volume before prediction.
153        exclude_boundary: Whether to exclude vesicles that touch the upper / lower border in z.
154        exclude_boundary_vesicles: Whether to exlude vesicles on the boundary that have less than the full diameter
155            inside of the volume. This is an alternative to post-processing with `exclude_boundary` that filters
156            out less vesicles at the boundary and is better suited for volumes with small context in z.
157            If `exclude_boundary` is also set to True, then this option will have no effect.
158        mask: An optional mask that is used to restrict the segmentation.
159
160    Returns:
161        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
162        and the predictions if return_predictions is True.
163    """
164    if verbose:
165        print("Segmenting vesicles in volume of shape", input_volume.shape)
166    # Create the scaler to handle prediction with a different scaling factor.
167    scaler = _Scaler(scale, verbose)
168    input_volume = scaler.scale_input(input_volume)
169
170    # Rescale the mask if it was given and run prediction.
171    if mask is not None:
172        mask = scaler.scale_input(mask, is_segmentation=True)
173    pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose, mask=mask)
174    foreground, boundaries = pred[:2]
175
176    # Deal with 2D segmentation case.
177    kwargs = {}
178    if len(input_volume.shape) == 2:
179        kwargs["block_shape"] = (256, 256)
180        kwargs["halo"] = (48, 48)
181
182    if distance_based_segmentation:
183        seg = distance_based_vesicle_segmentation(
184            foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs
185        )
186    else:
187        seg = simple_vesicle_segmentation(
188            foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs
189        )
190
191    if exclude_boundary and exclude_boundary_vesicles:
192        warnings.warn(
193            "You have set both 'exclude_boundary' and 'exclude_boundary_vesicles' to True."
194            "The 'exclude_boundary_vesicles' option will have no effect."
195        )
196        seg = filter_border_objects(seg)
197
198    elif exclude_boundary:
199        seg = filter_border_objects(seg)
200
201    elif exclude_boundary_vesicles:
202        # Filter the vesicles that are at the z-border with less than their full diameter.
203        seg_ids = filter_border_vesicles(seg)
204
205        # Remove everything not in seg ids and relable the remaining IDs consecutively.
206        seg[~np.isin(seg, seg_ids)] = 0
207        seg = relabel_sequential(seg)[0]
208
209    seg = scaler.rescale_output(seg, is_segmentation=True)
210
211    if return_predictions:
212        pred = scaler.rescale_output(pred, is_segmentation=False)
213        return seg, pred
214    return seg

Segment vesicles in an input volume or image.

Arguments:
  • input_volume: The input volume to segment.
  • model_path: The path to the model checkpoint if model is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • min_size: The minimum size of a vesicle to be considered.
  • verbose: Whether to print timing information.
  • distance_based_segmentation: Whether to use distance-based segmentation.
  • return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
  • scale: The scale factor to use for rescaling the input volume before prediction.
  • exclude_boundary: Whether to exclude vesicles that touch the upper / lower border in z.
  • exclude_boundary_vesicles: Whether to exlude vesicles on the boundary that have less than the full diameter inside of the volume. This is an alternative to post-processing with exclude_boundary that filters out less vesicles at the boundary and is better suited for volumes with small context in z. If exclude_boundary is also set to True, then this option will have no effect.
  • mask: An optional mask that is used to restrict the segmentation.
Returns:

The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.