synapse_net.inference.postprocessing.ribbon

 1import numpy as np
 2
 3import elf.parallel as parallel
 4from skimage.measure import regionprops
 5
 6
 7def segment_ribbon(
 8    ribbon_prediction: np.array,
 9    vesicle_segmentation: np.array,
10    n_slices_exclude: int,
11    n_ribbons: int,
12    max_vesicle_distance: int = 20,
13):
14    """Derive ribbon segmentation from ribbon predictions by
15    filtering out ribbons that don't have sufficient associated vesicles.
16
17    Args:
18        ribbon_prediction: Binary prediction for ribbons in the tomogram.
19        vesicle_segmentation: The vesicle segmentation.
20        n_slices_exclude: The number of slices to exclude on the top / bottom
21            in order to avoid segmentation errors due to imaging artifacts in top and bottom.
22        n_ribbons: The number of ribbons in the tomogram.
23        max_vesicle_distance: The maximal distance in pixels to associate a vesicle with a ribbon.
24    """
25    assert ribbon_prediction.shape == vesicle_segmentation.shape
26
27    original_shape = ribbon_prediction.shape
28
29    # Cut away the exclude mask.
30    slice_mask = np.s_[n_slices_exclude:-n_slices_exclude]
31    ribbon_prediction = ribbon_prediction[slice_mask]
32    vesicle_segmentation = vesicle_segmentation[slice_mask]
33
34    block_shape = (48, 384, 384)
35    # Label the ribbon predictions.
36    ribbon_segmentation = parallel.label(ribbon_prediction, block_shape=block_shape)
37
38    # Compute the distance to ribbon and the corresponding index.
39    halo = 3 * (max_vesicle_distance + 1,)
40    ribbon_dist, ribbon_idx = parallel.distance_transform(
41        ribbon_prediction == 0, return_indices=True, halo=halo, block_shape=block_shape
42    )
43
44    # Count the number of vesicles associated with each foreground object in the ribbon prediction.
45    vesicle_counts = {}
46    props = regionprops(vesicle_segmentation)
47    for prop in props:
48        bb = prop.bbox
49        bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]]
50
51        vesicle_mask = vesicle_segmentation[bb] == prop.label
52        dist, idx = ribbon_dist[bb].copy(), ribbon_idx[(slice(None),) + bb]
53        dist[~vesicle_mask] = np.inf
54
55        min_dist_point = np.argmin(dist)
56        min_dist_point = np.unravel_index(min_dist_point, vesicle_mask.shape)
57        min_dist = dist[min_dist_point]
58
59        if min_dist > max_vesicle_distance:
60            continue
61
62        ribbon_coord = tuple(idx_[min_dist_point] for idx_ in idx)
63        ribbon_id = ribbon_segmentation[ribbon_coord]
64        if ribbon_id == 0:
65            continue
66
67        if ribbon_id in vesicle_counts:
68            vesicle_counts[ribbon_id] += 1
69        else:
70            vesicle_counts[ribbon_id] = 1
71
72    # Create the output segmentation for the full output shape,
73    # keeping only the ribbons with sufficient number of associated vesicles.
74    full_ribbon_segmentation = np.zeros(original_shape, dtype="uint8")
75
76    if vesicle_counts:
77        ids = np.array(list(vesicle_counts.keys()))
78        counts = np.array(list(vesicle_counts.values()))
79    else:
80        print("No vesicles were matched to a ribbon")
81        print("Skipping postprocessing and returning the initial input")
82        full_ribbon_segmentation[slice_mask] = ribbon_prediction
83        return full_ribbon_segmentation
84
85    ids = ids[np.argsort(counts)[::-1]]
86
87    for output_id, ribbon_id in enumerate(ids[:n_ribbons], 1):
88        full_ribbon_segmentation[slice_mask][ribbon_segmentation == ribbon_id] = output_id
89
90    return full_ribbon_segmentation
def segment_ribbon( ribbon_prediction: <built-in function array>, vesicle_segmentation: <built-in function array>, n_slices_exclude: int, n_ribbons: int, max_vesicle_distance: int = 20):
 8def segment_ribbon(
 9    ribbon_prediction: np.array,
10    vesicle_segmentation: np.array,
11    n_slices_exclude: int,
12    n_ribbons: int,
13    max_vesicle_distance: int = 20,
14):
15    """Derive ribbon segmentation from ribbon predictions by
16    filtering out ribbons that don't have sufficient associated vesicles.
17
18    Args:
19        ribbon_prediction: Binary prediction for ribbons in the tomogram.
20        vesicle_segmentation: The vesicle segmentation.
21        n_slices_exclude: The number of slices to exclude on the top / bottom
22            in order to avoid segmentation errors due to imaging artifacts in top and bottom.
23        n_ribbons: The number of ribbons in the tomogram.
24        max_vesicle_distance: The maximal distance in pixels to associate a vesicle with a ribbon.
25    """
26    assert ribbon_prediction.shape == vesicle_segmentation.shape
27
28    original_shape = ribbon_prediction.shape
29
30    # Cut away the exclude mask.
31    slice_mask = np.s_[n_slices_exclude:-n_slices_exclude]
32    ribbon_prediction = ribbon_prediction[slice_mask]
33    vesicle_segmentation = vesicle_segmentation[slice_mask]
34
35    block_shape = (48, 384, 384)
36    # Label the ribbon predictions.
37    ribbon_segmentation = parallel.label(ribbon_prediction, block_shape=block_shape)
38
39    # Compute the distance to ribbon and the corresponding index.
40    halo = 3 * (max_vesicle_distance + 1,)
41    ribbon_dist, ribbon_idx = parallel.distance_transform(
42        ribbon_prediction == 0, return_indices=True, halo=halo, block_shape=block_shape
43    )
44
45    # Count the number of vesicles associated with each foreground object in the ribbon prediction.
46    vesicle_counts = {}
47    props = regionprops(vesicle_segmentation)
48    for prop in props:
49        bb = prop.bbox
50        bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]]
51
52        vesicle_mask = vesicle_segmentation[bb] == prop.label
53        dist, idx = ribbon_dist[bb].copy(), ribbon_idx[(slice(None),) + bb]
54        dist[~vesicle_mask] = np.inf
55
56        min_dist_point = np.argmin(dist)
57        min_dist_point = np.unravel_index(min_dist_point, vesicle_mask.shape)
58        min_dist = dist[min_dist_point]
59
60        if min_dist > max_vesicle_distance:
61            continue
62
63        ribbon_coord = tuple(idx_[min_dist_point] for idx_ in idx)
64        ribbon_id = ribbon_segmentation[ribbon_coord]
65        if ribbon_id == 0:
66            continue
67
68        if ribbon_id in vesicle_counts:
69            vesicle_counts[ribbon_id] += 1
70        else:
71            vesicle_counts[ribbon_id] = 1
72
73    # Create the output segmentation for the full output shape,
74    # keeping only the ribbons with sufficient number of associated vesicles.
75    full_ribbon_segmentation = np.zeros(original_shape, dtype="uint8")
76
77    if vesicle_counts:
78        ids = np.array(list(vesicle_counts.keys()))
79        counts = np.array(list(vesicle_counts.values()))
80    else:
81        print("No vesicles were matched to a ribbon")
82        print("Skipping postprocessing and returning the initial input")
83        full_ribbon_segmentation[slice_mask] = ribbon_prediction
84        return full_ribbon_segmentation
85
86    ids = ids[np.argsort(counts)[::-1]]
87
88    for output_id, ribbon_id in enumerate(ids[:n_ribbons], 1):
89        full_ribbon_segmentation[slice_mask][ribbon_segmentation == ribbon_id] = output_id
90
91    return full_ribbon_segmentation

Derive ribbon segmentation from ribbon predictions by filtering out ribbons that don't have sufficient associated vesicles.

Arguments:
  • ribbon_prediction: Binary prediction for ribbons in the tomogram.
  • vesicle_segmentation: The vesicle segmentation.
  • n_slices_exclude: The number of slices to exclude on the top / bottom in order to avoid segmentation errors due to imaging artifacts in top and bottom.
  • n_ribbons: The number of ribbons in the tomogram.
  • max_vesicle_distance: The maximal distance in pixels to associate a vesicle with a ribbon.