synapse_net.inference.vesicles

  1import time
  2from typing import Dict, List, Optional, Tuple, Union
  3
  4import elf.parallel as parallel
  5import numpy as np
  6
  7import torch
  8
  9from synapse_net.inference.util import apply_size_filter, get_prediction, _Scaler
 10from synapse_net.inference.postprocessing.vesicles import filter_border_objects
 11
 12
 13def distance_based_vesicle_segmentation(
 14    foreground: np.ndarray,
 15    boundaries: np.ndarray,
 16    verbose: bool,
 17    min_size: int,
 18    boundary_threshold: float = 0.5,  # previous default value was 0.9
 19    distance_threshold: int = 8,
 20    block_shape: Tuple[int, int, int] = (128, 256, 256),
 21    halo: Tuple[int, int, int] = (48, 48, 48),
 22) -> np.ndarray:
 23    """Segment vesicles using a seeded watershed from connected components derived from
 24    distance transform of the boundary predictions.
 25
 26    This approach can prevent false merges that occur with the `simple_vesicle_segmentation`.
 27
 28    Args:
 29        foreground: The foreground prediction.
 30        boundaries: The boundary prediction.
 31        verbose: Whether to print timing information.
 32        min_size: The minimal vesicle size.
 33        boundary_threshold: The threshold for binarizing the boundary predictions for the distance computation.
 34        distance_threshold: The threshold for finding connected components in the boundary distances.
 35        block_shape: Block shape for parallelizing the operations.
 36        halo: Halo for parallelizing the operations.
 37
 38    Returns:
 39        The vesicle segmentation.
 40    """
 41    # Compute the boundary distances.
 42    t0 = time.time()
 43    bd_dist = parallel.distance_transform(
 44        boundaries < boundary_threshold, halo=halo, verbose=verbose, block_shape=block_shape
 45    )
 46    bd_dist[foreground < 0.5] = 0
 47    if verbose:
 48        print("Compute distance transform in", time.time() - t0, "s")
 49
 50    # Get the segmentation via seeded watershed of components in the boundary distances.
 51    t0 = time.time()
 52    seeds = parallel.label(bd_dist > distance_threshold, block_shape=block_shape, verbose=verbose)
 53    if verbose:
 54        print("Compute connected components in", time.time() - t0, "s")
 55
 56    # Compute distances from the seeds, which are used as heightmap for the watershed,
 57    # to assign all pixels to the nearest seed.
 58    t0 = time.time()
 59    dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape)
 60    if verbose:
 61        print("Compute distance transform in", time.time() - t0, "s")
 62
 63    t0 = time.time()
 64    mask = (foreground + boundaries) > 0.5
 65    seg = np.zeros_like(seeds)
 66    seg = parallel.seeded_watershed(
 67        dist, seeds, block_shape=block_shape,
 68        out=seg, mask=mask, verbose=verbose, halo=halo,
 69    )
 70    if verbose:
 71        print("Compute watershed in", time.time() - t0, "s")
 72
 73    seg = apply_size_filter(seg, min_size, verbose, block_shape)
 74    return seg
 75
 76
 77def simple_vesicle_segmentation(
 78    foreground: np.ndarray,
 79    boundaries: np.ndarray,
 80    verbose: bool,
 81    min_size: int,
 82    block_shape: Tuple[int, int, int] = (128, 256, 256),
 83    halo: Tuple[int, int, int] = (48, 48, 48),
 84) -> np.ndarray:
 85    """Segment vesicles by subtracting boundary from foreground prediction and
 86    applying connected components.
 87
 88    Args:
 89        foreground: The foreground prediction.
 90        boundaries: The boundary prediction.
 91        verbose: Whether to print timing information.
 92        min_size: The minimal vesicle size.
 93        block_shape: Block shape for parallelizing the operations.
 94        halo: Halo for parallelizing the operations.
 95
 96    Returns:
 97        The vesicle segmentation.
 98    """
 99
100    t0 = time.time()
101    seeds = parallel.label((foreground - boundaries) > 0.5, block_shape=block_shape, verbose=verbose)
102    if verbose:
103        print("Compute connected components in", time.time() - t0, "s")
104
105    t0 = time.time()
106    dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape)
107    if verbose:
108        print("Compute distance transform in", time.time() - t0, "s")
109
110    t0 = time.time()
111    mask = (foreground + boundaries) > 0.5
112    seg = np.zeros_like(seeds)
113    seg = parallel.seeded_watershed(
114        dist, seeds, block_shape=block_shape,
115        out=seg, mask=mask, verbose=verbose, halo=halo,
116    )
117    if verbose:
118        print("Compute watershed in", time.time() - t0, "s")
119
120    seg = apply_size_filter(seg, min_size, verbose, block_shape)
121    return seg
122
123
124def segment_vesicles(
125    input_volume: np.ndarray,
126    model_path: Optional[str] = None,
127    model: Optional[torch.nn.Module] = None,
128    tiling: Optional[Dict[str, Dict[str, int]]] = None,
129    min_size: int = 500,
130    verbose: bool = True,
131    distance_based_segmentation: bool = True,
132    return_predictions: bool = False,
133    scale: Optional[List[float]] = None,
134    exclude_boundary: bool = False,
135    mask: Optional[np.ndarray] = None,
136) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
137    """Segment vesicles in an input volume or image.
138
139    Args:
140        input_volume: The input volume to segment.
141        model_path: The path to the model checkpoint if `model` is not provided.
142        model: Pre-loaded model. Either `model_path` or `model` is required.
143        tiling: The tiling configuration for the prediction.
144        min_size: The minimum size of a vesicle to be considered.
145        verbose: Whether to print timing information.
146        distance_based_segmentation: Whether to use distance-based segmentation.
147        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
148        scale: The scale factor to use for rescaling the input volume before prediction.
149        exclude_boundary: Whether to exclude vesicles that touch the upper / lower border in z.
150        mask: An optional mask that is used to restrict the segmentation.
151
152    Returns:
153        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
154        and the predictions if return_predictions is True.
155    """
156    if verbose:
157        print("Segmenting vesicles in volume of shape", input_volume.shape)
158    # Create the scaler to handle prediction with a different scaling factor.
159    scaler = _Scaler(scale, verbose)
160    input_volume = scaler.scale_input(input_volume)
161
162    # Rescale the mask if it was given and run prediction.
163    if mask is not None:
164        mask = scaler.scale_input(mask, is_segmentation=True)
165    pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose, mask=mask)
166    foreground, boundaries = pred[:2]
167
168    # Deal with 2D segmentation case.
169    kwargs = {}
170    if len(input_volume.shape) == 2:
171        kwargs["block_shape"] = (256, 256)
172        kwargs["halo"] = (48, 48)
173
174    if distance_based_segmentation:
175        seg = distance_based_vesicle_segmentation(
176            foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs
177        )
178    else:
179        seg = simple_vesicle_segmentation(
180            foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs
181        )
182
183    if exclude_boundary:
184        seg = filter_border_objects(seg)
185    seg = scaler.rescale_output(seg, is_segmentation=True)
186
187    if return_predictions:
188        pred = scaler.rescale_output(pred, is_segmentation=False)
189        return seg, pred
190    return seg
def distance_based_vesicle_segmentation( foreground: numpy.ndarray, boundaries: numpy.ndarray, verbose: bool, min_size: int, boundary_threshold: float = 0.5, distance_threshold: int = 8, block_shape: Tuple[int, int, int] = (128, 256, 256), halo: Tuple[int, int, int] = (48, 48, 48)) -> numpy.ndarray:
14def distance_based_vesicle_segmentation(
15    foreground: np.ndarray,
16    boundaries: np.ndarray,
17    verbose: bool,
18    min_size: int,
19    boundary_threshold: float = 0.5,  # previous default value was 0.9
20    distance_threshold: int = 8,
21    block_shape: Tuple[int, int, int] = (128, 256, 256),
22    halo: Tuple[int, int, int] = (48, 48, 48),
23) -> np.ndarray:
24    """Segment vesicles using a seeded watershed from connected components derived from
25    distance transform of the boundary predictions.
26
27    This approach can prevent false merges that occur with the `simple_vesicle_segmentation`.
28
29    Args:
30        foreground: The foreground prediction.
31        boundaries: The boundary prediction.
32        verbose: Whether to print timing information.
33        min_size: The minimal vesicle size.
34        boundary_threshold: The threshold for binarizing the boundary predictions for the distance computation.
35        distance_threshold: The threshold for finding connected components in the boundary distances.
36        block_shape: Block shape for parallelizing the operations.
37        halo: Halo for parallelizing the operations.
38
39    Returns:
40        The vesicle segmentation.
41    """
42    # Compute the boundary distances.
43    t0 = time.time()
44    bd_dist = parallel.distance_transform(
45        boundaries < boundary_threshold, halo=halo, verbose=verbose, block_shape=block_shape
46    )
47    bd_dist[foreground < 0.5] = 0
48    if verbose:
49        print("Compute distance transform in", time.time() - t0, "s")
50
51    # Get the segmentation via seeded watershed of components in the boundary distances.
52    t0 = time.time()
53    seeds = parallel.label(bd_dist > distance_threshold, block_shape=block_shape, verbose=verbose)
54    if verbose:
55        print("Compute connected components in", time.time() - t0, "s")
56
57    # Compute distances from the seeds, which are used as heightmap for the watershed,
58    # to assign all pixels to the nearest seed.
59    t0 = time.time()
60    dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape)
61    if verbose:
62        print("Compute distance transform in", time.time() - t0, "s")
63
64    t0 = time.time()
65    mask = (foreground + boundaries) > 0.5
66    seg = np.zeros_like(seeds)
67    seg = parallel.seeded_watershed(
68        dist, seeds, block_shape=block_shape,
69        out=seg, mask=mask, verbose=verbose, halo=halo,
70    )
71    if verbose:
72        print("Compute watershed in", time.time() - t0, "s")
73
74    seg = apply_size_filter(seg, min_size, verbose, block_shape)
75    return seg

Segment vesicles using a seeded watershed from connected components derived from distance transform of the boundary predictions.

This approach can prevent false merges that occur with the simple_vesicle_segmentation.

Arguments:
  • foreground: The foreground prediction.
  • boundaries: The boundary prediction.
  • verbose: Whether to print timing information.
  • min_size: The minimal vesicle size.
  • boundary_threshold: The threshold for binarizing the boundary predictions for the distance computation.
  • distance_threshold: The threshold for finding connected components in the boundary distances.
  • block_shape: Block shape for parallelizing the operations.
  • halo: Halo for parallelizing the operations.
Returns:

The vesicle segmentation.

def simple_vesicle_segmentation( foreground: numpy.ndarray, boundaries: numpy.ndarray, verbose: bool, min_size: int, block_shape: Tuple[int, int, int] = (128, 256, 256), halo: Tuple[int, int, int] = (48, 48, 48)) -> numpy.ndarray:
 78def simple_vesicle_segmentation(
 79    foreground: np.ndarray,
 80    boundaries: np.ndarray,
 81    verbose: bool,
 82    min_size: int,
 83    block_shape: Tuple[int, int, int] = (128, 256, 256),
 84    halo: Tuple[int, int, int] = (48, 48, 48),
 85) -> np.ndarray:
 86    """Segment vesicles by subtracting boundary from foreground prediction and
 87    applying connected components.
 88
 89    Args:
 90        foreground: The foreground prediction.
 91        boundaries: The boundary prediction.
 92        verbose: Whether to print timing information.
 93        min_size: The minimal vesicle size.
 94        block_shape: Block shape for parallelizing the operations.
 95        halo: Halo for parallelizing the operations.
 96
 97    Returns:
 98        The vesicle segmentation.
 99    """
100
101    t0 = time.time()
102    seeds = parallel.label((foreground - boundaries) > 0.5, block_shape=block_shape, verbose=verbose)
103    if verbose:
104        print("Compute connected components in", time.time() - t0, "s")
105
106    t0 = time.time()
107    dist = parallel.distance_transform(seeds == 0, halo=halo, verbose=verbose, block_shape=block_shape)
108    if verbose:
109        print("Compute distance transform in", time.time() - t0, "s")
110
111    t0 = time.time()
112    mask = (foreground + boundaries) > 0.5
113    seg = np.zeros_like(seeds)
114    seg = parallel.seeded_watershed(
115        dist, seeds, block_shape=block_shape,
116        out=seg, mask=mask, verbose=verbose, halo=halo,
117    )
118    if verbose:
119        print("Compute watershed in", time.time() - t0, "s")
120
121    seg = apply_size_filter(seg, min_size, verbose, block_shape)
122    return seg

Segment vesicles by subtracting boundary from foreground prediction and applying connected components.

Arguments:
  • foreground: The foreground prediction.
  • boundaries: The boundary prediction.
  • verbose: Whether to print timing information.
  • min_size: The minimal vesicle size.
  • block_shape: Block shape for parallelizing the operations.
  • halo: Halo for parallelizing the operations.
Returns:

The vesicle segmentation.

def segment_vesicles( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 500, verbose: bool = True, distance_based_segmentation: bool = True, return_predictions: bool = False, scale: Optional[List[float]] = None, exclude_boundary: bool = False, mask: Optional[numpy.ndarray] = None) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
125def segment_vesicles(
126    input_volume: np.ndarray,
127    model_path: Optional[str] = None,
128    model: Optional[torch.nn.Module] = None,
129    tiling: Optional[Dict[str, Dict[str, int]]] = None,
130    min_size: int = 500,
131    verbose: bool = True,
132    distance_based_segmentation: bool = True,
133    return_predictions: bool = False,
134    scale: Optional[List[float]] = None,
135    exclude_boundary: bool = False,
136    mask: Optional[np.ndarray] = None,
137) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
138    """Segment vesicles in an input volume or image.
139
140    Args:
141        input_volume: The input volume to segment.
142        model_path: The path to the model checkpoint if `model` is not provided.
143        model: Pre-loaded model. Either `model_path` or `model` is required.
144        tiling: The tiling configuration for the prediction.
145        min_size: The minimum size of a vesicle to be considered.
146        verbose: Whether to print timing information.
147        distance_based_segmentation: Whether to use distance-based segmentation.
148        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
149        scale: The scale factor to use for rescaling the input volume before prediction.
150        exclude_boundary: Whether to exclude vesicles that touch the upper / lower border in z.
151        mask: An optional mask that is used to restrict the segmentation.
152
153    Returns:
154        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
155        and the predictions if return_predictions is True.
156    """
157    if verbose:
158        print("Segmenting vesicles in volume of shape", input_volume.shape)
159    # Create the scaler to handle prediction with a different scaling factor.
160    scaler = _Scaler(scale, verbose)
161    input_volume = scaler.scale_input(input_volume)
162
163    # Rescale the mask if it was given and run prediction.
164    if mask is not None:
165        mask = scaler.scale_input(mask, is_segmentation=True)
166    pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose, mask=mask)
167    foreground, boundaries = pred[:2]
168
169    # Deal with 2D segmentation case.
170    kwargs = {}
171    if len(input_volume.shape) == 2:
172        kwargs["block_shape"] = (256, 256)
173        kwargs["halo"] = (48, 48)
174
175    if distance_based_segmentation:
176        seg = distance_based_vesicle_segmentation(
177            foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs
178        )
179    else:
180        seg = simple_vesicle_segmentation(
181            foreground, boundaries, verbose=verbose, min_size=min_size, **kwargs
182        )
183
184    if exclude_boundary:
185        seg = filter_border_objects(seg)
186    seg = scaler.rescale_output(seg, is_segmentation=True)
187
188    if return_predictions:
189        pred = scaler.rescale_output(pred, is_segmentation=False)
190        return seg, pred
191    return seg

Segment vesicles in an input volume or image.

Arguments:
  • input_volume: The input volume to segment.
  • model_path: The path to the model checkpoint if model is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • min_size: The minimum size of a vesicle to be considered.
  • verbose: Whether to print timing information.
  • distance_based_segmentation: Whether to use distance-based segmentation.
  • return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
  • scale: The scale factor to use for rescaling the input volume before prediction.
  • exclude_boundary: Whether to exclude vesicles that touch the upper / lower border in z.
  • mask: An optional mask that is used to restrict the segmentation.
Returns:

The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.