synapse_net.inference.postprocessing.presynaptic_density

 1import numpy as np
 2
 3from scipy.ndimage import distance_transform_edt
 4from skimage.measure import regionprops
 5
 6from elf.parallel import label
 7
 8
 9# TODO update for multiple ribbons and pds
10def segment_presynaptic_density(
11    presyn_prediction: np.array,
12    ribbon_segmentation: np.array,
13    n_slices_exclude: int,
14    n_pds_per_ribbon: int = 1,
15    max_distance_to_ribbon: int = 15,
16):
17    """Derive presynaptic density segmentation from predictions by
18    only keeping a PD prediction close to the ribbon.
19
20    Args:
21        presyn_prediction: Binary prediction for presynaptic densities in the tomogram.
22        ribbon_segmentation: The ribbon segmentation.
23        n_slices_exclude: The number of slices to exclude on the top / bottom
24            in order to avoid segmentation errors due to imaging artifacts in top and bottom.
25        max_distance_to_ribbon: The minimal distance to associate a PD with a ribbon.
26    """
27    assert presyn_prediction.shape == ribbon_segmentation.shape
28
29    original_shape = ribbon_segmentation.shape
30
31    # Cut away the exclude mask.
32    slice_mask = np.s_[n_slices_exclude:-n_slices_exclude]
33    presyn_prediction = presyn_prediction[slice_mask]
34    ribbon_segmentation = ribbon_segmentation[slice_mask]
35
36    # Label the presyn predictions.
37    presyn_segmentation = label(presyn_prediction, block_shape=(32, 256, 256))
38
39    # Compute the distance to a ribbon.
40    ribbon_dist, ribbon_idx = distance_transform_edt(ribbon_segmentation == 0, return_indices=True)
41
42    # Associate presynaptic densities with ribbons.
43    ribbon_matches = {}
44    props = regionprops(presyn_segmentation)
45    for prop in props:
46        bb = prop.bbox
47        bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]]
48
49        presyn_mask = presyn_segmentation[bb] == prop.label
50        dist, idx = ribbon_dist[bb].copy(), ribbon_idx[(slice(None),) + bb]
51        dist[~presyn_mask] = np.inf
52
53        min_dist_point = np.argmin(dist)
54        min_dist_point = np.unravel_index(min_dist_point, presyn_mask.shape)
55
56        this_distance = dist[min_dist_point]
57        if this_distance > max_distance_to_ribbon:
58            continue
59
60        ribbon_coord = tuple(idx_[min_dist_point] for idx_ in idx)
61        ribbon_id = ribbon_segmentation[ribbon_coord]
62        assert ribbon_id != 0
63
64        if ribbon_id in ribbon_matches:
65            ribbon_matches[ribbon_id].append([prop.label, this_distance])
66        else:
67            ribbon_matches[ribbon_id] = [[prop.label, this_distance]]
68
69    # Create the output segmentation for the full output shape,
70    # keeping only the presyns that are associated with a ribbon.
71    full_presyn_segmentation = np.zeros(original_shape, dtype="uint8")
72
73    for ribbon_id, matches in ribbon_matches.items():
74        if len(matches) == 0:  # no associated PD was found
75            continue
76        elif len(matches) == 1:  # exactly one associated PD was found
77            presyn_ids = [matches[0][0]]
78        else:  # multiple matches were found, assign all of them to the ribbon
79            matched_ids, matched_distances = [match[0] for match in matches], [match[1] for match in matches]
80            presyn_ids = np.array(matched_ids)[np.argsort(matched_distances)[:n_pds_per_ribbon]]
81
82        full_presyn_segmentation[slice_mask][np.isin(presyn_segmentation, presyn_ids)] = ribbon_id
83
84    if full_presyn_segmentation.sum() == 0:
85        print("No presynapse was found")
86    return full_presyn_segmentation
def segment_presynaptic_density( presyn_prediction: <built-in function array>, ribbon_segmentation: <built-in function array>, n_slices_exclude: int, n_pds_per_ribbon: int = 1, max_distance_to_ribbon: int = 15):
11def segment_presynaptic_density(
12    presyn_prediction: np.array,
13    ribbon_segmentation: np.array,
14    n_slices_exclude: int,
15    n_pds_per_ribbon: int = 1,
16    max_distance_to_ribbon: int = 15,
17):
18    """Derive presynaptic density segmentation from predictions by
19    only keeping a PD prediction close to the ribbon.
20
21    Args:
22        presyn_prediction: Binary prediction for presynaptic densities in the tomogram.
23        ribbon_segmentation: The ribbon segmentation.
24        n_slices_exclude: The number of slices to exclude on the top / bottom
25            in order to avoid segmentation errors due to imaging artifacts in top and bottom.
26        max_distance_to_ribbon: The minimal distance to associate a PD with a ribbon.
27    """
28    assert presyn_prediction.shape == ribbon_segmentation.shape
29
30    original_shape = ribbon_segmentation.shape
31
32    # Cut away the exclude mask.
33    slice_mask = np.s_[n_slices_exclude:-n_slices_exclude]
34    presyn_prediction = presyn_prediction[slice_mask]
35    ribbon_segmentation = ribbon_segmentation[slice_mask]
36
37    # Label the presyn predictions.
38    presyn_segmentation = label(presyn_prediction, block_shape=(32, 256, 256))
39
40    # Compute the distance to a ribbon.
41    ribbon_dist, ribbon_idx = distance_transform_edt(ribbon_segmentation == 0, return_indices=True)
42
43    # Associate presynaptic densities with ribbons.
44    ribbon_matches = {}
45    props = regionprops(presyn_segmentation)
46    for prop in props:
47        bb = prop.bbox
48        bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]]
49
50        presyn_mask = presyn_segmentation[bb] == prop.label
51        dist, idx = ribbon_dist[bb].copy(), ribbon_idx[(slice(None),) + bb]
52        dist[~presyn_mask] = np.inf
53
54        min_dist_point = np.argmin(dist)
55        min_dist_point = np.unravel_index(min_dist_point, presyn_mask.shape)
56
57        this_distance = dist[min_dist_point]
58        if this_distance > max_distance_to_ribbon:
59            continue
60
61        ribbon_coord = tuple(idx_[min_dist_point] for idx_ in idx)
62        ribbon_id = ribbon_segmentation[ribbon_coord]
63        assert ribbon_id != 0
64
65        if ribbon_id in ribbon_matches:
66            ribbon_matches[ribbon_id].append([prop.label, this_distance])
67        else:
68            ribbon_matches[ribbon_id] = [[prop.label, this_distance]]
69
70    # Create the output segmentation for the full output shape,
71    # keeping only the presyns that are associated with a ribbon.
72    full_presyn_segmentation = np.zeros(original_shape, dtype="uint8")
73
74    for ribbon_id, matches in ribbon_matches.items():
75        if len(matches) == 0:  # no associated PD was found
76            continue
77        elif len(matches) == 1:  # exactly one associated PD was found
78            presyn_ids = [matches[0][0]]
79        else:  # multiple matches were found, assign all of them to the ribbon
80            matched_ids, matched_distances = [match[0] for match in matches], [match[1] for match in matches]
81            presyn_ids = np.array(matched_ids)[np.argsort(matched_distances)[:n_pds_per_ribbon]]
82
83        full_presyn_segmentation[slice_mask][np.isin(presyn_segmentation, presyn_ids)] = ribbon_id
84
85    if full_presyn_segmentation.sum() == 0:
86        print("No presynapse was found")
87    return full_presyn_segmentation

Derive presynaptic density segmentation from predictions by only keeping a PD prediction close to the ribbon.

Arguments:
  • presyn_prediction: Binary prediction for presynaptic densities in the tomogram.
  • ribbon_segmentation: The ribbon segmentation.
  • n_slices_exclude: The number of slices to exclude on the top / bottom in order to avoid segmentation errors due to imaging artifacts in top and bottom.
  • max_distance_to_ribbon: The minimal distance to associate a PD with a ribbon.