synapse_net.inference.compartments

  1import time
  2from typing import Dict, List, Optional, Tuple, Union
  3
  4import numpy as np
  5import torch
  6
  7import bioimage_cpp as bic
  8import elf.segmentation as eseg
  9from elf.tracking.tracking_utils import compute_edges_from_overlap
 10from scipy.ndimage import distance_transform_edt, binary_closing
 11from skimage.measure import label, regionprops
 12from skimage.segmentation import watershed
 13from skimage.morphology import remove_small_holes, local_maxima
 14
 15from synapse_net.inference.util import get_prediction, _Scaler, _postprocess_seg_3d
 16
 17
 18def _segment_compartments_2d(
 19    boundaries,
 20    boundary_threshold=0.4,  # Threshold for the boundary distance computation.
 21    large_seed_distance=30,  # The distance threshold for computing large seeds (= components).
 22    distances=None,  # Pre-computed distances to take into account z-context.
 23):
 24    # Compoute distances if already not precomputed.
 25    if distances is None:
 26        distances = distance_transform_edt(boundaries < boundary_threshold).astype("float32")
 27        distances_z = distances
 28    else:
 29        # If the distances were pre-computed then compute them again in 2d.
 30        # This is needed for inserting small seeds from maxima, otherwise we will get spurious maxima.
 31        distances_z = distance_transform_edt(boundaries < boundary_threshold).astype("float32")
 32
 33    # Find the large seeds as connected components in the distances > large_seed_distance.
 34    seeds = label(distances > large_seed_distance)
 35
 36    # Remove to small large seeds.
 37    min_seed_area = 50
 38    ids, sizes = np.unique(seeds, return_counts=True)
 39    remove_ids = ids[sizes < min_seed_area]
 40    seeds[np.isin(seeds, remove_ids)] = 0
 41
 42    # Compute the small seeds = local maxima of the in-plane distance map
 43    small_seeds = label(local_maxima(distances_z, allow_borders=True))
 44
 45    # We only keep small seeds that don't intersect with a large seed.
 46    props = regionprops(small_seeds, seeds)
 47    keep_seeds = [prop.label for prop in props if prop.max_intensity == 0]
 48    keep_mask = np.isin(small_seeds, keep_seeds)
 49
 50    # Add up the small seeds we keep with the large seeds.
 51    all_seeds = seeds.copy()
 52    seed_offset = seeds.max()
 53    all_seeds[keep_mask] = (small_seeds[keep_mask] + seed_offset)
 54
 55    # Run watershed to get the segmentation.
 56    hmap = boundaries + (distances.max() - distances) / distances.max()
 57    raw_segmentation = watershed(hmap, markers=all_seeds)
 58
 59    # Thee are the large seed ids that we will keep.
 60    keep_ids = list(range(1, seed_offset + 1))
 61
 62    # Iterate over the ids, only keep large seeds and remove holes in their respective masks.
 63    props = regionprops(raw_segmentation)
 64    segmentation = np.zeros_like(raw_segmentation)
 65    for prop in props:
 66        if prop.label not in keep_ids:
 67            continue
 68
 69        # Get bounding box and mask.
 70        bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:2], prop.bbox[2:]))
 71        mask = raw_segmentation[bb] == prop.label
 72
 73        # Fill small holes and apply closing.
 74        mask = remove_small_holes(mask, area_threshold=500)
 75        mask = np.logical_or(binary_closing(mask, iterations=4), mask)
 76        segmentation[bb][mask] = prop.label
 77
 78    return segmentation
 79
 80
 81def _merge_segmentation_3d(seg_2d, beta=0.5, min_z_extent=10):
 82    edges = compute_edges_from_overlap(seg_2d, verbose=False)
 83
 84    uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges])
 85    overlaps = np.array([edge["score"] for edge in edges])
 86
 87    n_nodes = int(seg_2d.max() + 1)
 88    graph = bic.graph.UndirectedGraph(n_nodes)
 89    graph.insert_edges(uv_ids)
 90
 91    costs = eseg.multicut.compute_edge_costs(1.0 - overlaps)
 92    # set background weights to be maximally repulsive
 93    bg_edges = (uv_ids == 0).any(axis=1)
 94    costs[bg_edges] = -8.0
 95
 96    node_labels = eseg.multicut.multicut_decomposition(graph, costs, beta=beta)
 97    segmentation = node_labels[seg_2d]
 98
 99    if min_z_extent is not None and min_z_extent > 0:
100        props = regionprops(segmentation)
101        filter_ids = []
102        for prop in props:
103            box = prop.bbox
104            z_extent = box[3] - box[0]
105            if z_extent < min_z_extent:
106                filter_ids.append(prop.label)
107        if filter_ids:
108            segmentation[np.isin(segmentation, filter_ids)] = 0
109
110    return segmentation
111
112
113def _segment_compartments_3d(
114    prediction,
115    boundary_threshold=0.4,
116    n_slices_exclude=0,
117    min_z_extent=10,
118):
119    distances = distance_transform_edt(prediction < boundary_threshold).astype("float32")
120    seg_2d = np.zeros(prediction.shape, dtype="uint32")
121
122    offset = 0
123    # Parallelize?
124    for z in range(seg_2d.shape[0]):
125        if z < n_slices_exclude or z >= seg_2d.shape[0] - n_slices_exclude:
126            continue
127        seg_z = _segment_compartments_2d(prediction[z], distances=distances[z])
128        seg_z[seg_z != 0] += offset
129        offset = max(int(seg_z.max()), offset)
130        seg_2d[z] = seg_z
131
132    seg = _merge_segmentation_3d(seg_2d, min_z_extent)
133    seg = _postprocess_seg_3d(seg)
134
135    # import napari
136    # v = napari.Viewer()
137    # v.add_image(prediction)
138    # v.add_image(distances)
139    # v.add_labels(seg_2d)
140    # v.add_labels(seg)
141    # napari.run()
142
143    return seg
144
145
146def segment_compartments(
147    input_volume: np.ndarray,
148    model_path: Optional[str] = None,
149    model: Optional[torch.nn.Module] = None,
150    tiling: Optional[Dict[str, Dict[str, int]]] = None,
151    verbose: bool = True,
152    return_predictions: bool = False,
153    scale: Optional[List[float]] = None,
154    mask: Optional[np.ndarray] = None,
155    n_slices_exclude: int = 0,
156    boundary_threshold: float=0.4,
157    **kwargs,
158) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
159    """Segment synaptic compartments in an input volume.
160
161    Args:
162        input_volume: The input volume to segment.
163        model_path: The path to the model checkpoint if `model` is not provided.
164        model: Pre-loaded model. Either `model_path` or `model` is required.
165        tiling: The tiling configuration for the prediction.
166        verbose: Whether to print timing information.
167        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
168        scale: The scale factor to use for rescaling the input volume before prediction.
169        n_slices_exclude:
170        boundary_threshold: Threshold that determines when the prediction of the network is foreground for the segmentation. Need higher threshold than default for TEM.
171
172    Returns:
173        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
174        and the predictions if return_predictions is True.
175    """
176    if verbose:
177        print("Segmenting compartments in volume of shape", input_volume.shape)
178
179    # Create the scaler to handle prediction with a different scaling factor.
180    scaler = _Scaler(scale, verbose)
181    input_volume = scaler.scale_input(input_volume)
182
183    # Run prediction. Support models with a single or multiple channels,
184    # assuming that the first channel is the boundary prediction.
185    pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose)
186
187    # Remove channel axis if necessary.
188    if pred.ndim != input_volume.ndim:
189        assert pred.ndim == input_volume.ndim + 1
190        pred = pred[0]
191
192    # Run the compartment segmentation.
193    # We may want to expose some of the parameters here.
194    t0 = time.time()
195    if input_volume.ndim == 2:
196        seg = _segment_compartments_2d(pred, boundary_threshold=boundary_threshold)
197    else:
198        seg = _segment_compartments_3d(pred, n_slices_exclude=n_slices_exclude, boundary_threshold=boundary_threshold)
199    if verbose:
200        print("Run segmentation in", time.time() - t0, "s")
201
202    seg = scaler.rescale_output(seg, is_segmentation=True)
203
204    if return_predictions:
205        pred = scaler.rescale_output(pred, is_segmentation=False)
206        return seg, pred
207    return seg
def segment_compartments( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, verbose: bool = True, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, n_slices_exclude: int = 0, boundary_threshold: float = 0.4, **kwargs) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
147def segment_compartments(
148    input_volume: np.ndarray,
149    model_path: Optional[str] = None,
150    model: Optional[torch.nn.Module] = None,
151    tiling: Optional[Dict[str, Dict[str, int]]] = None,
152    verbose: bool = True,
153    return_predictions: bool = False,
154    scale: Optional[List[float]] = None,
155    mask: Optional[np.ndarray] = None,
156    n_slices_exclude: int = 0,
157    boundary_threshold: float=0.4,
158    **kwargs,
159) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
160    """Segment synaptic compartments in an input volume.
161
162    Args:
163        input_volume: The input volume to segment.
164        model_path: The path to the model checkpoint if `model` is not provided.
165        model: Pre-loaded model. Either `model_path` or `model` is required.
166        tiling: The tiling configuration for the prediction.
167        verbose: Whether to print timing information.
168        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
169        scale: The scale factor to use for rescaling the input volume before prediction.
170        n_slices_exclude:
171        boundary_threshold: Threshold that determines when the prediction of the network is foreground for the segmentation. Need higher threshold than default for TEM.
172
173    Returns:
174        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
175        and the predictions if return_predictions is True.
176    """
177    if verbose:
178        print("Segmenting compartments in volume of shape", input_volume.shape)
179
180    # Create the scaler to handle prediction with a different scaling factor.
181    scaler = _Scaler(scale, verbose)
182    input_volume = scaler.scale_input(input_volume)
183
184    # Run prediction. Support models with a single or multiple channels,
185    # assuming that the first channel is the boundary prediction.
186    pred = get_prediction(input_volume, tiling=tiling, model_path=model_path, model=model, verbose=verbose)
187
188    # Remove channel axis if necessary.
189    if pred.ndim != input_volume.ndim:
190        assert pred.ndim == input_volume.ndim + 1
191        pred = pred[0]
192
193    # Run the compartment segmentation.
194    # We may want to expose some of the parameters here.
195    t0 = time.time()
196    if input_volume.ndim == 2:
197        seg = _segment_compartments_2d(pred, boundary_threshold=boundary_threshold)
198    else:
199        seg = _segment_compartments_3d(pred, n_slices_exclude=n_slices_exclude, boundary_threshold=boundary_threshold)
200    if verbose:
201        print("Run segmentation in", time.time() - t0, "s")
202
203    seg = scaler.rescale_output(seg, is_segmentation=True)
204
205    if return_predictions:
206        pred = scaler.rescale_output(pred, is_segmentation=False)
207        return seg, pred
208    return seg

Segment synaptic compartments in an input volume.

Arguments:
  • input_volume: The input volume to segment.
  • model_path: The path to the model checkpoint if model is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • verbose: Whether to print timing information.
  • return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
  • scale: The scale factor to use for rescaling the input volume before prediction.
  • n_slices_exclude:
  • boundary_threshold: Threshold that determines when the prediction of the network is foreground for the segmentation. Need higher threshold than default for TEM.
Returns:

The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.