
  1import multiprocessing as mp
  2from concurrent import futures
  3from functools import partial
  4from typing import List, Dict, Optional
  6import numpy as np
  8from scipy.ndimage import binary_erosion, binary_dilation
  9from skimage import img_as_ubyte
 10from skimage.filters import gaussian, rank, sato, sobel
 11from skimage.measure import regionprops
 12from skimage.morphology import disk
 13from skimage.segmentation import watershed
 14from tqdm import tqdm
 17    import vigra
 18except ImportError:
 19    vigra = None
 21FILTERS = ("sobel", "laplace", "ggm", "structure-tensor", "sato")
 24def _sato_filter(raw, sigma, max_window=16):
 25    if raw.ndim != 2:
 26        raise NotImplementedError("The sato filter is only implemented for 2D data.")
 27    hmap = sato(raw)
 28    hmap = gaussian(hmap, sigma=sigma)
 29    hmap -= hmap.min()
 30    hmap /= hmap.max()
 31    hmap = rank.autolevel(img_as_ubyte(hmap), disk(max_window)).astype("float") / 255
 32    return hmap
 35def edge_filter(
 36    data: np.ndarray,
 37    sigma: float,
 38    method: str = "sato",
 39    per_slice: bool = True,
 40    n_threads: Optional[int] = None,
 41) -> np.ndarray:
 42    """Find edges in the image data.
 44    Args:
 45        data: The input data.
 46        sigma: The smoothing factor applied before the edge filter.
 47        method: The method for finding edges. The following methods are supported:
 48            - "sobel": Edges are found by smoothing the data and then applying a sobel filter.
 49            - "laplace": Edges are found with a laplacian of gaussian filter.
 50            - "ggm": Edges are found with a gaussian gradient magnitude filter.
 51            - "structure-tensor": Edges are found based on the 2nd eigenvalue of the structure tensor.
 52            - "sato": Edges are found with a sato-filter, followed by smoothing and leveling.
 53        per_slice: Compute the filter per slice instead of for the whole volume.
 54        n_threads: Number of threads for parallel computation over the slices.
 56    Returns:
 57        Edge filter response.
 58    """
 59    if method not in FILTERS:
 60        raise ValueError(f"Invalid edge filter method: {method}. Expect one of {FILTERS}.")
 61    if method in FILTERS[1:] and vigra is None:
 62        raise ValueError(f"Filter {method} requires vigra.")
 64    if per_slice and data.ndim == 3:
 65        n_threads = mp.cpu_count() if n_threads is None else n_threads
 66        filter_func = partial(edge_filter, sigma=sigma, method=method, per_slice=False)
 67        with futures.ThreadPoolExecutor(n_threads) as tp:
 68            edge_map = list(, data))
 69        edge_map = np.stack(edge_map)
 70        return edge_map
 72    if method == "sobel":
 73        edge_map = gaussian(data, sigma=sigma)
 74        edge_map = sobel(edge_map)
 75    elif method == "laplace":
 76        edge_map = vigra.filters.laplacianOfGaussian(data.astype("float32"), sigma)
 77    elif method == "ggm":
 78        edge_map = vigra.filters.gaussianGradientMagnitude(data.astype("float32"), sigma)
 79    elif method == "structure-tensor":
 80        inner_scale, outer_scale = sigma, sigma * 0.5
 81        edge_map = vigra.filters.structureTensorEigenvalues(
 82            data.astype("float32"), innerScale=inner_scale, outerScale=outer_scale
 83        )[..., 1]
 84    elif method == "sato":
 85        edge_map = _sato_filter(data, sigma)
 87    return edge_map
 90def check_filters(
 91    data: np.ndarray,
 92    filters: List[str] = FILTERS,
 93    sigmas: List[float] = [2.0, 4.0],
 94    show: bool = True,
 95) -> Dict[str, np.ndarray]:
 96    """Apply different edge filters to the input data.
 98    Args:
 99        data: The input data volume.
100        filters: The names of edge filters to apply.
101            The filter names must match `method` in `edge_filter`.
102        sigmas: The sigma values to use for the filters.
103        show: Whether to show the filter responses in napari.
105    Returns:
106        Dictionary with the filter responses.
107    """
109    n_filters = len(filters) * len(sigmas)
110    pbar = tqdm(total=n_filters, desc="Compute filters")
112    responses = {}
113    for filter_ in filters:
114        for sigma in sigmas:
115            name = f"{filter_}_{sigma}"
116            responses[name] = edge_filter(data, sigma, method=filter_)
117            pbar.update(1)
119    if show:
120        import napari
122        v = napari.Viewer()
123        v.add_image(data)
124        for name, response in responses.items():
125            v.add_image(response, name=name)
128    return responses
131def refine_vesicle_shapes(
132    vesicles: np.ndarray,
133    edge_map: np.ndarray,
134    foreground_erosion: int = 2,
135    background_erosion: int = 6,
136    fit_to_outer_boundary: bool = False,
137    return_seeds: bool = False,
138    compactness: float = 1.0,
139) -> np.ndarray:
140    """Refine vesicle shapes by fitting vesicles to a boundary map.
142    This function erodes the segmented vesicles, and then fits them
143    to a bonudary using a seeded watershed. This is done with two watersheds,
144    one two separate foreground from background and one to separate vesicles within
145    the foreground.
147    Args:
148        vesicles: The vesicle segmentation.
149        edge_map: Volume with high intensities for vesicle membrane.
150            You can use `edge_filter` to compute this based on the tomogram.
151        foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
152        background_erosion: By how many pixels the background should be eroded in the seeds.
153        fit_to_outer_boundary: Whether to fit the seeds to the outer membrane by
154            applying a second edge filter to `edge_map`.
155        return_seeds: Whether to return the seeds used for the watershed.
156        compactness: The compactness parameter passed to the watershed function.
157            Higher compactness leads to more regular sized vesicles.
159    Returns:
160        The refined vesicles.
161    """
163    fg = vesicles != 0
164    if foreground_erosion > 0:
165        fg_seeds = binary_erosion(fg, iterations=foreground_erosion).astype("uint8")
166    else:
167        fg_seeds = fg.astype("uint8")
168    bg = vesicles == 0
169    bg_seeds = binary_erosion(bg, iterations=background_erosion).astype("uint8")
170    # Create 1 pixel wide mask and set to 1 and add to bg seed
171    # Create a 1-pixel wide boundary at the edges of the tomogram
172    boundary_mask = np.zeros_like(bg, dtype="uint8")
174    # Set the boundary to 1 along the edges of each dimension
175    boundary_mask[0, :, :] = 1
176    boundary_mask[-1, :, :] = 1
177    boundary_mask[:, 0, :] = 1
178    boundary_mask[:, -1, :] = 1
179    boundary_mask[:, :, 0] = 1
180    boundary_mask[:, :, -1] = 1
182    # Add the boundary to the background seeds without affecting existing seeds
183    bg_seeds = np.clip(bg_seeds + boundary_mask, 0, 1)  # Ensure values are either 0 or 1
185    if fit_to_outer_boundary:
186        outer_edge_map = edge_filter(edge_map, sigma=2)
187    else:
188        outer_edge_map = edge_map
190    seeds = fg_seeds + 2 * bg_seeds
191    refined_mask = watershed(outer_edge_map, seeds, compactness=compactness)
192    refined_mask[refined_mask == 2] = 0
194    refined_vesicles = watershed(edge_map, vesicles, mask=refined_mask, compactness=compactness)
196    if return_seeds:
197        return refined_vesicles, seeds
198    return refined_vesicles
201def refine_individual_vesicle_shapes(
202    vesicles: np.ndarray,
203    edge_map: np.ndarray,
204    foreground_erosion: int = 4,
205    background_erosion: int = 8,
206) -> np.ndarray:
207    """Refine vesicle shapes by fitting vesicles to a boundary map.
209    This function erodes the segmented vesicles, and then fits them
210    to a bonudary using a seeded watershed. This is done individually for each vesicle.
212    Args:
213        vesicles: The vesicle segmentation.
214        edge_map: Volume with high intensities for vesicle membrane.
215            You can use `edge_filter` to compute this based on the tomogram.
216        foreground_erosion: By how many pixels the foreground should be eroded in the seeds.
217        background_erosion: By how many pixels the background should be eroded in the seeds.
218    Returns:
219        The refined vesicles.
220    """
222    refined_vesicles = np.zeros_like(vesicles)
223    halo = [0, 12, 12]
225    def fit_vesicle(prop):
226        label_id = prop.label
228        bb = prop.bbox
229        bb = tuple(
230            slice(max(start - ha, 0), min(stop + ha, sh)) for start, stop, ha, sh in
231            zip(bb[:3], bb[3:], halo, vesicles.shape)
232        )
233        vesicle_sub = vesicles[bb]
235        vesicle_mask = vesicle_sub == label_id
236        hmap = edge_map[bb]
238        # Do refinement in 2d to avoid effects caused by anisotropy.
239        seg = np.zeros_like(vesicle_mask)
240        for z in range(seg.shape[0]):
241            m = vesicle_mask[z]
242            fg_seed = binary_erosion(m, iterations=foreground_erosion).astype("uint8")
243            if fg_seed.sum() == 0:
244                seg[z][m] = 1
245                continue
246            bg_seed = (~binary_dilation(m, iterations=background_erosion)).astype("uint8")
247            # Make sure all other vesicles in the local bbox are part of the bg seed,
248            # to avoid leaking into other vesicles.
249            bg_seed[(vesicle_sub[z] != 0) & (vesicle_sub[z] != label_id)] = 1
251            # Run seeded watershed to fit the shapes.
252            seeds = fg_seed + 2 * bg_seed
253            seg[z] = watershed(hmap[z], seeds) == 1
255        # import napari
256        # v = napari.Viewer()
257        # v.add_image(hmap)
258        # # v.add_labels(seeds)
259        # v.add_labels(seg)
260        # v.title = label_id
261        #
263        refined_vesicles[bb][seg] = label_id
265    props = regionprops(vesicles)
266    # fit_vesicle(props[1])
267    n_threads = 8
268    with futures.ThreadPoolExecutor(n_threads) as tp:
269        list(tqdm(, props), total=len(props), disable=False, desc="refine vesicles"))
271    return refined_vesicles
