synapse_net.inference.postprocessing.membranes

  1from typing import Optional
  2import numpy as np
  3
  4from scipy.ndimage import distance_transform_edt
  5from skimage.measure import regionprops
  6
  7from elf.parallel import label
  8from synapse_net.distance_measurements import compute_geodesic_distances
  9
 10
 11def segment_membrane_next_to_object(
 12    boundary_prediction: np.array,
 13    object_segmentation: np.array,
 14    n_slices_exclude: int,
 15    radius: int = 25,
 16    n_fragments: int = 1,
 17):
 18    """Derive boundary segmentation from boundary predictions by
 19    selecting large boundary fragment closest to the object.
 20
 21    Args:
 22        boundary_prediction: Binary prediction for boundaries in the tomogram.
 23        object_segmentation: The object segmentation.
 24        n_slices_exclude: The number of slices to exclude on the top / bottom
 25            in order to avoid segmentation errors due to imaging artifacts in top and bottom.
 26        radius: The radius for membrane fragments that are considered.
 27        n_fragments: The number of boundary fragments to keep.
 28    """
 29    assert boundary_prediction.shape == object_segmentation.shape
 30
 31    original_shape = boundary_prediction.shape
 32
 33    # Cut away the exclude mask.
 34    slice_mask = np.s_[n_slices_exclude:-n_slices_exclude]
 35    boundary_prediction = boundary_prediction[slice_mask]
 36    object_segmentation = object_segmentation[slice_mask]
 37
 38    # Label the boundary predictions.
 39    boundary_segmentation = label(boundary_prediction, block_shape=(32, 256, 256))
 40
 41    # Compute the distance to object and the corresponding index.
 42    object_dist = distance_transform_edt(object_segmentation == 0)
 43
 44    # Find the distances to the object and fragment size.
 45    ids = []
 46    distances = []
 47    sizes = []
 48
 49    props = regionprops(boundary_segmentation)
 50    for prop in props:
 51        bb = prop.bbox
 52        bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]]
 53
 54        label_id = prop.label
 55        boundary_mask = boundary_segmentation[bb] == label_id
 56        dist = object_dist[bb].copy()
 57        dist[~boundary_mask] = np.inf
 58
 59        min_dist = np.min(dist)
 60        size = prop.area
 61
 62        ids.append(prop.label)
 63        distances.append(min_dist)
 64        sizes.append(size)
 65
 66    ids, distances, sizes = np.array(ids), np.array(distances), np.array(sizes)
 67
 68    mask = distances < radius
 69    if mask.sum() > 0:
 70        ids, sizes = ids[mask], sizes[mask]
 71
 72    keep_ids = ids[np.argsort(sizes)[::-1][:n_fragments]]
 73
 74    # Create the output segmentation for the full output shape,
 75    # keeping only the boundary fragment closest to the PD.
 76    full_boundary_segmentation = np.zeros(original_shape, dtype="uint8")
 77    full_boundary_segmentation[slice_mask][np.isin(boundary_segmentation, keep_ids)] = 1
 78
 79    return full_boundary_segmentation
 80
 81
 82def segment_membrane_distance_based(
 83    boundary_prediction: np.array,
 84    reference_segmentation: np.array,
 85    n_slices_exclude: int,
 86    max_distance: float,
 87    resolution: Optional[float] = None,
 88):
 89    assert boundary_prediction.shape == reference_segmentation.shape
 90
 91    original_shape = boundary_prediction.shape
 92
 93    # Cut away the exclude mask.
 94    slice_mask = np.s_[n_slices_exclude:-n_slices_exclude]
 95    boundary_prediction = boundary_prediction[slice_mask]
 96    reference_segmentation = reference_segmentation[slice_mask]
 97
 98    # Get the unique objects in the reference segmentation.
 99    reference_ids = np.unique(reference_segmentation)
100    assert reference_ids[0] == 0
101    reference_ids = reference_ids[1:]
102
103    # Compute the boundary fragments close to the unique objects in the reference.
104    full_boundary_segmentation = np.zeros(original_shape, dtype="uint8")
105    for seg_id in reference_ids:
106
107        # First, we find the closest point on the membrane surface.
108        ref_dist = distance_transform_edt(reference_segmentation != seg_id)
109        ref_dist[boundary_prediction == 0] = np.inf
110        closest_membrane = np.argmin(ref_dist)
111        closest_point = np.unravel_index(closest_membrane, ref_dist.shape)
112
113        # Then we compute the geodesic distance to this point on the distance and threshold it.
114        boundary_segmentation = compute_geodesic_distances(
115            boundary_prediction, closest_point, resolution
116        ) < max_distance
117
118        # boundary_segmentation = np.logical_and(boundary_prediction > 0, pd_dist < max_distance)
119        full_boundary_segmentation[slice_mask][boundary_segmentation] = 1
120
121    return full_boundary_segmentation
def segment_membrane_next_to_object( boundary_prediction: <built-in function array>, object_segmentation: <built-in function array>, n_slices_exclude: int, radius: int = 25, n_fragments: int = 1):
12def segment_membrane_next_to_object(
13    boundary_prediction: np.array,
14    object_segmentation: np.array,
15    n_slices_exclude: int,
16    radius: int = 25,
17    n_fragments: int = 1,
18):
19    """Derive boundary segmentation from boundary predictions by
20    selecting large boundary fragment closest to the object.
21
22    Args:
23        boundary_prediction: Binary prediction for boundaries in the tomogram.
24        object_segmentation: The object segmentation.
25        n_slices_exclude: The number of slices to exclude on the top / bottom
26            in order to avoid segmentation errors due to imaging artifacts in top and bottom.
27        radius: The radius for membrane fragments that are considered.
28        n_fragments: The number of boundary fragments to keep.
29    """
30    assert boundary_prediction.shape == object_segmentation.shape
31
32    original_shape = boundary_prediction.shape
33
34    # Cut away the exclude mask.
35    slice_mask = np.s_[n_slices_exclude:-n_slices_exclude]
36    boundary_prediction = boundary_prediction[slice_mask]
37    object_segmentation = object_segmentation[slice_mask]
38
39    # Label the boundary predictions.
40    boundary_segmentation = label(boundary_prediction, block_shape=(32, 256, 256))
41
42    # Compute the distance to object and the corresponding index.
43    object_dist = distance_transform_edt(object_segmentation == 0)
44
45    # Find the distances to the object and fragment size.
46    ids = []
47    distances = []
48    sizes = []
49
50    props = regionprops(boundary_segmentation)
51    for prop in props:
52        bb = prop.bbox
53        bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]]
54
55        label_id = prop.label
56        boundary_mask = boundary_segmentation[bb] == label_id
57        dist = object_dist[bb].copy()
58        dist[~boundary_mask] = np.inf
59
60        min_dist = np.min(dist)
61        size = prop.area
62
63        ids.append(prop.label)
64        distances.append(min_dist)
65        sizes.append(size)
66
67    ids, distances, sizes = np.array(ids), np.array(distances), np.array(sizes)
68
69    mask = distances < radius
70    if mask.sum() > 0:
71        ids, sizes = ids[mask], sizes[mask]
72
73    keep_ids = ids[np.argsort(sizes)[::-1][:n_fragments]]
74
75    # Create the output segmentation for the full output shape,
76    # keeping only the boundary fragment closest to the PD.
77    full_boundary_segmentation = np.zeros(original_shape, dtype="uint8")
78    full_boundary_segmentation[slice_mask][np.isin(boundary_segmentation, keep_ids)] = 1
79
80    return full_boundary_segmentation

Derive boundary segmentation from boundary predictions by selecting large boundary fragment closest to the object.

Arguments:
  • boundary_prediction: Binary prediction for boundaries in the tomogram.
  • object_segmentation: The object segmentation.
  • n_slices_exclude: The number of slices to exclude on the top / bottom in order to avoid segmentation errors due to imaging artifacts in top and bottom.
  • radius: The radius for membrane fragments that are considered.
  • n_fragments: The number of boundary fragments to keep.
def segment_membrane_distance_based( boundary_prediction: <built-in function array>, reference_segmentation: <built-in function array>, n_slices_exclude: int, max_distance: float, resolution: Optional[float] = None):
 83def segment_membrane_distance_based(
 84    boundary_prediction: np.array,
 85    reference_segmentation: np.array,
 86    n_slices_exclude: int,
 87    max_distance: float,
 88    resolution: Optional[float] = None,
 89):
 90    assert boundary_prediction.shape == reference_segmentation.shape
 91
 92    original_shape = boundary_prediction.shape
 93
 94    # Cut away the exclude mask.
 95    slice_mask = np.s_[n_slices_exclude:-n_slices_exclude]
 96    boundary_prediction = boundary_prediction[slice_mask]
 97    reference_segmentation = reference_segmentation[slice_mask]
 98
 99    # Get the unique objects in the reference segmentation.
100    reference_ids = np.unique(reference_segmentation)
101    assert reference_ids[0] == 0
102    reference_ids = reference_ids[1:]
103
104    # Compute the boundary fragments close to the unique objects in the reference.
105    full_boundary_segmentation = np.zeros(original_shape, dtype="uint8")
106    for seg_id in reference_ids:
107
108        # First, we find the closest point on the membrane surface.
109        ref_dist = distance_transform_edt(reference_segmentation != seg_id)
110        ref_dist[boundary_prediction == 0] = np.inf
111        closest_membrane = np.argmin(ref_dist)
112        closest_point = np.unravel_index(closest_membrane, ref_dist.shape)
113
114        # Then we compute the geodesic distance to this point on the distance and threshold it.
115        boundary_segmentation = compute_geodesic_distances(
116            boundary_prediction, closest_point, resolution
117        ) < max_distance
118
119        # boundary_segmentation = np.logical_and(boundary_prediction > 0, pd_dist < max_distance)
120        full_boundary_segmentation[slice_mask][boundary_segmentation] = 1
121
122    return full_boundary_segmentation