synapse_net.inference.compartments
1import time 2from typing import Dict, List, Optional, Tuple, Union 3 4import numpy as np 5import vigra 6import torch 7 8import elf.segmentation as eseg 9import nifty 10from elf.tracking.tracking_utils import compute_edges_from_overlap 11from scipy.ndimage import distance_transform_edt, binary_closing 12from skimage.measure import label, regionprops 13from skimage.segmentation import watershed 14from skimage.morphology import remove_small_holes 15 16from synapse_net.inference.util import get_prediction, _Scaler, _postprocess_seg_3d 17 18 19def _segment_compartments_2d( 20 boundaries, 21 boundary_threshold=0.4, # Threshold for the boundary distance computation. 22 large_seed_distance=30, # The distance threshold for computing large seeds (= components). 23 distances=None, # Pre-computed distances to take into account z-context. 24): 25 # Compoute distances if already not precomputed. 26 if distances is None: 27 distances = distance_transform_edt(boundaries < boundary_threshold).astype("float32") 28 distances_z = distances 29 else: 30 # If the distances were pre-computed then compute them again in 2d. 31 # This is needed for inserting small seeds from maxima, otherwise we will get spurious maxima. 32 distances_z = distance_transform_edt(boundaries < boundary_threshold).astype("float32") 33 34 # Find the large seeds as connected components in the distances > large_seed_distance. 35 seeds = label(distances > large_seed_distance) 36 37 # Remove to small large seeds. 38 min_seed_area = 50 39 ids, sizes = np.unique(seeds, return_counts=True) 40 remove_ids = ids[sizes < min_seed_area] 41 seeds[np.isin(seeds, remove_ids)] = 0 42 43 # Compute the small seeds = local maxima of the in-plane distance map 44 small_seeds = vigra.analysis.localMaxima(distances_z, marker=np.nan, allowAtBorder=True, allowPlateaus=True) 45 small_seeds = label(np.isnan(small_seeds)) 46 47 # We only keep small seeds that don't intersect with a large seed. 48 props = regionprops(small_seeds, seeds) 49 keep_seeds = [prop.label for prop in props if prop.max_intensity == 0] 50 keep_mask = np.isin(small_seeds, keep_seeds) 51 52 # Add up the small seeds we keep with the large seeds. 53 all_seeds = seeds.copy() 54 seed_offset = seeds.max() 55 all_seeds[keep_mask] = (small_seeds[keep_mask] + seed_offset) 56 57 # Run watershed to get the segmentation. 58 hmap = boundaries + (distances.max() - distances) / distances.max() 59 raw_segmentation = watershed(hmap, markers=all_seeds) 60 61 # Thee are the large seed ids that we will keep. 62 keep_ids = list(range(1, seed_offset + 1)) 63 64 # Iterate over the ids, only keep large seeds and remove holes in their respective masks. 65 props = regionprops(raw_segmentation) 66 segmentation = np.zeros_like(raw_segmentation) 67 for prop in props: 68 if prop.label not in keep_ids: 69 continue 70 71 # Get bounding box and mask. 72 bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:2], prop.bbox[2:])) 73 mask = raw_segmentation[bb] == prop.label 74 75 # Fill small holes and apply closing. 76 mask = remove_small_holes(mask, area_threshold=500) 77 mask = np.logical_or(binary_closing(mask, iterations=4), mask) 78 segmentation[bb][mask] = prop.label 79 80 return segmentation 81 82 83def _merge_segmentation_3d(seg_2d, beta=0.5, min_z_extent=10): 84 edges = compute_edges_from_overlap(seg_2d, verbose=False) 85 86 uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) 87 overlaps = np.array([edge["score"] for edge in edges]) 88 89 n_nodes = int(seg_2d.max() + 1) 90 graph = nifty.graph.undirectedGraph(n_nodes) 91 graph.insertEdges(uv_ids) 92 93 costs = eseg.multicut.compute_edge_costs(1.0 - overlaps) 94 # set background weights to be maximally repulsive 95 bg_edges = (uv_ids == 0).any(axis=1) 96 costs[bg_edges] = -8.0 97 98 node_labels = eseg.multicut.multicut_decomposition(graph, costs, beta=beta) 99 segmentation = nifty.tools.take(node_labels, seg_2d) 100 101 if min_z_extent is not None and min_z_extent > 0: 102 props = regionprops(segmentation) 103 filter_ids = [] 104 for prop in props: 105 box = prop.bbox 106 z_extent = box[3] - box[0] 107 if z_extent < min_z_extent: 108 filter_ids.append(prop.label) 109 if filter_ids: 110 segmentation[np.isin(segmentation, filter_ids)] = 0 111 112 return segmentation 113 114 115def _segment_compartments_3d( 116 prediction, 117 boundary_threshold=0.4, 118 n_slices_exclude=0, 119 min_z_extent=10, 120): 121 distances = distance_transform_edt(prediction < boundary_threshold).astype("float32") 122 seg_2d = np.zeros(prediction.shape, dtype="uint32") 123 124 offset = 0 125 # Parallelize? 126 for z in range(seg_2d.shape[0]): 127 if z < n_slices_exclude or z >= seg_2d.shape[0] - n_slices_exclude: 128 continue 129 seg_z = _segment_compartments_2d(prediction[z], distances=distances[z]) 130 seg_z[seg_z != 0] += offset 131 offset = int(seg_z.max()) 132 seg_2d[z] = seg_z 133 134 seg = _merge_segmentation_3d(seg_2d, min_z_extent) 135 seg = _postprocess_seg_3d(seg) 136 137 # import napari 138 # v = napari.Viewer() 139 # v.add_image(prediction) 140 # v.add_image(distances) 141 # v.add_labels(seg_2d) 142 # v.add_labels(seg) 143 # napari.run() 144 145 return seg 146 147 148def segment_compartments( 149 input_volume: np.ndarray, 150 model_path: Optional[str] = None, 151 model: Optional[torch.nn.Module] = None, 152 tiling: Optional[Dict[str, Dict[str, int]]] = None, 153 verbose: bool = True, 154 return_predictions: bool = False, 155 scale: Optional[List[float]] = None, 156 mask: Optional[np.ndarray] = None, 157 n_slices_exclude: int = 0, 158 **kwargs, 159) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 160 """Segment synaptic compartments in an input volume. 161 162 Args: 163 input_volume: The input volume to segment. 164 model_path: The path to the model checkpoint if `model` is not provided. 165 model: Pre-loaded model. Either `model_path` or `model` is required. 166 tiling: The tiling configuration for the prediction. 167 verbose: Whether to print timing information. 168 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 169 scale: The scale factor to use for rescaling the input volume before prediction. 170 n_slices_exclude: 171 172 Returns: 173 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 174 and the predictions if return_predictions is True. 175 """ 176 if verbose: 177 print("Segmenting compartments in volume of shape", input_volume.shape) 178 179 # Create the scaler to handle prediction with a different scaling factor. 180 scaler = _Scaler(scale, verbose) 181 input_volume = scaler.scale_input(input_volume) 182 183 # Run prediction. Support models with a single or multiple channels, 184 # assuming that the first channel is the boundary prediction. 185 pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose) 186 187 # Remove channel axis if necessary. 188 if pred.ndim != input_volume.ndim: 189 assert pred.ndim == input_volume.ndim + 1 190 pred = pred[0] 191 192 # Run the compartment segmentation. 193 # We may want to expose some of the parameters here. 194 t0 = time.time() 195 if input_volume.ndim == 2: 196 seg = _segment_compartments_2d(pred) 197 else: 198 seg = _segment_compartments_3d(pred, n_slices_exclude=n_slices_exclude) 199 if verbose: 200 print("Run segmentation in", time.time() - t0, "s") 201 202 seg = scaler.rescale_output(seg, is_segmentation=True) 203 204 if return_predictions: 205 pred = scaler.rescale_output(pred, is_segmentation=False) 206 return seg, pred 207 return seg
def
segment_compartments( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, verbose: bool = True, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, n_slices_exclude: int = 0, **kwargs) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
149def segment_compartments( 150 input_volume: np.ndarray, 151 model_path: Optional[str] = None, 152 model: Optional[torch.nn.Module] = None, 153 tiling: Optional[Dict[str, Dict[str, int]]] = None, 154 verbose: bool = True, 155 return_predictions: bool = False, 156 scale: Optional[List[float]] = None, 157 mask: Optional[np.ndarray] = None, 158 n_slices_exclude: int = 0, 159 **kwargs, 160) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 161 """Segment synaptic compartments in an input volume. 162 163 Args: 164 input_volume: The input volume to segment. 165 model_path: The path to the model checkpoint if `model` is not provided. 166 model: Pre-loaded model. Either `model_path` or `model` is required. 167 tiling: The tiling configuration for the prediction. 168 verbose: Whether to print timing information. 169 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 170 scale: The scale factor to use for rescaling the input volume before prediction. 171 n_slices_exclude: 172 173 Returns: 174 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 175 and the predictions if return_predictions is True. 176 """ 177 if verbose: 178 print("Segmenting compartments in volume of shape", input_volume.shape) 179 180 # Create the scaler to handle prediction with a different scaling factor. 181 scaler = _Scaler(scale, verbose) 182 input_volume = scaler.scale_input(input_volume) 183 184 # Run prediction. Support models with a single or multiple channels, 185 # assuming that the first channel is the boundary prediction. 186 pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose) 187 188 # Remove channel axis if necessary. 189 if pred.ndim != input_volume.ndim: 190 assert pred.ndim == input_volume.ndim + 1 191 pred = pred[0] 192 193 # Run the compartment segmentation. 194 # We may want to expose some of the parameters here. 195 t0 = time.time() 196 if input_volume.ndim == 2: 197 seg = _segment_compartments_2d(pred) 198 else: 199 seg = _segment_compartments_3d(pred, n_slices_exclude=n_slices_exclude) 200 if verbose: 201 print("Run segmentation in", time.time() - t0, "s") 202 203 seg = scaler.rescale_output(seg, is_segmentation=True) 204 205 if return_predictions: 206 pred = scaler.rescale_output(pred, is_segmentation=False) 207 return seg, pred 208 return seg
Segment synaptic compartments in an input volume.
Arguments:
- input_volume: The input volume to segment.
- model_path: The path to the model checkpoint if
model
is not provided. - model: Pre-loaded model. Either
model_path
ormodel
is required. - tiling: The tiling configuration for the prediction.
- verbose: Whether to print timing information.
- return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
- scale: The scale factor to use for rescaling the input volume before prediction.
- n_slices_exclude:
Returns:
The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.