synapse_net.inference.active_zone

  1import time
  2from typing import Dict, List, Optional, Tuple, Union
  3
  4import elf.parallel as parallel
  5import numpy as np
  6import torch
  7
  8from skimage.segmentation import find_boundaries
  9from synapse_net.inference.util import get_prediction, _Scaler
 10
 11
 12def find_intersection_boundary(segmented_AZ: np.ndarray, segmented_compartment: np.ndarray) -> np.ndarray:
 13    """Find the intersection of the boundaries of each objects in segmented_compartment with segmented_AZ.
 14
 15    Args:
 16        segmented_AZ: 3D array representing the active zone (AZ).
 17        segmented_compartment: 3D array representing the compartment, with multiple labels.
 18
 19    Returns:
 20        Array with the cumulative intersection of all boundaries of segmented_compartment labels with segmented_AZ.
 21    """
 22    # Step 0: Initialize an empty array to accumulate intersections
 23    cumulative_intersection = np.zeros_like(segmented_AZ, dtype=bool)
 24
 25    # Step 1: Loop through each unique label in segmented_compartment (excluding 0 if it represents background)
 26    labels = np.unique(segmented_compartment)
 27    labels = labels[labels != 0]  # Exclude background label (0) if necessary
 28
 29    for label in labels:
 30        # Step 2: Create a binary mask for the current label
 31        label_mask = (segmented_compartment == label)
 32
 33        # Step 3: Find the boundary of the current label's compartment
 34        boundary_compartment = find_boundaries(label_mask, mode='outer')
 35
 36        # Step 4: Find the intersection with the AZ for this label's boundary
 37        intersection = np.logical_and(boundary_compartment, segmented_AZ)
 38
 39        # Step 5: Accumulate intersections for each label
 40        cumulative_intersection = np.logical_or(cumulative_intersection, intersection)
 41
 42    return cumulative_intersection.astype(int)  # Convert boolean array to int (1 for intersecting points, 0 elsewhere)
 43
 44
 45def _run_segmentation(
 46    foreground, verbose, min_size,
 47    # blocking shapes for parallel computation
 48    block_shape=(128, 256, 256),
 49):
 50
 51    # get the segmentation via seeded watershed
 52    t0 = time.time()
 53    seg = parallel.label(foreground > 0.5, block_shape=block_shape, verbose=verbose)
 54    if verbose:
 55        print("Compute connected components in", time.time() - t0, "s")
 56
 57    # size filter
 58    t0 = time.time()
 59    ids, sizes = parallel.unique(seg, return_counts=True, block_shape=block_shape, verbose=verbose)
 60    filter_ids = ids[sizes < min_size]
 61    seg[np.isin(seg, filter_ids)] = 0
 62    if verbose:
 63        print("Size filter in", time.time() - t0, "s")
 64    seg = np.where(seg > 0, 1, 0)
 65    return seg
 66
 67
 68def segment_active_zone(
 69    input_volume: np.ndarray,
 70    model_path: Optional[str] = None,
 71    model: Optional[torch.nn.Module] = None,
 72    tiling: Optional[Dict[str, Dict[str, int]]] = None,
 73    min_size: int = 500,
 74    verbose: bool = True,
 75    return_predictions: bool = False,
 76    scale: Optional[List[float]] = None,
 77    mask: Optional[np.ndarray] = None,
 78    compartment: Optional[np.ndarray] = None,
 79) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
 80    """Segment active zones in an input volume.
 81
 82    Args:
 83        input_volume: The input volume to segment.
 84        model_path: The path to the model checkpoint if `model` is not provided.
 85        model: Pre-loaded model. Either `model_path` or `model` is required.
 86        tiling: The tiling configuration for the prediction.
 87        verbose: Whether to print timing information.
 88        scale: The scale factor to use for rescaling the input volume before prediction.
 89        mask: An optional mask that is used to restrict the segmentation.
 90        compartment: Pass a compartment segmentation, to intersect the boundaries of the
 91            compartments with the active zone prediction.
 92
 93    Returns:
 94        The foreground mask as a numpy array.
 95    """
 96    if verbose:
 97        print("Segmenting AZ in volume of shape", input_volume.shape)
 98    # Create the scaler to handle prediction with a different scaling factor.
 99    scaler = _Scaler(scale, verbose)
100    input_volume = scaler.scale_input(input_volume)
101
102    # Rescale the mask if it was given and run prediction.
103    if mask is not None:
104        mask = scaler.scale_input(mask, is_segmentation=True)
105    pred = get_prediction(input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose)
106
107    # Run segmentation and rescale the result if necessary.
108    foreground = pred[0]
109    print(f"shape {foreground.shape}")
110
111    segmentation = _run_segmentation(foreground, verbose=verbose, min_size=min_size)
112    segmentation = scaler.rescale_output(segmentation, is_segmentation=True)
113
114    # Returning prediciton and intersection currently not possible.
115    if return_predictions:
116        assert compartment is None
117        pred = scaler.rescale_output(pred, is_segmentation=False)
118        return segmentation, pred
119
120    if compartment is not None:
121        assert not return_predictions
122        compartment = scaler.scale_input(input_volume, is_segmentation=True)
123        intersection = find_intersection_boundary(segmentation, compartment)
124        return segmentation, intersection
125
126    return segmentation
def find_intersection_boundary( segmented_AZ: numpy.ndarray, segmented_compartment: numpy.ndarray) -> numpy.ndarray:
13def find_intersection_boundary(segmented_AZ: np.ndarray, segmented_compartment: np.ndarray) -> np.ndarray:
14    """Find the intersection of the boundaries of each objects in segmented_compartment with segmented_AZ.
15
16    Args:
17        segmented_AZ: 3D array representing the active zone (AZ).
18        segmented_compartment: 3D array representing the compartment, with multiple labels.
19
20    Returns:
21        Array with the cumulative intersection of all boundaries of segmented_compartment labels with segmented_AZ.
22    """
23    # Step 0: Initialize an empty array to accumulate intersections
24    cumulative_intersection = np.zeros_like(segmented_AZ, dtype=bool)
25
26    # Step 1: Loop through each unique label in segmented_compartment (excluding 0 if it represents background)
27    labels = np.unique(segmented_compartment)
28    labels = labels[labels != 0]  # Exclude background label (0) if necessary
29
30    for label in labels:
31        # Step 2: Create a binary mask for the current label
32        label_mask = (segmented_compartment == label)
33
34        # Step 3: Find the boundary of the current label's compartment
35        boundary_compartment = find_boundaries(label_mask, mode='outer')
36
37        # Step 4: Find the intersection with the AZ for this label's boundary
38        intersection = np.logical_and(boundary_compartment, segmented_AZ)
39
40        # Step 5: Accumulate intersections for each label
41        cumulative_intersection = np.logical_or(cumulative_intersection, intersection)
42
43    return cumulative_intersection.astype(int)  # Convert boolean array to int (1 for intersecting points, 0 elsewhere)

Find the intersection of the boundaries of each objects in segmented_compartment with segmented_AZ.

Arguments:
  • segmented_AZ: 3D array representing the active zone (AZ).
  • segmented_compartment: 3D array representing the compartment, with multiple labels.
Returns:

Array with the cumulative intersection of all boundaries of segmented_compartment labels with segmented_AZ.

def segment_active_zone( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 500, verbose: bool = True, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, compartment: Optional[numpy.ndarray] = None) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
 69def segment_active_zone(
 70    input_volume: np.ndarray,
 71    model_path: Optional[str] = None,
 72    model: Optional[torch.nn.Module] = None,
 73    tiling: Optional[Dict[str, Dict[str, int]]] = None,
 74    min_size: int = 500,
 75    verbose: bool = True,
 76    return_predictions: bool = False,
 77    scale: Optional[List[float]] = None,
 78    mask: Optional[np.ndarray] = None,
 79    compartment: Optional[np.ndarray] = None,
 80) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
 81    """Segment active zones in an input volume.
 82
 83    Args:
 84        input_volume: The input volume to segment.
 85        model_path: The path to the model checkpoint if `model` is not provided.
 86        model: Pre-loaded model. Either `model_path` or `model` is required.
 87        tiling: The tiling configuration for the prediction.
 88        verbose: Whether to print timing information.
 89        scale: The scale factor to use for rescaling the input volume before prediction.
 90        mask: An optional mask that is used to restrict the segmentation.
 91        compartment: Pass a compartment segmentation, to intersect the boundaries of the
 92            compartments with the active zone prediction.
 93
 94    Returns:
 95        The foreground mask as a numpy array.
 96    """
 97    if verbose:
 98        print("Segmenting AZ in volume of shape", input_volume.shape)
 99    # Create the scaler to handle prediction with a different scaling factor.
100    scaler = _Scaler(scale, verbose)
101    input_volume = scaler.scale_input(input_volume)
102
103    # Rescale the mask if it was given and run prediction.
104    if mask is not None:
105        mask = scaler.scale_input(mask, is_segmentation=True)
106    pred = get_prediction(input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose)
107
108    # Run segmentation and rescale the result if necessary.
109    foreground = pred[0]
110    print(f"shape {foreground.shape}")
111
112    segmentation = _run_segmentation(foreground, verbose=verbose, min_size=min_size)
113    segmentation = scaler.rescale_output(segmentation, is_segmentation=True)
114
115    # Returning prediciton and intersection currently not possible.
116    if return_predictions:
117        assert compartment is None
118        pred = scaler.rescale_output(pred, is_segmentation=False)
119        return segmentation, pred
120
121    if compartment is not None:
122        assert not return_predictions
123        compartment = scaler.scale_input(input_volume, is_segmentation=True)
124        intersection = find_intersection_boundary(segmentation, compartment)
125        return segmentation, intersection
126
127    return segmentation

Segment active zones in an input volume.

Arguments:
  • input_volume: The input volume to segment.
  • model_path: The path to the model checkpoint if model is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • verbose: Whether to print timing information.
  • scale: The scale factor to use for rescaling the input volume before prediction.
  • mask: An optional mask that is used to restrict the segmentation.
  • compartment: Pass a compartment segmentation, to intersect the boundaries of the compartments with the active zone prediction.
Returns:

The foreground mask as a numpy array.