synapse_net.inference.compartments
1import time 2from typing import Dict, List, Optional, Tuple, Union 3 4import numpy as np 5import torch 6 7import bioimage_cpp as bic 8import elf.segmentation as eseg 9from elf.tracking.tracking_utils import compute_edges_from_overlap 10from scipy.ndimage import distance_transform_edt, binary_closing 11from skimage.measure import label, regionprops 12from skimage.segmentation import watershed 13from skimage.morphology import remove_small_holes, local_maxima 14 15from synapse_net.inference.util import get_prediction, _Scaler, _postprocess_seg_3d 16 17 18def _segment_compartments_2d( 19 boundaries, 20 boundary_threshold=0.4, # Threshold for the boundary distance computation. 21 large_seed_distance=30, # The distance threshold for computing large seeds (= components). 22 distances=None, # Pre-computed distances to take into account z-context. 23): 24 # Compoute distances if already not precomputed. 25 if distances is None: 26 distances = distance_transform_edt(boundaries < boundary_threshold).astype("float32") 27 distances_z = distances 28 else: 29 # If the distances were pre-computed then compute them again in 2d. 30 # This is needed for inserting small seeds from maxima, otherwise we will get spurious maxima. 31 distances_z = distance_transform_edt(boundaries < boundary_threshold).astype("float32") 32 33 # Find the large seeds as connected components in the distances > large_seed_distance. 34 seeds = label(distances > large_seed_distance) 35 36 # Remove to small large seeds. 37 min_seed_area = 50 38 ids, sizes = np.unique(seeds, return_counts=True) 39 remove_ids = ids[sizes < min_seed_area] 40 seeds[np.isin(seeds, remove_ids)] = 0 41 42 # Compute the small seeds = local maxima of the in-plane distance map 43 small_seeds = label(local_maxima(distances_z, allow_borders=True)) 44 45 # We only keep small seeds that don't intersect with a large seed. 46 props = regionprops(small_seeds, seeds) 47 keep_seeds = [prop.label for prop in props if prop.max_intensity == 0] 48 keep_mask = np.isin(small_seeds, keep_seeds) 49 50 # Add up the small seeds we keep with the large seeds. 51 all_seeds = seeds.copy() 52 seed_offset = seeds.max() 53 all_seeds[keep_mask] = (small_seeds[keep_mask] + seed_offset) 54 55 # Run watershed to get the segmentation. 56 hmap = boundaries + (distances.max() - distances) / distances.max() 57 raw_segmentation = watershed(hmap, markers=all_seeds) 58 59 # Thee are the large seed ids that we will keep. 60 keep_ids = list(range(1, seed_offset + 1)) 61 62 # Iterate over the ids, only keep large seeds and remove holes in their respective masks. 63 props = regionprops(raw_segmentation) 64 segmentation = np.zeros_like(raw_segmentation) 65 for prop in props: 66 if prop.label not in keep_ids: 67 continue 68 69 # Get bounding box and mask. 70 bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:2], prop.bbox[2:])) 71 mask = raw_segmentation[bb] == prop.label 72 73 # Fill small holes and apply closing. 74 mask = remove_small_holes(mask, area_threshold=500) 75 mask = np.logical_or(binary_closing(mask, iterations=4), mask) 76 segmentation[bb][mask] = prop.label 77 78 return segmentation 79 80 81def _merge_segmentation_3d(seg_2d, beta=0.5, min_z_extent=10): 82 edges = compute_edges_from_overlap(seg_2d, verbose=False) 83 84 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 85 overlaps = np.array([edge["score"] for edge in edges]) 86 87 n_nodes = int(seg_2d.max() + 1) 88 graph = bic.graph.UndirectedGraph(n_nodes) 89 graph.insert_edges(uv_ids) 90 91 costs = eseg.multicut.compute_edge_costs(1.0 - overlaps) 92 # set background weights to be maximally repulsive 93 bg_edges = (uv_ids == 0).any(axis=1) 94 costs[bg_edges] = -8.0 95 96 node_labels = eseg.multicut.multicut_decomposition(graph, costs, beta=beta) 97 segmentation = node_labels[seg_2d] 98 99 if min_z_extent is not None and min_z_extent > 0: 100 props = regionprops(segmentation) 101 filter_ids = [] 102 for prop in props: 103 box = prop.bbox 104 z_extent = box[3] - box[0] 105 if z_extent < min_z_extent: 106 filter_ids.append(prop.label) 107 if filter_ids: 108 segmentation[np.isin(segmentation, filter_ids)] = 0 109 110 return segmentation 111 112 113def _segment_compartments_3d( 114 prediction, 115 boundary_threshold=0.4, 116 n_slices_exclude=0, 117 min_z_extent=10, 118): 119 distances = distance_transform_edt(prediction < boundary_threshold).astype("float32") 120 seg_2d = np.zeros(prediction.shape, dtype="uint32") 121 122 offset = 0 123 # Parallelize? 124 for z in range(seg_2d.shape[0]): 125 if z < n_slices_exclude or z >= seg_2d.shape[0] - n_slices_exclude: 126 continue 127 seg_z = _segment_compartments_2d(prediction[z], distances=distances[z]) 128 seg_z[seg_z != 0] += offset 129 offset = max(int(seg_z.max()), offset) 130 seg_2d[z] = seg_z 131 132 seg = _merge_segmentation_3d(seg_2d, min_z_extent) 133 seg = _postprocess_seg_3d(seg) 134 135 # import napari 136 # v = napari.Viewer() 137 # v.add_image(prediction) 138 # v.add_image(distances) 139 # v.add_labels(seg_2d) 140 # v.add_labels(seg) 141 # napari.run() 142 143 return seg 144 145 146def segment_compartments( 147 input_volume: np.ndarray, 148 model_path: Optional[str] = None, 149 model: Optional[torch.nn.Module] = None, 150 tiling: Optional[Dict[str, Dict[str, int]]] = None, 151 verbose: bool = True, 152 return_predictions: bool = False, 153 scale: Optional[List[float]] = None, 154 mask: Optional[np.ndarray] = None, 155 n_slices_exclude: int = 0, 156 boundary_threshold: float=0.4, 157 **kwargs, 158) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 159 """Segment synaptic compartments in an input volume. 160 161 Args: 162 input_volume: The input volume to segment. 163 model_path: The path to the model checkpoint if `model` is not provided. 164 model: Pre-loaded model. Either `model_path` or `model` is required. 165 tiling: The tiling configuration for the prediction. 166 verbose: Whether to print timing information. 167 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 168 scale: The scale factor to use for rescaling the input volume before prediction. 169 n_slices_exclude: 170 boundary_threshold: Threshold that determines when the prediction of the network is foreground for the segmentation. Need higher threshold than default for TEM. 171 172 Returns: 173 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 174 and the predictions if return_predictions is True. 175 """ 176 if verbose: 177 print("Segmenting compartments in volume of shape", input_volume.shape) 178 179 # Create the scaler to handle prediction with a different scaling factor. 180 scaler = _Scaler(scale, verbose) 181 input_volume = scaler.scale_input(input_volume) 182 183 # Run prediction. Support models with a single or multiple channels, 184 # assuming that the first channel is the boundary prediction. 185 pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose) 186 187 # Remove channel axis if necessary. 188 if pred.ndim != input_volume.ndim: 189 assert pred.ndim == input_volume.ndim + 1 190 pred = pred[0] 191 192 # Run the compartment segmentation. 193 # We may want to expose some of the parameters here. 194 t0 = time.time() 195 if input_volume.ndim == 2: 196 seg = _segment_compartments_2d(pred, boundary_threshold=boundary_threshold) 197 else: 198 seg = _segment_compartments_3d(pred, n_slices_exclude=n_slices_exclude, boundary_threshold=boundary_threshold) 199 if verbose: 200 print("Run segmentation in", time.time() - t0, "s") 201 202 seg = scaler.rescale_output(seg, is_segmentation=True) 203 204 if return_predictions: 205 pred = scaler.rescale_output(pred, is_segmentation=False) 206 return seg, pred 207 return seg
def
segment_compartments( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, verbose: bool = True, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, n_slices_exclude: int = 0, boundary_threshold: float = 0.4, **kwargs) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
147def segment_compartments( 148 input_volume: np.ndarray, 149 model_path: Optional[str] = None, 150 model: Optional[torch.nn.Module] = None, 151 tiling: Optional[Dict[str, Dict[str, int]]] = None, 152 verbose: bool = True, 153 return_predictions: bool = False, 154 scale: Optional[List[float]] = None, 155 mask: Optional[np.ndarray] = None, 156 n_slices_exclude: int = 0, 157 boundary_threshold: float=0.4, 158 **kwargs, 159) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 160 """Segment synaptic compartments in an input volume. 161 162 Args: 163 input_volume: The input volume to segment. 164 model_path: The path to the model checkpoint if `model` is not provided. 165 model: Pre-loaded model. Either `model_path` or `model` is required. 166 tiling: The tiling configuration for the prediction. 167 verbose: Whether to print timing information. 168 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 169 scale: The scale factor to use for rescaling the input volume before prediction. 170 n_slices_exclude: 171 boundary_threshold: Threshold that determines when the prediction of the network is foreground for the segmentation. Need higher threshold than default for TEM. 172 173 Returns: 174 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 175 and the predictions if return_predictions is True. 176 """ 177 if verbose: 178 print("Segmenting compartments in volume of shape", input_volume.shape) 179 180 # Create the scaler to handle prediction with a different scaling factor. 181 scaler = _Scaler(scale, verbose) 182 input_volume = scaler.scale_input(input_volume) 183 184 # Run prediction. Support models with a single or multiple channels, 185 # assuming that the first channel is the boundary prediction. 186 pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose) 187 188 # Remove channel axis if necessary. 189 if pred.ndim != input_volume.ndim: 190 assert pred.ndim == input_volume.ndim + 1 191 pred = pred[0] 192 193 # Run the compartment segmentation. 194 # We may want to expose some of the parameters here. 195 t0 = time.time() 196 if input_volume.ndim == 2: 197 seg = _segment_compartments_2d(pred, boundary_threshold=boundary_threshold) 198 else: 199 seg = _segment_compartments_3d(pred, n_slices_exclude=n_slices_exclude, boundary_threshold=boundary_threshold) 200 if verbose: 201 print("Run segmentation in", time.time() - t0, "s") 202 203 seg = scaler.rescale_output(seg, is_segmentation=True) 204 205 if return_predictions: 206 pred = scaler.rescale_output(pred, is_segmentation=False) 207 return seg, pred 208 return seg
Segment synaptic compartments in an input volume.
Arguments:
- input_volume: The input volume to segment.
- model_path: The path to the model checkpoint if
modelis not provided. - model: Pre-loaded model. Either
model_pathormodelis required. - tiling: The tiling configuration for the prediction.
- verbose: Whether to print timing information.
- return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
- scale: The scale factor to use for rescaling the input volume before prediction.
- n_slices_exclude:
- boundary_threshold: Threshold that determines when the prediction of the network is foreground for the segmentation. Need higher threshold than default for TEM.
Returns:
The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.