synapse_net.inference.postprocessing.ribbon
1import numpy as np 2 3import elf.parallel as parallel 4from skimage.measure import regionprops 5 6 7def segment_ribbon( 8 ribbon_prediction: np.array, 9 vesicle_segmentation: np.array, 10 n_slices_exclude: int, 11 n_ribbons: int, 12 max_vesicle_distance: int = 20, 13): 14 """Derive ribbon segmentation from ribbon predictions by 15 filtering out ribbons that don't have sufficient associated vesicles. 16 17 Args: 18 ribbon_prediction: Binary prediction for ribbons in the tomogram. 19 vesicle_segmentation: The vesicle segmentation. 20 n_slices_exclude: The number of slices to exclude on the top / bottom 21 in order to avoid segmentation errors due to imaging artifacts in top and bottom. 22 n_ribbons: The number of ribbons in the tomogram. 23 max_vesicle_distance: The maximal distance in pixels to associate a vesicle with a ribbon. 24 """ 25 assert ribbon_prediction.shape == vesicle_segmentation.shape 26 27 original_shape = ribbon_prediction.shape 28 29 # Cut away the exclude mask. 30 slice_mask = np.s_[n_slices_exclude:-n_slices_exclude] 31 ribbon_prediction = ribbon_prediction[slice_mask] 32 vesicle_segmentation = vesicle_segmentation[slice_mask] 33 34 block_shape = (48, 384, 384) 35 # Label the ribbon predictions. 36 ribbon_segmentation = parallel.label(ribbon_prediction, block_shape=block_shape) 37 38 # Compute the distance to ribbon and the corresponding index. 39 halo = 3 * (max_vesicle_distance + 1,) 40 ribbon_dist, ribbon_idx = parallel.distance_transform( 41 ribbon_prediction == 0, return_indices=True, halo=halo, block_shape=block_shape 42 ) 43 44 # Count the number of vesicles associated with each foreground object in the ribbon prediction. 45 vesicle_counts = {} 46 props = regionprops(vesicle_segmentation) 47 for prop in props: 48 bb = prop.bbox 49 bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]] 50 51 vesicle_mask = vesicle_segmentation[bb] == prop.label 52 dist, idx = ribbon_dist[bb].copy(), ribbon_idx[(slice(None),) + bb] 53 dist[~vesicle_mask] = np.inf 54 55 min_dist_point = np.argmin(dist) 56 min_dist_point = np.unravel_index(min_dist_point, vesicle_mask.shape) 57 min_dist = dist[min_dist_point] 58 59 if min_dist > max_vesicle_distance: 60 continue 61 62 ribbon_coord = tuple(idx_[min_dist_point] for idx_ in idx) 63 ribbon_id = ribbon_segmentation[ribbon_coord] 64 if ribbon_id == 0: 65 continue 66 67 if ribbon_id in vesicle_counts: 68 vesicle_counts[ribbon_id] += 1 69 else: 70 vesicle_counts[ribbon_id] = 1 71 72 # Create the output segmentation for the full output shape, 73 # keeping only the ribbons with sufficient number of associated vesicles. 74 full_ribbon_segmentation = np.zeros(original_shape, dtype="uint8") 75 76 if vesicle_counts: 77 ids = np.array(list(vesicle_counts.keys())) 78 counts = np.array(list(vesicle_counts.values())) 79 else: 80 print("No vesicles were matched to a ribbon") 81 print("Skipping postprocessing and returning the initial input") 82 full_ribbon_segmentation[slice_mask] = ribbon_prediction 83 return full_ribbon_segmentation 84 85 ids = ids[np.argsort(counts)[::-1]] 86 87 for output_id, ribbon_id in enumerate(ids[:n_ribbons], 1): 88 full_ribbon_segmentation[slice_mask][ribbon_segmentation == ribbon_id] = output_id 89 90 return full_ribbon_segmentation
def
segment_ribbon( ribbon_prediction: <built-in function array>, vesicle_segmentation: <built-in function array>, n_slices_exclude: int, n_ribbons: int, max_vesicle_distance: int = 20):
8def segment_ribbon( 9 ribbon_prediction: np.array, 10 vesicle_segmentation: np.array, 11 n_slices_exclude: int, 12 n_ribbons: int, 13 max_vesicle_distance: int = 20, 14): 15 """Derive ribbon segmentation from ribbon predictions by 16 filtering out ribbons that don't have sufficient associated vesicles. 17 18 Args: 19 ribbon_prediction: Binary prediction for ribbons in the tomogram. 20 vesicle_segmentation: The vesicle segmentation. 21 n_slices_exclude: The number of slices to exclude on the top / bottom 22 in order to avoid segmentation errors due to imaging artifacts in top and bottom. 23 n_ribbons: The number of ribbons in the tomogram. 24 max_vesicle_distance: The maximal distance in pixels to associate a vesicle with a ribbon. 25 """ 26 assert ribbon_prediction.shape == vesicle_segmentation.shape 27 28 original_shape = ribbon_prediction.shape 29 30 # Cut away the exclude mask. 31 slice_mask = np.s_[n_slices_exclude:-n_slices_exclude] 32 ribbon_prediction = ribbon_prediction[slice_mask] 33 vesicle_segmentation = vesicle_segmentation[slice_mask] 34 35 block_shape = (48, 384, 384) 36 # Label the ribbon predictions. 37 ribbon_segmentation = parallel.label(ribbon_prediction, block_shape=block_shape) 38 39 # Compute the distance to ribbon and the corresponding index. 40 halo = 3 * (max_vesicle_distance + 1,) 41 ribbon_dist, ribbon_idx = parallel.distance_transform( 42 ribbon_prediction == 0, return_indices=True, halo=halo, block_shape=block_shape 43 ) 44 45 # Count the number of vesicles associated with each foreground object in the ribbon prediction. 46 vesicle_counts = {} 47 props = regionprops(vesicle_segmentation) 48 for prop in props: 49 bb = prop.bbox 50 bb = np.s_[bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]] 51 52 vesicle_mask = vesicle_segmentation[bb] == prop.label 53 dist, idx = ribbon_dist[bb].copy(), ribbon_idx[(slice(None),) + bb] 54 dist[~vesicle_mask] = np.inf 55 56 min_dist_point = np.argmin(dist) 57 min_dist_point = np.unravel_index(min_dist_point, vesicle_mask.shape) 58 min_dist = dist[min_dist_point] 59 60 if min_dist > max_vesicle_distance: 61 continue 62 63 ribbon_coord = tuple(idx_[min_dist_point] for idx_ in idx) 64 ribbon_id = ribbon_segmentation[ribbon_coord] 65 if ribbon_id == 0: 66 continue 67 68 if ribbon_id in vesicle_counts: 69 vesicle_counts[ribbon_id] += 1 70 else: 71 vesicle_counts[ribbon_id] = 1 72 73 # Create the output segmentation for the full output shape, 74 # keeping only the ribbons with sufficient number of associated vesicles. 75 full_ribbon_segmentation = np.zeros(original_shape, dtype="uint8") 76 77 if vesicle_counts: 78 ids = np.array(list(vesicle_counts.keys())) 79 counts = np.array(list(vesicle_counts.values())) 80 else: 81 print("No vesicles were matched to a ribbon") 82 print("Skipping postprocessing and returning the initial input") 83 full_ribbon_segmentation[slice_mask] = ribbon_prediction 84 return full_ribbon_segmentation 85 86 ids = ids[np.argsort(counts)[::-1]] 87 88 for output_id, ribbon_id in enumerate(ids[:n_ribbons], 1): 89 full_ribbon_segmentation[slice_mask][ribbon_segmentation == ribbon_id] = output_id 90 91 return full_ribbon_segmentation
Derive ribbon segmentation from ribbon predictions by filtering out ribbons that don't have sufficient associated vesicles.
Arguments:
- ribbon_prediction: Binary prediction for ribbons in the tomogram.
- vesicle_segmentation: The vesicle segmentation.
- n_slices_exclude: The number of slices to exclude on the top / bottom in order to avoid segmentation errors due to imaging artifacts in top and bottom.
- n_ribbons: The number of ribbons in the tomogram.
- max_vesicle_distance: The maximal distance in pixels to associate a vesicle with a ribbon.