synapse_net.ground_truth.shape_refinement

  1import multiprocessing as mp
  2from concurrent import futures
  3from functools import partial
  4from typing import List, Dict, Optional
  5
  6import numpy as np
  7
  8import bioimage_cpp as bic
  9from scipy.ndimage import binary_erosion, binary_dilation
 10from skimage import img_as_ubyte
 11from skimage.filters import gaussian, rank, sato, sobel
 12from skimage.measure import regionprops
 13from skimage.morphology import disk
 14from skimage.segmentation import watershed
 15from tqdm import tqdm
 16
 17FILTERS = ("sobel", "laplace", "ggm", "structure-tensor", "sato")
 18
 19
 20def _sato_filter(raw, sigma, max_window=16):
 21    if raw.ndim != 2:
 22        raise NotImplementedError("The sato filter is only implemented for 2D data.")
 23    hmap = sato(raw)
 24    hmap = gaussian(hmap, sigma=sigma)
 25    hmap -= hmap.min()
 26    hmap /= hmap.max()
 27    hmap = rank.autolevel(img_as_ubyte(hmap), disk(max_window)).astype("float") / 255
 28    return hmap
 29
 30
 31def edge_filter(
 32    data: np.ndarray,
 33    sigma: float,
 34    method: str = "sato",
 35    per_slice: bool = True,
 36    n_threads: Optional[int] = None,
 37) -> np.ndarray:
 38    """Find edges in the image data.
 39
 40    Args:
 41        data: The input data.
 42        sigma: The smoothing factor applied before the edge filter.
 43        method: The method for finding edges. The following methods are supported:
 44            - "sobel": Edges are found by smoothing the data and then applying a sobel filter.
 45            - "laplace": Edges are found with a laplacian of gaussian filter.
 46            - "ggm": Edges are found with a gaussian gradient magnitude filter.
 47            - "structure-tensor": Edges are found based on the 1st eigenvalue of the structure tensor.
 48            - "sato": Edges are found with a sato-filter, followed by smoothing and leveling.
 49        per_slice: Compute the filter per slice instead of for the whole volume.
 50        n_threads: Number of threads for parallel computation over the slices.
 51
 52    Returns:
 53        Edge filter response.
 54    """
 55    if method not in FILTERS:
 56        raise ValueError(f"Invalid edge filter method: {method}. Expect one of {FILTERS}.")
 57
 58    if per_slice and data.ndim == 3:
 59        n_threads = mp.cpu_count() if n_threads is None else n_threads
 60        filter_func = partial(edge_filter, sigma=sigma, method=method, per_slice=False)
 61        with futures.ThreadPoolExecutor(n_threads) as tp:
 62            edge_map = list(tp.map(filter_func, data))
 63        edge_map = np.stack(edge_map)
 64        return edge_map
 65
 66    if method == "sobel":
 67        edge_map = gaussian(data, sigma=sigma)
 68        edge_map = sobel(edge_map)
 69    elif method == "laplace":
 70        edge_map = bic.filters.laplacian_of_gaussian(data.astype("float32"), sigma)
 71    elif method == "ggm":
 72        edge_map = bic.filters.gaussian_gradient_magnitude(data.astype("float32"), sigma)
 73    elif method == "structure-tensor":
 74        inner_scale, outer_scale = sigma, sigma * 0.5
 75        edge_map = bic.filters.structure_tensor_eigenvalues(
 76            data.astype("float32"), inner_scale, outer_scale
 77        )[..., 0]
 78    elif method == "sato":
 79        edge_map = _sato_filter(data, sigma)
 80
 81    return edge_map
 82
 83
 84def check_filters(
 85    data: np.ndarray,
 86    filters: List[str] = FILTERS,
 87    sigmas: List[float] = [2.0, 4.0],
 88    show: bool = True,
 89) -> Dict[str, np.ndarray]:
 90    """Apply different edge filters to the input data.
 91
 92    Args:
 93        data: The input data volume.
 94        filters: The names of edge filters to apply.
 95            The filter names must match `method` in `edge_filter`.
 96        sigmas: The sigma values to use for the filters.
 97        show: Whether to show the filter responses in napari.
 98
 99    Returns:
100        Dictionary with the filter responses.
101    """
102
103    n_filters = len(filters) * len(sigmas)
104    pbar = tqdm(total=n_filters, desc="Compute filters")
105
106    responses = {}
107    for filter_ in filters:
108        for sigma in sigmas:
109            name = f"{filter_}_{sigma}"
110            responses[name] = edge_filter(data, sigma, method=filter_)
111            pbar.update(1)
112
113    if show:
114        import napari
115
116        v = napari.Viewer()
117        v.add_image(data)
118        for name, response in responses.items():
119            v.add_image(response, name=name)
120        napari.run()
121
122    return responses
123
124
125def refine_vesicle_shapes(
126    vesicles: np.ndarray,
127    edge_map: np.ndarray,
128    foreground_erosion: int = 2,
129    background_erosion: int = 6,
130    fit_to_outer_boundary: bool = False,
131    return_seeds: bool = False,
132    compactness: float = 1.0,
133) -> np.ndarray:
134    """Refine vesicle shapes by fitting vesicles to a boundary map.
135
136    This function erodes the segmented vesicles, and then fits them
137    to a bonudary using a seeded watershed. This is done with two watersheds,
138    one two separate foreground from background and one to separate vesicles within
139    the foreground.
140
141    Args:
142        vesicles: The vesicle segmentation.
143        edge_map: Volume with high intensities for vesicle membrane.
144            You can use `edge_filter` to compute this based on the tomogram.
145        foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
146        background_erosion: By how many pixels the background should be eroded in the seeds.
147        fit_to_outer_boundary: Whether to fit the seeds to the outer membrane by
148            applying a second edge filter to `edge_map`.
149        return_seeds: Whether to return the seeds used for the watershed.
150        compactness: The compactness parameter passed to the watershed function.
151            Higher compactness leads to more regular sized vesicles.
152
153    Returns:
154        The refined vesicles.
155    """
156
157    fg = vesicles != 0
158    if foreground_erosion > 0:
159        fg_seeds = binary_erosion(fg, iterations=foreground_erosion).astype("uint8")
160    else:
161        fg_seeds = fg.astype("uint8")
162    bg = vesicles == 0
163    bg_seeds = binary_erosion(bg, iterations=background_erosion).astype("uint8")
164    # Create 1 pixel wide mask and set to 1 and add to bg seed
165    # Create a 1-pixel wide boundary at the edges of the tomogram
166    boundary_mask = np.zeros_like(bg, dtype="uint8")
167
168    # Set the boundary to 1 along the edges of each dimension
169    boundary_mask[0, :, :] = 1
170    boundary_mask[-1, :, :] = 1
171    boundary_mask[:, 0, :] = 1
172    boundary_mask[:, -1, :] = 1
173    boundary_mask[:, :, 0] = 1
174    boundary_mask[:, :, -1] = 1
175
176    # Add the boundary to the background seeds without affecting existing seeds
177    bg_seeds = np.clip(bg_seeds + boundary_mask, 0, 1)  # Ensure values are either 0 or 1
178
179    if fit_to_outer_boundary:
180        outer_edge_map = edge_filter(edge_map, sigma=2)
181    else:
182        outer_edge_map = edge_map
183
184    seeds = fg_seeds + 2 * bg_seeds
185    refined_mask = watershed(outer_edge_map, seeds, compactness=compactness)
186    refined_mask[refined_mask == 2] = 0
187
188    refined_vesicles = watershed(edge_map, vesicles, mask=refined_mask, compactness=compactness)
189
190    if return_seeds:
191        return refined_vesicles, seeds
192    return refined_vesicles
193
194
195def refine_individual_vesicle_shapes(
196    vesicles: np.ndarray,
197    edge_map: np.ndarray,
198    foreground_erosion: int = 4,
199    background_erosion: int = 8,
200) -> np.ndarray:
201    """Refine vesicle shapes by fitting vesicles to a boundary map.
202
203    This function erodes the segmented vesicles, and then fits them
204    to a bonudary using a seeded watershed. This is done individually for each vesicle.
205
206    Args:
207        vesicles: The vesicle segmentation.
208        edge_map: Volume with high intensities for vesicle membrane.
209            You can use `edge_filter` to compute this based on the tomogram.
210        foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
211        background_erosion: By how many pixels the background should be eroded in the seeds.
212    Returns:
213        The refined vesicles.
214    """
215
216    refined_vesicles = np.zeros_like(vesicles)
217    halo = [0, 12, 12]
218
219    def fit_vesicle(prop):
220        label_id = prop.label
221
222        bb = prop.bbox
223        bb = tuple(
224            slice(max(start - ha, 0), min(stop + ha, sh)) for start, stop, ha, sh in
225            zip(bb[:3], bb[3:], halo, vesicles.shape)
226        )
227        vesicle_sub = vesicles[bb]
228
229        vesicle_mask = vesicle_sub == label_id
230        hmap = edge_map[bb]
231
232        # Do refinement in 2d to avoid effects caused by anisotropy.
233        seg = np.zeros_like(vesicle_mask)
234        for z in range(seg.shape[0]):
235            m = vesicle_mask[z]
236            fg_seed = binary_erosion(m, iterations=foreground_erosion).astype("uint8")
237            if fg_seed.sum() == 0:
238                seg[z][m] = 1
239                continue
240            bg_seed = (~binary_dilation(m, iterations=background_erosion)).astype("uint8")
241            # Make sure all other vesicles in the local bbox are part of the bg seed,
242            # to avoid leaking into other vesicles.
243            bg_seed[(vesicle_sub[z] != 0) & (vesicle_sub[z] != label_id)] = 1
244
245            # Run seeded watershed to fit the shapes.
246            seeds = fg_seed + 2 * bg_seed
247            seg[z] = watershed(hmap[z], seeds) == 1
248
249        # import napari
250        # v = napari.Viewer()
251        # v.add_image(hmap)
252        # # v.add_labels(seeds)
253        # v.add_labels(seg)
254        # v.title = label_id
255        # napari.run()
256
257        refined_vesicles[bb][seg] = label_id
258
259    props = regionprops(vesicles)
260    # fit_vesicle(props[1])
261    n_threads = 8
262    with futures.ThreadPoolExecutor(n_threads) as tp:
263        list(tqdm(tp.map(fit_vesicle, props), total=len(props), disable=False, desc="refine vesicles"))
264
265    return refined_vesicles
FILTERS = ('sobel', 'laplace', 'ggm', 'structure-tensor', 'sato')
def edge_filter( data: numpy.ndarray, sigma: float, method: str = 'sato', per_slice: bool = True, n_threads: Optional[int] = None) -> numpy.ndarray:
32def edge_filter(
33    data: np.ndarray,
34    sigma: float,
35    method: str = "sato",
36    per_slice: bool = True,
37    n_threads: Optional[int] = None,
38) -> np.ndarray:
39    """Find edges in the image data.
40
41    Args:
42        data: The input data.
43        sigma: The smoothing factor applied before the edge filter.
44        method: The method for finding edges. The following methods are supported:
45            - "sobel": Edges are found by smoothing the data and then applying a sobel filter.
46            - "laplace": Edges are found with a laplacian of gaussian filter.
47            - "ggm": Edges are found with a gaussian gradient magnitude filter.
48            - "structure-tensor": Edges are found based on the 1st eigenvalue of the structure tensor.
49            - "sato": Edges are found with a sato-filter, followed by smoothing and leveling.
50        per_slice: Compute the filter per slice instead of for the whole volume.
51        n_threads: Number of threads for parallel computation over the slices.
52
53    Returns:
54        Edge filter response.
55    """
56    if method not in FILTERS:
57        raise ValueError(f"Invalid edge filter method: {method}. Expect one of {FILTERS}.")
58
59    if per_slice and data.ndim == 3:
60        n_threads = mp.cpu_count() if n_threads is None else n_threads
61        filter_func = partial(edge_filter, sigma=sigma, method=method, per_slice=False)
62        with futures.ThreadPoolExecutor(n_threads) as tp:
63            edge_map = list(tp.map(filter_func, data))
64        edge_map = np.stack(edge_map)
65        return edge_map
66
67    if method == "sobel":
68        edge_map = gaussian(data, sigma=sigma)
69        edge_map = sobel(edge_map)
70    elif method == "laplace":
71        edge_map = bic.filters.laplacian_of_gaussian(data.astype("float32"), sigma)
72    elif method == "ggm":
73        edge_map = bic.filters.gaussian_gradient_magnitude(data.astype("float32"), sigma)
74    elif method == "structure-tensor":
75        inner_scale, outer_scale = sigma, sigma * 0.5
76        edge_map = bic.filters.structure_tensor_eigenvalues(
77            data.astype("float32"), inner_scale, outer_scale
78        )[..., 0]
79    elif method == "sato":
80        edge_map = _sato_filter(data, sigma)
81
82    return edge_map

Find edges in the image data.

Arguments:
  • data: The input data.
  • sigma: The smoothing factor applied before the edge filter.
  • method: The method for finding edges. The following methods are supported:
    • "sobel": Edges are found by smoothing the data and then applying a sobel filter.
    • "laplace": Edges are found with a laplacian of gaussian filter.
    • "ggm": Edges are found with a gaussian gradient magnitude filter.
    • "structure-tensor": Edges are found based on the 1st eigenvalue of the structure tensor.
    • "sato": Edges are found with a sato-filter, followed by smoothing and leveling.
  • per_slice: Compute the filter per slice instead of for the whole volume.
  • n_threads: Number of threads for parallel computation over the slices.
Returns:

Edge filter response.

def check_filters( data: numpy.ndarray, filters: List[str] = ('sobel', 'laplace', 'ggm', 'structure-tensor', 'sato'), sigmas: List[float] = [2.0, 4.0], show: bool = True) -> Dict[str, numpy.ndarray]:
 85def check_filters(
 86    data: np.ndarray,
 87    filters: List[str] = FILTERS,
 88    sigmas: List[float] = [2.0, 4.0],
 89    show: bool = True,
 90) -> Dict[str, np.ndarray]:
 91    """Apply different edge filters to the input data.
 92
 93    Args:
 94        data: The input data volume.
 95        filters: The names of edge filters to apply.
 96            The filter names must match `method` in `edge_filter`.
 97        sigmas: The sigma values to use for the filters.
 98        show: Whether to show the filter responses in napari.
 99
100    Returns:
101        Dictionary with the filter responses.
102    """
103
104    n_filters = len(filters) * len(sigmas)
105    pbar = tqdm(total=n_filters, desc="Compute filters")
106
107    responses = {}
108    for filter_ in filters:
109        for sigma in sigmas:
110            name = f"{filter_}_{sigma}"
111            responses[name] = edge_filter(data, sigma, method=filter_)
112            pbar.update(1)
113
114    if show:
115        import napari
116
117        v = napari.Viewer()
118        v.add_image(data)
119        for name, response in responses.items():
120            v.add_image(response, name=name)
121        napari.run()
122
123    return responses

Apply different edge filters to the input data.

Arguments:
  • data: The input data volume.
  • filters: The names of edge filters to apply. The filter names must match method in edge_filter.
  • sigmas: The sigma values to use for the filters.
  • show: Whether to show the filter responses in napari.
Returns:

Dictionary with the filter responses.

def refine_vesicle_shapes( vesicles: numpy.ndarray, edge_map: numpy.ndarray, foreground_erosion: int = 2, background_erosion: int = 6, fit_to_outer_boundary: bool = False, return_seeds: bool = False, compactness: float = 1.0) -> numpy.ndarray:
126def refine_vesicle_shapes(
127    vesicles: np.ndarray,
128    edge_map: np.ndarray,
129    foreground_erosion: int = 2,
130    background_erosion: int = 6,
131    fit_to_outer_boundary: bool = False,
132    return_seeds: bool = False,
133    compactness: float = 1.0,
134) -> np.ndarray:
135    """Refine vesicle shapes by fitting vesicles to a boundary map.
136
137    This function erodes the segmented vesicles, and then fits them
138    to a bonudary using a seeded watershed. This is done with two watersheds,
139    one two separate foreground from background and one to separate vesicles within
140    the foreground.
141
142    Args:
143        vesicles: The vesicle segmentation.
144        edge_map: Volume with high intensities for vesicle membrane.
145            You can use `edge_filter` to compute this based on the tomogram.
146        foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
147        background_erosion: By how many pixels the background should be eroded in the seeds.
148        fit_to_outer_boundary: Whether to fit the seeds to the outer membrane by
149            applying a second edge filter to `edge_map`.
150        return_seeds: Whether to return the seeds used for the watershed.
151        compactness: The compactness parameter passed to the watershed function.
152            Higher compactness leads to more regular sized vesicles.
153
154    Returns:
155        The refined vesicles.
156    """
157
158    fg = vesicles != 0
159    if foreground_erosion > 0:
160        fg_seeds = binary_erosion(fg, iterations=foreground_erosion).astype("uint8")
161    else:
162        fg_seeds = fg.astype("uint8")
163    bg = vesicles == 0
164    bg_seeds = binary_erosion(bg, iterations=background_erosion).astype("uint8")
165    # Create 1 pixel wide mask and set to 1 and add to bg seed
166    # Create a 1-pixel wide boundary at the edges of the tomogram
167    boundary_mask = np.zeros_like(bg, dtype="uint8")
168
169    # Set the boundary to 1 along the edges of each dimension
170    boundary_mask[0, :, :] = 1
171    boundary_mask[-1, :, :] = 1
172    boundary_mask[:, 0, :] = 1
173    boundary_mask[:, -1, :] = 1
174    boundary_mask[:, :, 0] = 1
175    boundary_mask[:, :, -1] = 1
176
177    # Add the boundary to the background seeds without affecting existing seeds
178    bg_seeds = np.clip(bg_seeds + boundary_mask, 0, 1)  # Ensure values are either 0 or 1
179
180    if fit_to_outer_boundary:
181        outer_edge_map = edge_filter(edge_map, sigma=2)
182    else:
183        outer_edge_map = edge_map
184
185    seeds = fg_seeds + 2 * bg_seeds
186    refined_mask = watershed(outer_edge_map, seeds, compactness=compactness)
187    refined_mask[refined_mask == 2] = 0
188
189    refined_vesicles = watershed(edge_map, vesicles, mask=refined_mask, compactness=compactness)
190
191    if return_seeds:
192        return refined_vesicles, seeds
193    return refined_vesicles

Refine vesicle shapes by fitting vesicles to a boundary map.

This function erodes the segmented vesicles, and then fits them to a bonudary using a seeded watershed. This is done with two watersheds, one two separate foreground from background and one to separate vesicles within the foreground.

Arguments:
  • vesicles: The vesicle segmentation.
  • edge_map: Volume with high intensities for vesicle membrane. You can use edge_filter to compute this based on the tomogram.
  • foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
  • background_erosion: By how many pixels the background should be eroded in the seeds.
  • fit_to_outer_boundary: Whether to fit the seeds to the outer membrane by applying a second edge filter to edge_map.
  • return_seeds: Whether to return the seeds used for the watershed.
  • compactness: The compactness parameter passed to the watershed function. Higher compactness leads to more regular sized vesicles.
Returns:

The refined vesicles.

def refine_individual_vesicle_shapes( vesicles: numpy.ndarray, edge_map: numpy.ndarray, foreground_erosion: int = 4, background_erosion: int = 8) -> numpy.ndarray:
196def refine_individual_vesicle_shapes(
197    vesicles: np.ndarray,
198    edge_map: np.ndarray,
199    foreground_erosion: int = 4,
200    background_erosion: int = 8,
201) -> np.ndarray:
202    """Refine vesicle shapes by fitting vesicles to a boundary map.
203
204    This function erodes the segmented vesicles, and then fits them
205    to a bonudary using a seeded watershed. This is done individually for each vesicle.
206
207    Args:
208        vesicles: The vesicle segmentation.
209        edge_map: Volume with high intensities for vesicle membrane.
210            You can use `edge_filter` to compute this based on the tomogram.
211        foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
212        background_erosion: By how many pixels the background should be eroded in the seeds.
213    Returns:
214        The refined vesicles.
215    """
216
217    refined_vesicles = np.zeros_like(vesicles)
218    halo = [0, 12, 12]
219
220    def fit_vesicle(prop):
221        label_id = prop.label
222
223        bb = prop.bbox
224        bb = tuple(
225            slice(max(start - ha, 0), min(stop + ha, sh)) for start, stop, ha, sh in
226            zip(bb[:3], bb[3:], halo, vesicles.shape)
227        )
228        vesicle_sub = vesicles[bb]
229
230        vesicle_mask = vesicle_sub == label_id
231        hmap = edge_map[bb]
232
233        # Do refinement in 2d to avoid effects caused by anisotropy.
234        seg = np.zeros_like(vesicle_mask)
235        for z in range(seg.shape[0]):
236            m = vesicle_mask[z]
237            fg_seed = binary_erosion(m, iterations=foreground_erosion).astype("uint8")
238            if fg_seed.sum() == 0:
239                seg[z][m] = 1
240                continue
241            bg_seed = (~binary_dilation(m, iterations=background_erosion)).astype("uint8")
242            # Make sure all other vesicles in the local bbox are part of the bg seed,
243            # to avoid leaking into other vesicles.
244            bg_seed[(vesicle_sub[z] != 0) & (vesicle_sub[z] != label_id)] = 1
245
246            # Run seeded watershed to fit the shapes.
247            seeds = fg_seed + 2 * bg_seed
248            seg[z] = watershed(hmap[z], seeds) == 1
249
250        # import napari
251        # v = napari.Viewer()
252        # v.add_image(hmap)
253        # # v.add_labels(seeds)
254        # v.add_labels(seg)
255        # v.title = label_id
256        # napari.run()
257
258        refined_vesicles[bb][seg] = label_id
259
260    props = regionprops(vesicles)
261    # fit_vesicle(props[1])
262    n_threads = 8
263    with futures.ThreadPoolExecutor(n_threads) as tp:
264        list(tqdm(tp.map(fit_vesicle, props), total=len(props), disable=False, desc="refine vesicles"))
265
266    return refined_vesicles

Refine vesicle shapes by fitting vesicles to a boundary map.

This function erodes the segmented vesicles, and then fits them to a bonudary using a seeded watershed. This is done individually for each vesicle.

Arguments:
  • vesicles: The vesicle segmentation.
  • edge_map: Volume with high intensities for vesicle membrane. You can use edge_filter to compute this based on the tomogram.
  • foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
  • background_erosion: By how many pixels the background should be eroded in the seeds.
Returns:

The refined vesicles.