synapse_net.inference.compartments

  1import time
  2from typing import Dict, List, Optional, Tuple, Union
  3
  4import numpy as np
  5import vigra
  6import torch
  7
  8import elf.segmentation as eseg
  9import nifty
 10from elf.tracking.tracking_utils import compute_edges_from_overlap
 11from scipy.ndimage import distance_transform_edt, binary_closing
 12from skimage.measure import label, regionprops
 13from skimage.segmentation import watershed
 14from skimage.morphology import remove_small_holes
 15
 16from synapse_net.inference.util import get_prediction, _Scaler, _postprocess_seg_3d
 17
 18
 19def _segment_compartments_2d(
 20    boundaries,
 21    boundary_threshold=0.4,  # Threshold for the boundary distance computation.
 22    large_seed_distance=30,  # The distance threshold for computing large seeds (= components).
 23    distances=None,  # Pre-computed distances to take into account z-context.
 24):
 25    # Compoute distances if already not precomputed.
 26    if distances is None:
 27        distances = distance_transform_edt(boundaries < boundary_threshold).astype("float32")
 28        distances_z = distances
 29    else:
 30        # If the distances were pre-computed then compute them again in 2d.
 31        # This is needed for inserting small seeds from maxima, otherwise we will get spurious maxima.
 32        distances_z = distance_transform_edt(boundaries < boundary_threshold).astype("float32")
 33
 34    # Find the large seeds as connected components in the distances > large_seed_distance.
 35    seeds = label(distances > large_seed_distance)
 36
 37    # Remove to small large seeds.
 38    min_seed_area = 50
 39    ids, sizes = np.unique(seeds, return_counts=True)
 40    remove_ids = ids[sizes < min_seed_area]
 41    seeds[np.isin(seeds, remove_ids)] = 0
 42
 43    # Compute the small seeds = local maxima of the in-plane distance map
 44    small_seeds = vigra.analysis.localMaxima(distances_z, marker=np.nan, allowAtBorder=True, allowPlateaus=True)
 45    small_seeds = label(np.isnan(small_seeds))
 46
 47    # We only keep small seeds that don't intersect with a large seed.
 48    props = regionprops(small_seeds, seeds)
 49    keep_seeds = [prop.label for prop in props if prop.max_intensity == 0]
 50    keep_mask = np.isin(small_seeds, keep_seeds)
 51
 52    # Add up the small seeds we keep with the large seeds.
 53    all_seeds = seeds.copy()
 54    seed_offset = seeds.max()
 55    all_seeds[keep_mask] = (small_seeds[keep_mask] + seed_offset)
 56
 57    # Run watershed to get the segmentation.
 58    hmap = boundaries + (distances.max() - distances) / distances.max()
 59    raw_segmentation = watershed(hmap, markers=all_seeds)
 60
 61    # Thee are the large seed ids that we will keep.
 62    keep_ids = list(range(1, seed_offset + 1))
 63
 64    # Iterate over the ids, only keep large seeds and remove holes in their respective masks.
 65    props = regionprops(raw_segmentation)
 66    segmentation = np.zeros_like(raw_segmentation)
 67    for prop in props:
 68        if prop.label not in keep_ids:
 69            continue
 70
 71        # Get bounding box and mask.
 72        bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:2], prop.bbox[2:]))
 73        mask = raw_segmentation[bb] == prop.label
 74
 75        # Fill small holes and apply closing.
 76        mask = remove_small_holes(mask, area_threshold=500)
 77        mask = np.logical_or(binary_closing(mask, iterations=4), mask)
 78        segmentation[bb][mask] = prop.label
 79
 80    return segmentation
 81
 82
 83def _merge_segmentation_3d(seg_2d, beta=0.5, min_z_extent=10):
 84    edges = compute_edges_from_overlap(seg_2d, verbose=False)
 85
 86    uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges])
 87    overlaps = np.array([edge["score"] for edge in edges])
 88
 89    n_nodes = int(seg_2d.max() + 1)
 90    graph = nifty.graph.undirectedGraph(n_nodes)
 91    graph.insertEdges(uv_ids)
 92
 93    costs = eseg.multicut.compute_edge_costs(1.0 - overlaps)
 94    # set background weights to be maximally repulsive
 95    bg_edges = (uv_ids == 0).any(axis=1)
 96    costs[bg_edges] = -8.0
 97
 98    node_labels = eseg.multicut.multicut_decomposition(graph, costs, beta=beta)
 99    segmentation = nifty.tools.take(node_labels, seg_2d)
100
101    if min_z_extent is not None and min_z_extent > 0:
102        props = regionprops(segmentation)
103        filter_ids = []
104        for prop in props:
105            box = prop.bbox
106            z_extent = box[3] - box[0]
107            if z_extent < min_z_extent:
108                filter_ids.append(prop.label)
109        if filter_ids:
110            segmentation[np.isin(segmentation, filter_ids)] = 0
111
112    return segmentation
113
114
115def _segment_compartments_3d(
116    prediction,
117    boundary_threshold=0.4,
118    n_slices_exclude=0,
119    min_z_extent=10,
120):
121    distances = distance_transform_edt(prediction < boundary_threshold).astype("float32")
122    seg_2d = np.zeros(prediction.shape, dtype="uint32")
123
124    offset = 0
125    # Parallelize?
126    for z in range(seg_2d.shape[0]):
127        if z < n_slices_exclude or z >= seg_2d.shape[0] - n_slices_exclude:
128            continue
129        seg_z = _segment_compartments_2d(prediction[z], distances=distances[z])
130        seg_z[seg_z != 0] += offset
131        offset = int(seg_z.max())
132        seg_2d[z] = seg_z
133
134    seg = _merge_segmentation_3d(seg_2d, min_z_extent)
135    seg = _postprocess_seg_3d(seg)
136
137    # import napari
138    # v = napari.Viewer()
139    # v.add_image(prediction)
140    # v.add_image(distances)
141    # v.add_labels(seg_2d)
142    # v.add_labels(seg)
143    # napari.run()
144
145    return seg
146
147
148def segment_compartments(
149    input_volume: np.ndarray,
150    model_path: Optional[str] = None,
151    model: Optional[torch.nn.Module] = None,
152    tiling: Optional[Dict[str, Dict[str, int]]] = None,
153    verbose: bool = True,
154    return_predictions: bool = False,
155    scale: Optional[List[float]] = None,
156    mask: Optional[np.ndarray] = None,
157    n_slices_exclude: int = 0,
158    **kwargs,
159) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
160    """Segment synaptic compartments in an input volume.
161
162    Args:
163        input_volume: The input volume to segment.
164        model_path: The path to the model checkpoint if `model` is not provided.
165        model: Pre-loaded model. Either `model_path` or `model` is required.
166        tiling: The tiling configuration for the prediction.
167        verbose: Whether to print timing information.
168        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
169        scale: The scale factor to use for rescaling the input volume before prediction.
170        n_slices_exclude:
171
172    Returns:
173        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
174        and the predictions if return_predictions is True.
175    """
176    if verbose:
177        print("Segmenting compartments in volume of shape", input_volume.shape)
178
179    # Create the scaler to handle prediction with a different scaling factor.
180    scaler = _Scaler(scale, verbose)
181    input_volume = scaler.scale_input(input_volume)
182
183    # Run prediction. Support models with a single or multiple channels,
184    # assuming that the first channel is the boundary prediction.
185    pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose)
186
187    # Remove channel axis if necessary.
188    if pred.ndim != input_volume.ndim:
189        assert pred.ndim == input_volume.ndim + 1
190        pred = pred[0]
191
192    # Run the compartment segmentation.
193    # We may want to expose some of the parameters here.
194    t0 = time.time()
195    if input_volume.ndim == 2:
196        seg = _segment_compartments_2d(pred)
197    else:
198        seg = _segment_compartments_3d(pred, n_slices_exclude=n_slices_exclude)
199    if verbose:
200        print("Run segmentation in", time.time() - t0, "s")
201
202    seg = scaler.rescale_output(seg, is_segmentation=True)
203
204    if return_predictions:
205        pred = scaler.rescale_output(pred, is_segmentation=False)
206        return seg, pred
207    return seg
def segment_compartments( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, verbose: bool = True, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, n_slices_exclude: int = 0, **kwargs) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
149def segment_compartments(
150    input_volume: np.ndarray,
151    model_path: Optional[str] = None,
152    model: Optional[torch.nn.Module] = None,
153    tiling: Optional[Dict[str, Dict[str, int]]] = None,
154    verbose: bool = True,
155    return_predictions: bool = False,
156    scale: Optional[List[float]] = None,
157    mask: Optional[np.ndarray] = None,
158    n_slices_exclude: int = 0,
159    **kwargs,
160) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
161    """Segment synaptic compartments in an input volume.
162
163    Args:
164        input_volume: The input volume to segment.
165        model_path: The path to the model checkpoint if `model` is not provided.
166        model: Pre-loaded model. Either `model_path` or `model` is required.
167        tiling: The tiling configuration for the prediction.
168        verbose: Whether to print timing information.
169        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
170        scale: The scale factor to use for rescaling the input volume before prediction.
171        n_slices_exclude:
172
173    Returns:
174        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
175        and the predictions if return_predictions is True.
176    """
177    if verbose:
178        print("Segmenting compartments in volume of shape", input_volume.shape)
179
180    # Create the scaler to handle prediction with a different scaling factor.
181    scaler = _Scaler(scale, verbose)
182    input_volume = scaler.scale_input(input_volume)
183
184    # Run prediction. Support models with a single or multiple channels,
185    # assuming that the first channel is the boundary prediction.
186    pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose)
187
188    # Remove channel axis if necessary.
189    if pred.ndim != input_volume.ndim:
190        assert pred.ndim == input_volume.ndim + 1
191        pred = pred[0]
192
193    # Run the compartment segmentation.
194    # We may want to expose some of the parameters here.
195    t0 = time.time()
196    if input_volume.ndim == 2:
197        seg = _segment_compartments_2d(pred)
198    else:
199        seg = _segment_compartments_3d(pred, n_slices_exclude=n_slices_exclude)
200    if verbose:
201        print("Run segmentation in", time.time() - t0, "s")
202
203    seg = scaler.rescale_output(seg, is_segmentation=True)
204
205    if return_predictions:
206        pred = scaler.rescale_output(pred, is_segmentation=False)
207        return seg, pred
208    return seg

Segment synaptic compartments in an input volume.

Arguments:
  • input_volume: The input volume to segment.
  • model_path: The path to the model checkpoint if model is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • verbose: Whether to print timing information.
  • return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
  • scale: The scale factor to use for rescaling the input volume before prediction.
  • n_slices_exclude:
Returns:

The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.