synapse_net.inference.compartments
1import time 2from typing import Dict, List, Optional, Tuple, Union 3 4import numpy as np 5import vigra 6import torch 7 8import elf.segmentation as eseg 9import nifty 10from elf.tracking.tracking_utils import compute_edges_from_overlap 11from scipy.ndimage import distance_transform_edt, binary_closing 12from skimage.measure import label, regionprops 13from skimage.segmentation import watershed 14from skimage.morphology import remove_small_holes 15 16from synapse_net.inference.util import get_prediction, _Scaler, _postprocess_seg_3d 17 18 19def _segment_compartments_2d( 20 boundaries, 21 boundary_threshold=0.4, # Threshold for the boundary distance computation. 22 large_seed_distance=30, # The distance threshold for computing large seeds (= components). 23 distances=None, # Pre-computed distances to take into account z-context. 24): 25 # Compoute distances if already not precomputed. 26 if distances is None: 27 distances = distance_transform_edt(boundaries < boundary_threshold).astype("float32") 28 distances_z = distances 29 else: 30 # If the distances were pre-computed then compute them again in 2d. 31 # This is needed for inserting small seeds from maxima, otherwise we will get spurious maxima. 32 distances_z = distance_transform_edt(boundaries < boundary_threshold).astype("float32") 33 34 # Find the large seeds as connected components in the distances > large_seed_distance. 35 seeds = label(distances > large_seed_distance) 36 37 # Remove to small large seeds. 38 min_seed_area = 50 39 ids, sizes = np.unique(seeds, return_counts=True) 40 remove_ids = ids[sizes < min_seed_area] 41 seeds[np.isin(seeds, remove_ids)] = 0 42 43 # Compute the small seeds = local maxima of the in-plane distance map 44 small_seeds = vigra.analysis.localMaxima(distances_z, marker=np.nan, allowAtBorder=True, allowPlateaus=True) 45 small_seeds = label(np.isnan(small_seeds)) 46 47 # We only keep small seeds that don't intersect with a large seed. 48 props = regionprops(small_seeds, seeds) 49 keep_seeds = [prop.label for prop in props if prop.max_intensity == 0] 50 keep_mask = np.isin(small_seeds, keep_seeds) 51 52 # Add up the small seeds we keep with the large seeds. 53 all_seeds = seeds.copy() 54 seed_offset = seeds.max() 55 all_seeds[keep_mask] = (small_seeds[keep_mask] + seed_offset) 56 57 # Run watershed to get the segmentation. 58 hmap = boundaries + (distances.max() - distances) / distances.max() 59 raw_segmentation = watershed(hmap, markers=all_seeds) 60 61 # Thee are the large seed ids that we will keep. 62 keep_ids = list(range(1, seed_offset + 1)) 63 64 # Iterate over the ids, only keep large seeds and remove holes in their respective masks. 65 props = regionprops(raw_segmentation) 66 segmentation = np.zeros_like(raw_segmentation) 67 for prop in props: 68 if prop.label not in keep_ids: 69 continue 70 71 # Get bounding box and mask. 72 bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:2], prop.bbox[2:])) 73 mask = raw_segmentation[bb] == prop.label 74 75 # Fill small holes and apply closing. 76 mask = remove_small_holes(mask, area_threshold=500) 77 mask = np.logical_or(binary_closing(mask, iterations=4), mask) 78 segmentation[bb][mask] = prop.label 79 80 return segmentation 81 82 83def _merge_segmentation_3d(seg_2d, beta=0.5, min_z_extent=10): 84 edges = compute_edges_from_overlap(seg_2d, verbose=False) 85 86 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 87 overlaps = np.array([edge["score"] for edge in edges]) 88 89 n_nodes = int(seg_2d.max() + 1) 90 graph = nifty.graph.undirectedGraph(n_nodes) 91 graph.insertEdges(uv_ids) 92 93 costs = eseg.multicut.compute_edge_costs(1.0 - overlaps) 94 # set background weights to be maximally repulsive 95 bg_edges = (uv_ids == 0).any(axis=1) 96 costs[bg_edges] = -8.0 97 98 node_labels = eseg.multicut.multicut_decomposition(graph, costs, beta=beta) 99 segmentation = nifty.tools.take(node_labels, seg_2d) 100 101 if min_z_extent is not None and min_z_extent > 0: 102 props = regionprops(segmentation) 103 filter_ids = [] 104 for prop in props: 105 box = prop.bbox 106 z_extent = box[3] - box[0] 107 if z_extent < min_z_extent: 108 filter_ids.append(prop.label) 109 if filter_ids: 110 segmentation[np.isin(segmentation, filter_ids)] = 0 111 112 return segmentation 113 114 115def _segment_compartments_3d( 116 prediction, 117 boundary_threshold=0.4, 118 n_slices_exclude=0, 119 min_z_extent=10, 120): 121 distances = distance_transform_edt(prediction < boundary_threshold).astype("float32") 122 seg_2d = np.zeros(prediction.shape, dtype="uint32") 123 124 offset = 0 125 # Parallelize? 126 for z in range(seg_2d.shape[0]): 127 if z < n_slices_exclude or z >= seg_2d.shape[0] - n_slices_exclude: 128 continue 129 seg_z = _segment_compartments_2d(prediction[z], distances=distances[z]) 130 seg_z[seg_z != 0] += offset 131 offset = max(int(seg_z.max()), offset) 132 seg_2d[z] = seg_z 133 134 seg = _merge_segmentation_3d(seg_2d, min_z_extent) 135 seg = _postprocess_seg_3d(seg) 136 137 # import napari 138 # v = napari.Viewer() 139 # v.add_image(prediction) 140 # v.add_image(distances) 141 # v.add_labels(seg_2d) 142 # v.add_labels(seg) 143 # napari.run() 144 145 return seg 146 147 148def segment_compartments( 149 input_volume: np.ndarray, 150 model_path: Optional[str] = None, 151 model: Optional[torch.nn.Module] = None, 152 tiling: Optional[Dict[str, Dict[str, int]]] = None, 153 verbose: bool = True, 154 return_predictions: bool = False, 155 scale: Optional[List[float]] = None, 156 mask: Optional[np.ndarray] = None, 157 n_slices_exclude: int = 0, 158 boundary_threshold: float=0.4, 159 **kwargs, 160) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 161 """Segment synaptic compartments in an input volume. 162 163 Args: 164 input_volume: The input volume to segment. 165 model_path: The path to the model checkpoint if `model` is not provided. 166 model: Pre-loaded model. Either `model_path` or `model` is required. 167 tiling: The tiling configuration for the prediction. 168 verbose: Whether to print timing information. 169 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 170 scale: The scale factor to use for rescaling the input volume before prediction. 171 n_slices_exclude: 172 boundary_threshold: Threshold that determines when the prediction of the network is foreground for the segmentation. Need higher threshold than default for TEM. 173 174 Returns: 175 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 176 and the predictions if return_predictions is True. 177 """ 178 if verbose: 179 print("Segmenting compartments in volume of shape", input_volume.shape) 180 181 # Create the scaler to handle prediction with a different scaling factor. 182 scaler = _Scaler(scale, verbose) 183 input_volume = scaler.scale_input(input_volume) 184 185 # Run prediction. Support models with a single or multiple channels, 186 # assuming that the first channel is the boundary prediction. 187 pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose) 188 189 # Remove channel axis if necessary. 190 if pred.ndim != input_volume.ndim: 191 assert pred.ndim == input_volume.ndim + 1 192 pred = pred[0] 193 194 # Run the compartment segmentation. 195 # We may want to expose some of the parameters here. 196 t0 = time.time() 197 if input_volume.ndim == 2: 198 seg = _segment_compartments_2d(pred, boundary_threshold=boundary_threshold) 199 else: 200 seg = _segment_compartments_3d(pred, n_slices_exclude=n_slices_exclude, boundary_threshold=boundary_threshold) 201 if verbose: 202 print("Run segmentation in", time.time() - t0, "s") 203 204 seg = scaler.rescale_output(seg, is_segmentation=True) 205 206 if return_predictions: 207 pred = scaler.rescale_output(pred, is_segmentation=False) 208 return seg, pred 209 return seg
def
segment_compartments( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, verbose: bool = True, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, n_slices_exclude: int = 0, boundary_threshold: float = 0.4, **kwargs) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
149def segment_compartments( 150 input_volume: np.ndarray, 151 model_path: Optional[str] = None, 152 model: Optional[torch.nn.Module] = None, 153 tiling: Optional[Dict[str, Dict[str, int]]] = None, 154 verbose: bool = True, 155 return_predictions: bool = False, 156 scale: Optional[List[float]] = None, 157 mask: Optional[np.ndarray] = None, 158 n_slices_exclude: int = 0, 159 boundary_threshold: float=0.4, 160 **kwargs, 161) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 162 """Segment synaptic compartments in an input volume. 163 164 Args: 165 input_volume: The input volume to segment. 166 model_path: The path to the model checkpoint if `model` is not provided. 167 model: Pre-loaded model. Either `model_path` or `model` is required. 168 tiling: The tiling configuration for the prediction. 169 verbose: Whether to print timing information. 170 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 171 scale: The scale factor to use for rescaling the input volume before prediction. 172 n_slices_exclude: 173 boundary_threshold: Threshold that determines when the prediction of the network is foreground for the segmentation. Need higher threshold than default for TEM. 174 175 Returns: 176 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 177 and the predictions if return_predictions is True. 178 """ 179 if verbose: 180 print("Segmenting compartments in volume of shape", input_volume.shape) 181 182 # Create the scaler to handle prediction with a different scaling factor. 183 scaler = _Scaler(scale, verbose) 184 input_volume = scaler.scale_input(input_volume) 185 186 # Run prediction. Support models with a single or multiple channels, 187 # assuming that the first channel is the boundary prediction. 188 pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose) 189 190 # Remove channel axis if necessary. 191 if pred.ndim != input_volume.ndim: 192 assert pred.ndim == input_volume.ndim + 1 193 pred = pred[0] 194 195 # Run the compartment segmentation. 196 # We may want to expose some of the parameters here. 197 t0 = time.time() 198 if input_volume.ndim == 2: 199 seg = _segment_compartments_2d(pred, boundary_threshold=boundary_threshold) 200 else: 201 seg = _segment_compartments_3d(pred, n_slices_exclude=n_slices_exclude, boundary_threshold=boundary_threshold) 202 if verbose: 203 print("Run segmentation in", time.time() - t0, "s") 204 205 seg = scaler.rescale_output(seg, is_segmentation=True) 206 207 if return_predictions: 208 pred = scaler.rescale_output(pred, is_segmentation=False) 209 return seg, pred 210 return seg
Segment synaptic compartments in an input volume.
Arguments:
- input_volume: The input volume to segment.
- model_path: The path to the model checkpoint if
model
is not provided. - model: Pre-loaded model. Either
model_path
ormodel
is required. - tiling: The tiling configuration for the prediction.
- verbose: Whether to print timing information.
- return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
- scale: The scale factor to use for rescaling the input volume before prediction.
- n_slices_exclude:
- boundary_threshold: Threshold that determines when the prediction of the network is foreground for the segmentation. Need higher threshold than default for TEM.
Returns:
The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.