synapse_net.inference.postprocessing.membranes
1from typing import Optional 2import numpy as np 3 4from scipy.ndimage import distance_transform_edt 5from skimage.measure import regionprops 6 7from elf.parallel import label 8from synapse_net.distance_measurements import compute_geodesic_distances 9 10 11def segment_membrane_next_to_object( 12 boundary_prediction: np.array, 13 object_segmentation: np.array, 14 n_slices_exclude: int, 15 radius: int = 25, 16 n_fragments: int = 1, 17): 18 """Derive boundary segmentation from boundary predictions by 19 selecting large boundary fragment closest to the object. 20 21 Args: 22 boundary_prediction: Binary prediction for boundaries in the tomogram. 23 object_segmentation: The object segmentation. 24 n_slices_exclude: The number of slices to exclude on the top / bottom 25 in order to avoid segmentation errors due to imaging artifacts in top and bottom. 26 radius: The radius for membrane fragments that are considered. 27 n_fragments: The number of boundary fragments to keep. 28 """ 29 assert boundary_prediction.shape == object_segmentation.shape 30 31 original_shape = boundary_prediction.shape 32 33 # Cut away the exclude mask. 34 slice_mask = np.s_[n_slices_exclude:-n_slices_exclude] 35 boundary_prediction = boundary_prediction[slice_mask] 36 object_segmentation = object_segmentation[slice_mask] 37 38 # Label the boundary predictions. 39 boundary_segmentation = label(boundary_prediction, block_shape=(32, 256, 256)) 40 41 # Compute the distance to object and the corresponding index. 42 object_dist = distance_transform_edt(object_segmentation == 0) 43 44 # Find the distances to the object and fragment size. 45 ids = [] 46 distances = [] 47 sizes = [] 48 49 props = regionprops(boundary_segmentation) 50 for prop in props: 51 bb = prop.bbox 52 bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]] 53 54 label_id = prop.label 55 boundary_mask = boundary_segmentation[bb] == label_id 56 dist = object_dist[bb].copy() 57 dist[~boundary_mask] = np.inf 58 59 min_dist = np.min(dist) 60 size = prop.area 61 62 ids.append(prop.label) 63 distances.append(min_dist) 64 sizes.append(size) 65 66 ids, distances, sizes = np.array(ids), np.array(distances), np.array(sizes) 67 68 mask = distances < radius 69 if mask.sum() > 0: 70 ids, sizes = ids[mask], sizes[mask] 71 72 keep_ids = ids[np.argsort(sizes)[::-1][:n_fragments]] 73 74 # Create the output segmentation for the full output shape, 75 # keeping only the boundary fragment closest to the PD. 76 full_boundary_segmentation = np.zeros(original_shape, dtype="uint8") 77 full_boundary_segmentation[slice_mask][np.isin(boundary_segmentation, keep_ids)] = 1 78 79 return full_boundary_segmentation 80 81 82def segment_membrane_distance_based( 83 boundary_prediction: np.array, 84 reference_segmentation: np.array, 85 n_slices_exclude: int, 86 max_distance: float, 87 resolution: Optional[float] = None, 88): 89 assert boundary_prediction.shape == reference_segmentation.shape 90 91 original_shape = boundary_prediction.shape 92 93 # Cut away the exclude mask. 94 slice_mask = np.s_[n_slices_exclude:-n_slices_exclude] 95 boundary_prediction = boundary_prediction[slice_mask] 96 reference_segmentation = reference_segmentation[slice_mask] 97 98 # Get the unique objects in the reference segmentation. 99 reference_ids = np.unique(reference_segmentation) 100 assert reference_ids[0] == 0 101 reference_ids = reference_ids[1:] 102 103 # Compute the boundary fragments close to the unique objects in the reference. 104 full_boundary_segmentation = np.zeros(original_shape, dtype="uint8") 105 for seg_id in reference_ids: 106 107 # First, we find the closest point on the membrane surface. 108 ref_dist = distance_transform_edt(reference_segmentation != seg_id) 109 ref_dist[boundary_prediction == 0] = np.inf 110 closest_membrane = np.argmin(ref_dist) 111 closest_point = np.unravel_index(closest_membrane, ref_dist.shape) 112 113 # Then we compute the geodesic distance to this point on the distance and threshold it. 114 boundary_segmentation = compute_geodesic_distances( 115 boundary_prediction, closest_point, resolution 116 ) < max_distance 117 118 # boundary_segmentation = np.logical_and(boundary_prediction > 0, pd_dist < max_distance) 119 full_boundary_segmentation[slice_mask][boundary_segmentation] = 1 120 121 return full_boundary_segmentation
def
segment_membrane_next_to_object( boundary_prediction: <built-in function array>, object_segmentation: <built-in function array>, n_slices_exclude: int, radius: int = 25, n_fragments: int = 1):
12def segment_membrane_next_to_object( 13 boundary_prediction: np.array, 14 object_segmentation: np.array, 15 n_slices_exclude: int, 16 radius: int = 25, 17 n_fragments: int = 1, 18): 19 """Derive boundary segmentation from boundary predictions by 20 selecting large boundary fragment closest to the object. 21 22 Args: 23 boundary_prediction: Binary prediction for boundaries in the tomogram. 24 object_segmentation: The object segmentation. 25 n_slices_exclude: The number of slices to exclude on the top / bottom 26 in order to avoid segmentation errors due to imaging artifacts in top and bottom. 27 radius: The radius for membrane fragments that are considered. 28 n_fragments: The number of boundary fragments to keep. 29 """ 30 assert boundary_prediction.shape == object_segmentation.shape 31 32 original_shape = boundary_prediction.shape 33 34 # Cut away the exclude mask. 35 slice_mask = np.s_[n_slices_exclude:-n_slices_exclude] 36 boundary_prediction = boundary_prediction[slice_mask] 37 object_segmentation = object_segmentation[slice_mask] 38 39 # Label the boundary predictions. 40 boundary_segmentation = label(boundary_prediction, block_shape=(32, 256, 256)) 41 42 # Compute the distance to object and the corresponding index. 43 object_dist = distance_transform_edt(object_segmentation == 0) 44 45 # Find the distances to the object and fragment size. 46 ids = [] 47 distances = [] 48 sizes = [] 49 50 props = regionprops(boundary_segmentation) 51 for prop in props: 52 bb = prop.bbox 53 bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]] 54 55 label_id = prop.label 56 boundary_mask = boundary_segmentation[bb] == label_id 57 dist = object_dist[bb].copy() 58 dist[~boundary_mask] = np.inf 59 60 min_dist = np.min(dist) 61 size = prop.area 62 63 ids.append(prop.label) 64 distances.append(min_dist) 65 sizes.append(size) 66 67 ids, distances, sizes = np.array(ids), np.array(distances), np.array(sizes) 68 69 mask = distances < radius 70 if mask.sum() > 0: 71 ids, sizes = ids[mask], sizes[mask] 72 73 keep_ids = ids[np.argsort(sizes)[::-1][:n_fragments]] 74 75 # Create the output segmentation for the full output shape, 76 # keeping only the boundary fragment closest to the PD. 77 full_boundary_segmentation = np.zeros(original_shape, dtype="uint8") 78 full_boundary_segmentation[slice_mask][np.isin(boundary_segmentation, keep_ids)] = 1 79 80 return full_boundary_segmentation
Derive boundary segmentation from boundary predictions by selecting large boundary fragment closest to the object.
Arguments:
- boundary_prediction: Binary prediction for boundaries in the tomogram.
- object_segmentation: The object segmentation.
- n_slices_exclude: The number of slices to exclude on the top / bottom in order to avoid segmentation errors due to imaging artifacts in top and bottom.
- radius: The radius for membrane fragments that are considered.
- n_fragments: The number of boundary fragments to keep.
def
segment_membrane_distance_based( boundary_prediction: <built-in function array>, reference_segmentation: <built-in function array>, n_slices_exclude: int, max_distance: float, resolution: Optional[float] = None):
83def segment_membrane_distance_based( 84 boundary_prediction: np.array, 85 reference_segmentation: np.array, 86 n_slices_exclude: int, 87 max_distance: float, 88 resolution: Optional[float] = None, 89): 90 assert boundary_prediction.shape == reference_segmentation.shape 91 92 original_shape = boundary_prediction.shape 93 94 # Cut away the exclude mask. 95 slice_mask = np.s_[n_slices_exclude:-n_slices_exclude] 96 boundary_prediction = boundary_prediction[slice_mask] 97 reference_segmentation = reference_segmentation[slice_mask] 98 99 # Get the unique objects in the reference segmentation. 100 reference_ids = np.unique(reference_segmentation) 101 assert reference_ids[0] == 0 102 reference_ids = reference_ids[1:] 103 104 # Compute the boundary fragments close to the unique objects in the reference. 105 full_boundary_segmentation = np.zeros(original_shape, dtype="uint8") 106 for seg_id in reference_ids: 107 108 # First, we find the closest point on the membrane surface. 109 ref_dist = distance_transform_edt(reference_segmentation != seg_id) 110 ref_dist[boundary_prediction == 0] = np.inf 111 closest_membrane = np.argmin(ref_dist) 112 closest_point = np.unravel_index(closest_membrane, ref_dist.shape) 113 114 # Then we compute the geodesic distance to this point on the distance and threshold it. 115 boundary_segmentation = compute_geodesic_distances( 116 boundary_prediction, closest_point, resolution 117 ) < max_distance 118 119 # boundary_segmentation = np.logical_and(boundary_prediction > 0, pd_dist < max_distance) 120 full_boundary_segmentation[slice_mask][boundary_segmentation] = 1 121 122 return full_boundary_segmentation