synapse_net.ground_truth.shape_refinement

  1import multiprocessing as mp
  2from concurrent import futures
  3from functools import partial
  4from typing import List, Dict, Optional
  5
  6import numpy as np
  7
  8from scipy.ndimage import binary_erosion, binary_dilation
  9from skimage import img_as_ubyte
 10from skimage.filters import gaussian, rank, sato, sobel
 11from skimage.measure import regionprops
 12from skimage.morphology import disk
 13from skimage.segmentation import watershed
 14from tqdm import tqdm
 15
 16try:
 17    import vigra
 18except ImportError:
 19    vigra = None
 20
 21FILTERS = ("sobel", "laplace", "ggm", "structure-tensor", "sato")
 22
 23
 24def _sato_filter(raw, sigma, max_window=16):
 25    if raw.ndim != 2:
 26        raise NotImplementedError("The sato filter is only implemented for 2D data.")
 27    hmap = sato(raw)
 28    hmap = gaussian(hmap, sigma=sigma)
 29    hmap -= hmap.min()
 30    hmap /= hmap.max()
 31    hmap = rank.autolevel(img_as_ubyte(hmap), disk(max_window)).astype("float") / 255
 32    return hmap
 33
 34
 35def edge_filter(
 36    data: np.ndarray,
 37    sigma: float,
 38    method: str = "sato",
 39    per_slice: bool = True,
 40    n_threads: Optional[int] = None,
 41) -> np.ndarray:
 42    """Find edges in the image data.
 43
 44    Args:
 45        data: The input data.
 46        sigma: The smoothing factor applied before the edge filter.
 47        method: The method for finding edges. The following methods are supported:
 48            - "sobel": Edges are found by smoothing the data and then applying a sobel filter.
 49            - "laplace": Edges are found with a laplacian of gaussian filter.
 50            - "ggm": Edges are found with a gaussian gradient magnitude filter.
 51            - "structure-tensor": Edges are found based on the 2nd eigenvalue of the structure tensor.
 52            - "sato": Edges are found with a sato-filter, followed by smoothing and leveling.
 53        per_slice: Compute the filter per slice instead of for the whole volume.
 54        n_threads: Number of threads for parallel computation over the slices.
 55
 56    Returns:
 57        Edge filter response.
 58    """
 59    if method not in FILTERS:
 60        raise ValueError(f"Invalid edge filter method: {method}. Expect one of {FILTERS}.")
 61    if method in FILTERS[1:] and vigra is None:
 62        raise ValueError(f"Filter {method} requires vigra.")
 63
 64    if per_slice and data.ndim == 3:
 65        n_threads = mp.cpu_count() if n_threads is None else n_threads
 66        filter_func = partial(edge_filter, sigma=sigma, method=method, per_slice=False)
 67        with futures.ThreadPoolExecutor(n_threads) as tp:
 68            edge_map = list(tp.map(filter_func, data))
 69        edge_map = np.stack(edge_map)
 70        return edge_map
 71
 72    if method == "sobel":
 73        edge_map = gaussian(data, sigma=sigma)
 74        edge_map = sobel(edge_map)
 75    elif method == "laplace":
 76        edge_map = vigra.filters.laplacianOfGaussian(data.astype("float32"), sigma)
 77    elif method == "ggm":
 78        edge_map = vigra.filters.gaussianGradientMagnitude(data.astype("float32"), sigma)
 79    elif method == "structure-tensor":
 80        inner_scale, outer_scale = sigma, sigma * 0.5
 81        edge_map = vigra.filters.structureTensorEigenvalues(
 82            data.astype("float32"), innerScale=inner_scale, outerScale=outer_scale
 83        )[..., 1]
 84    elif method == "sato":
 85        edge_map = _sato_filter(data, sigma)
 86
 87    return edge_map
 88
 89
 90def check_filters(
 91    data: np.ndarray,
 92    filters: List[str] = FILTERS,
 93    sigmas: List[float] = [2.0, 4.0],
 94    show: bool = True,
 95) -> Dict[str, np.ndarray]:
 96    """Apply different edge filters to the input data.
 97
 98    Args:
 99        data: The input data volume.
100        filters: The names of edge filters to apply.
101            The filter names must match `method` in `edge_filter`.
102        sigmas: The sigma values to use for the filters.
103        show: Whether to show the filter responses in napari.
104
105    Returns:
106        Dictionary with the filter responses.
107    """
108
109    n_filters = len(filters) * len(sigmas)
110    pbar = tqdm(total=n_filters, desc="Compute filters")
111
112    responses = {}
113    for filter_ in filters:
114        for sigma in sigmas:
115            name = f"{filter_}_{sigma}"
116            responses[name] = edge_filter(data, sigma, method=filter_)
117            pbar.update(1)
118
119    if show:
120        import napari
121
122        v = napari.Viewer()
123        v.add_image(data)
124        for name, response in responses.items():
125            v.add_image(response, name=name)
126        napari.run()
127
128    return responses
129
130
131def refine_vesicle_shapes(
132    vesicles: np.ndarray,
133    edge_map: np.ndarray,
134    foreground_erosion: int = 2,
135    background_erosion: int = 6,
136    fit_to_outer_boundary: bool = False,
137    return_seeds: bool = False,
138    compactness: float = 1.0,
139) -> np.ndarray:
140    """Refine vesicle shapes by fitting vesicles to a boundary map.
141
142    This function erodes the segmented vesicles, and then fits them
143    to a bonudary using a seeded watershed. This is done with two watersheds,
144    one two separate foreground from background and one to separate vesicles within
145    the foreground.
146
147    Args:
148        vesicles: The vesicle segmentation.
149        edge_map: Volume with high intensities for vesicle membrane.
150            You can use `edge_filter` to compute this based on the tomogram.
151        foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
152        background_erosion: By how many pixels the background should be eroded in the seeds.
153        fit_to_outer_boundary: Whether to fit the seeds to the outer membrane by
154            applying a second edge filter to `edge_map`.
155        return_seeds: Whether to return the seeds used for the watershed.
156        compactness: The compactness parameter passed to the watershed function.
157            Higher compactness leads to more regular sized vesicles.
158
159    Returns:
160        The refined vesicles.
161    """
162
163    fg = vesicles != 0
164    if foreground_erosion > 0:
165        fg_seeds = binary_erosion(fg, iterations=foreground_erosion).astype("uint8")
166    else:
167        fg_seeds = fg.astype("uint8")
168    bg = vesicles == 0
169    bg_seeds = binary_erosion(bg, iterations=background_erosion).astype("uint8")
170    # Create 1 pixel wide mask and set to 1 and add to bg seed
171    # Create a 1-pixel wide boundary at the edges of the tomogram
172    boundary_mask = np.zeros_like(bg, dtype="uint8")
173
174    # Set the boundary to 1 along the edges of each dimension
175    boundary_mask[0, :, :] = 1
176    boundary_mask[-1, :, :] = 1
177    boundary_mask[:, 0, :] = 1
178    boundary_mask[:, -1, :] = 1
179    boundary_mask[:, :, 0] = 1
180    boundary_mask[:, :, -1] = 1
181
182    # Add the boundary to the background seeds without affecting existing seeds
183    bg_seeds = np.clip(bg_seeds + boundary_mask, 0, 1)  # Ensure values are either 0 or 1
184
185    if fit_to_outer_boundary:
186        outer_edge_map = edge_filter(edge_map, sigma=2)
187    else:
188        outer_edge_map = edge_map
189
190    seeds = fg_seeds + 2 * bg_seeds
191    refined_mask = watershed(outer_edge_map, seeds, compactness=compactness)
192    refined_mask[refined_mask == 2] = 0
193
194    refined_vesicles = watershed(edge_map, vesicles, mask=refined_mask, compactness=compactness)
195
196    if return_seeds:
197        return refined_vesicles, seeds
198    return refined_vesicles
199
200
201def refine_individual_vesicle_shapes(
202    vesicles: np.ndarray,
203    edge_map: np.ndarray,
204    foreground_erosion: int = 4,
205    background_erosion: int = 8,
206) -> np.ndarray:
207    """Refine vesicle shapes by fitting vesicles to a boundary map.
208
209    This function erodes the segmented vesicles, and then fits them
210    to a bonudary using a seeded watershed. This is done individually for each vesicle.
211
212    Args:
213        vesicles: The vesicle segmentation.
214        edge_map: Volume with high intensities for vesicle membrane.
215            You can use `edge_filter` to compute this based on the tomogram.
216        foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
217        background_erosion: By how many pixels the background should be eroded in the seeds.
218    Returns:
219        The refined vesicles.
220    """
221
222    refined_vesicles = np.zeros_like(vesicles)
223    halo = [0, 12, 12]
224
225    def fit_vesicle(prop):
226        label_id = prop.label
227
228        bb = prop.bbox
229        bb = tuple(
230            slice(max(start - ha, 0), min(stop + ha, sh)) for start, stop, ha, sh in
231            zip(bb[:3], bb[3:], halo, vesicles.shape)
232        )
233        vesicle_sub = vesicles[bb]
234
235        vesicle_mask = vesicle_sub == label_id
236        hmap = edge_map[bb]
237
238        # Do refinement in 2d to avoid effects caused by anisotropy.
239        seg = np.zeros_like(vesicle_mask)
240        for z in range(seg.shape[0]):
241            m = vesicle_mask[z]
242            fg_seed = binary_erosion(m, iterations=foreground_erosion).astype("uint8")
243            if fg_seed.sum() == 0:
244                seg[z][m] = 1
245                continue
246            bg_seed = (~binary_dilation(m, iterations=background_erosion)).astype("uint8")
247            # Make sure all other vesicles in the local bbox are part of the bg seed,
248            # to avoid leaking into other vesicles.
249            bg_seed[(vesicle_sub[z] != 0) & (vesicle_sub[z] != label_id)] = 1
250
251            # Run seeded watershed to fit the shapes.
252            seeds = fg_seed + 2 * bg_seed
253            seg[z] = watershed(hmap[z], seeds) == 1
254
255        # import napari
256        # v = napari.Viewer()
257        # v.add_image(hmap)
258        # # v.add_labels(seeds)
259        # v.add_labels(seg)
260        # v.title = label_id
261        # napari.run()
262
263        refined_vesicles[bb][seg] = label_id
264
265    props = regionprops(vesicles)
266    # fit_vesicle(props[1])
267    n_threads = 8
268    with futures.ThreadPoolExecutor(n_threads) as tp:
269        list(tqdm(tp.map(fit_vesicle, props), total=len(props), disable=False, desc="refine vesicles"))
270
271    return refined_vesicles
FILTERS = ('sobel', 'laplace', 'ggm', 'structure-tensor', 'sato')
def edge_filter( data: numpy.ndarray, sigma: float, method: str = 'sato', per_slice: bool = True, n_threads: Optional[int] = None) -> numpy.ndarray:
36def edge_filter(
37    data: np.ndarray,
38    sigma: float,
39    method: str = "sato",
40    per_slice: bool = True,
41    n_threads: Optional[int] = None,
42) -> np.ndarray:
43    """Find edges in the image data.
44
45    Args:
46        data: The input data.
47        sigma: The smoothing factor applied before the edge filter.
48        method: The method for finding edges. The following methods are supported:
49            - "sobel": Edges are found by smoothing the data and then applying a sobel filter.
50            - "laplace": Edges are found with a laplacian of gaussian filter.
51            - "ggm": Edges are found with a gaussian gradient magnitude filter.
52            - "structure-tensor": Edges are found based on the 2nd eigenvalue of the structure tensor.
53            - "sato": Edges are found with a sato-filter, followed by smoothing and leveling.
54        per_slice: Compute the filter per slice instead of for the whole volume.
55        n_threads: Number of threads for parallel computation over the slices.
56
57    Returns:
58        Edge filter response.
59    """
60    if method not in FILTERS:
61        raise ValueError(f"Invalid edge filter method: {method}. Expect one of {FILTERS}.")
62    if method in FILTERS[1:] and vigra is None:
63        raise ValueError(f"Filter {method} requires vigra.")
64
65    if per_slice and data.ndim == 3:
66        n_threads = mp.cpu_count() if n_threads is None else n_threads
67        filter_func = partial(edge_filter, sigma=sigma, method=method, per_slice=False)
68        with futures.ThreadPoolExecutor(n_threads) as tp:
69            edge_map = list(tp.map(filter_func, data))
70        edge_map = np.stack(edge_map)
71        return edge_map
72
73    if method == "sobel":
74        edge_map = gaussian(data, sigma=sigma)
75        edge_map = sobel(edge_map)
76    elif method == "laplace":
77        edge_map = vigra.filters.laplacianOfGaussian(data.astype("float32"), sigma)
78    elif method == "ggm":
79        edge_map = vigra.filters.gaussianGradientMagnitude(data.astype("float32"), sigma)
80    elif method == "structure-tensor":
81        inner_scale, outer_scale = sigma, sigma * 0.5
82        edge_map = vigra.filters.structureTensorEigenvalues(
83            data.astype("float32"), innerScale=inner_scale, outerScale=outer_scale
84        )[..., 1]
85    elif method == "sato":
86        edge_map = _sato_filter(data, sigma)
87
88    return edge_map

Find edges in the image data.

Arguments:
  • data: The input data.
  • sigma: The smoothing factor applied before the edge filter.
  • method: The method for finding edges. The following methods are supported:
    • "sobel": Edges are found by smoothing the data and then applying a sobel filter.
    • "laplace": Edges are found with a laplacian of gaussian filter.
    • "ggm": Edges are found with a gaussian gradient magnitude filter.
    • "structure-tensor": Edges are found based on the 2nd eigenvalue of the structure tensor.
    • "sato": Edges are found with a sato-filter, followed by smoothing and leveling.
  • per_slice: Compute the filter per slice instead of for the whole volume.
  • n_threads: Number of threads for parallel computation over the slices.
Returns:

Edge filter response.

def check_filters( data: numpy.ndarray, filters: List[str] = ('sobel', 'laplace', 'ggm', 'structure-tensor', 'sato'), sigmas: List[float] = [2.0, 4.0], show: bool = True) -> Dict[str, numpy.ndarray]:
 91def check_filters(
 92    data: np.ndarray,
 93    filters: List[str] = FILTERS,
 94    sigmas: List[float] = [2.0, 4.0],
 95    show: bool = True,
 96) -> Dict[str, np.ndarray]:
 97    """Apply different edge filters to the input data.
 98
 99    Args:
100        data: The input data volume.
101        filters: The names of edge filters to apply.
102            The filter names must match `method` in `edge_filter`.
103        sigmas: The sigma values to use for the filters.
104        show: Whether to show the filter responses in napari.
105
106    Returns:
107        Dictionary with the filter responses.
108    """
109
110    n_filters = len(filters) * len(sigmas)
111    pbar = tqdm(total=n_filters, desc="Compute filters")
112
113    responses = {}
114    for filter_ in filters:
115        for sigma in sigmas:
116            name = f"{filter_}_{sigma}"
117            responses[name] = edge_filter(data, sigma, method=filter_)
118            pbar.update(1)
119
120    if show:
121        import napari
122
123        v = napari.Viewer()
124        v.add_image(data)
125        for name, response in responses.items():
126            v.add_image(response, name=name)
127        napari.run()
128
129    return responses

Apply different edge filters to the input data.

Arguments:
  • data: The input data volume.
  • filters: The names of edge filters to apply. The filter names must match method in edge_filter.
  • sigmas: The sigma values to use for the filters.
  • show: Whether to show the filter responses in napari.
Returns:

Dictionary with the filter responses.

def refine_vesicle_shapes( vesicles: numpy.ndarray, edge_map: numpy.ndarray, foreground_erosion: int = 2, background_erosion: int = 6, fit_to_outer_boundary: bool = False, return_seeds: bool = False, compactness: float = 1.0) -> numpy.ndarray:
132def refine_vesicle_shapes(
133    vesicles: np.ndarray,
134    edge_map: np.ndarray,
135    foreground_erosion: int = 2,
136    background_erosion: int = 6,
137    fit_to_outer_boundary: bool = False,
138    return_seeds: bool = False,
139    compactness: float = 1.0,
140) -> np.ndarray:
141    """Refine vesicle shapes by fitting vesicles to a boundary map.
142
143    This function erodes the segmented vesicles, and then fits them
144    to a bonudary using a seeded watershed. This is done with two watersheds,
145    one two separate foreground from background and one to separate vesicles within
146    the foreground.
147
148    Args:
149        vesicles: The vesicle segmentation.
150        edge_map: Volume with high intensities for vesicle membrane.
151            You can use `edge_filter` to compute this based on the tomogram.
152        foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
153        background_erosion: By how many pixels the background should be eroded in the seeds.
154        fit_to_outer_boundary: Whether to fit the seeds to the outer membrane by
155            applying a second edge filter to `edge_map`.
156        return_seeds: Whether to return the seeds used for the watershed.
157        compactness: The compactness parameter passed to the watershed function.
158            Higher compactness leads to more regular sized vesicles.
159
160    Returns:
161        The refined vesicles.
162    """
163
164    fg = vesicles != 0
165    if foreground_erosion > 0:
166        fg_seeds = binary_erosion(fg, iterations=foreground_erosion).astype("uint8")
167    else:
168        fg_seeds = fg.astype("uint8")
169    bg = vesicles == 0
170    bg_seeds = binary_erosion(bg, iterations=background_erosion).astype("uint8")
171    # Create 1 pixel wide mask and set to 1 and add to bg seed
172    # Create a 1-pixel wide boundary at the edges of the tomogram
173    boundary_mask = np.zeros_like(bg, dtype="uint8")
174
175    # Set the boundary to 1 along the edges of each dimension
176    boundary_mask[0, :, :] = 1
177    boundary_mask[-1, :, :] = 1
178    boundary_mask[:, 0, :] = 1
179    boundary_mask[:, -1, :] = 1
180    boundary_mask[:, :, 0] = 1
181    boundary_mask[:, :, -1] = 1
182
183    # Add the boundary to the background seeds without affecting existing seeds
184    bg_seeds = np.clip(bg_seeds + boundary_mask, 0, 1)  # Ensure values are either 0 or 1
185
186    if fit_to_outer_boundary:
187        outer_edge_map = edge_filter(edge_map, sigma=2)
188    else:
189        outer_edge_map = edge_map
190
191    seeds = fg_seeds + 2 * bg_seeds
192    refined_mask = watershed(outer_edge_map, seeds, compactness=compactness)
193    refined_mask[refined_mask == 2] = 0
194
195    refined_vesicles = watershed(edge_map, vesicles, mask=refined_mask, compactness=compactness)
196
197    if return_seeds:
198        return refined_vesicles, seeds
199    return refined_vesicles

Refine vesicle shapes by fitting vesicles to a boundary map.

This function erodes the segmented vesicles, and then fits them to a bonudary using a seeded watershed. This is done with two watersheds, one two separate foreground from background and one to separate vesicles within the foreground.

Arguments:
  • vesicles: The vesicle segmentation.
  • edge_map: Volume with high intensities for vesicle membrane. You can use edge_filter to compute this based on the tomogram.
  • foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
  • background_erosion: By how many pixels the background should be eroded in the seeds.
  • fit_to_outer_boundary: Whether to fit the seeds to the outer membrane by applying a second edge filter to edge_map.
  • return_seeds: Whether to return the seeds used for the watershed.
  • compactness: The compactness parameter passed to the watershed function. Higher compactness leads to more regular sized vesicles.
Returns:

The refined vesicles.

def refine_individual_vesicle_shapes( vesicles: numpy.ndarray, edge_map: numpy.ndarray, foreground_erosion: int = 4, background_erosion: int = 8) -> numpy.ndarray:
202def refine_individual_vesicle_shapes(
203    vesicles: np.ndarray,
204    edge_map: np.ndarray,
205    foreground_erosion: int = 4,
206    background_erosion: int = 8,
207) -> np.ndarray:
208    """Refine vesicle shapes by fitting vesicles to a boundary map.
209
210    This function erodes the segmented vesicles, and then fits them
211    to a bonudary using a seeded watershed. This is done individually for each vesicle.
212
213    Args:
214        vesicles: The vesicle segmentation.
215        edge_map: Volume with high intensities for vesicle membrane.
216            You can use `edge_filter` to compute this based on the tomogram.
217        foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
218        background_erosion: By how many pixels the background should be eroded in the seeds.
219    Returns:
220        The refined vesicles.
221    """
222
223    refined_vesicles = np.zeros_like(vesicles)
224    halo = [0, 12, 12]
225
226    def fit_vesicle(prop):
227        label_id = prop.label
228
229        bb = prop.bbox
230        bb = tuple(
231            slice(max(start - ha, 0), min(stop + ha, sh)) for start, stop, ha, sh in
232            zip(bb[:3], bb[3:], halo, vesicles.shape)
233        )
234        vesicle_sub = vesicles[bb]
235
236        vesicle_mask = vesicle_sub == label_id
237        hmap = edge_map[bb]
238
239        # Do refinement in 2d to avoid effects caused by anisotropy.
240        seg = np.zeros_like(vesicle_mask)
241        for z in range(seg.shape[0]):
242            m = vesicle_mask[z]
243            fg_seed = binary_erosion(m, iterations=foreground_erosion).astype("uint8")
244            if fg_seed.sum() == 0:
245                seg[z][m] = 1
246                continue
247            bg_seed = (~binary_dilation(m, iterations=background_erosion)).astype("uint8")
248            # Make sure all other vesicles in the local bbox are part of the bg seed,
249            # to avoid leaking into other vesicles.
250            bg_seed[(vesicle_sub[z] != 0) & (vesicle_sub[z] != label_id)] = 1
251
252            # Run seeded watershed to fit the shapes.
253            seeds = fg_seed + 2 * bg_seed
254            seg[z] = watershed(hmap[z], seeds) == 1
255
256        # import napari
257        # v = napari.Viewer()
258        # v.add_image(hmap)
259        # # v.add_labels(seeds)
260        # v.add_labels(seg)
261        # v.title = label_id
262        # napari.run()
263
264        refined_vesicles[bb][seg] = label_id
265
266    props = regionprops(vesicles)
267    # fit_vesicle(props[1])
268    n_threads = 8
269    with futures.ThreadPoolExecutor(n_threads) as tp:
270        list(tqdm(tp.map(fit_vesicle, props), total=len(props), disable=False, desc="refine vesicles"))
271
272    return refined_vesicles

Refine vesicle shapes by fitting vesicles to a boundary map.

This function erodes the segmented vesicles, and then fits them to a bonudary using a seeded watershed. This is done individually for each vesicle.

Arguments:
  • vesicles: The vesicle segmentation.
  • edge_map: Volume with high intensities for vesicle membrane. You can use edge_filter to compute this based on the tomogram.
  • foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
  • background_erosion: By how many pixels the background should be eroded in the seeds.
Returns:

The refined vesicles.