synapse_net.inference.active_zone
1import time 2from typing import Dict, List, Optional, Tuple, Union 3 4import elf.parallel as parallel 5import numpy as np 6import torch 7 8from skimage.segmentation import find_boundaries 9from synapse_net.inference.util import get_prediction, _Scaler 10 11 12def find_intersection_boundary(segmented_AZ: np.ndarray, segmented_compartment: np.ndarray) -> np.ndarray: 13 """Find the intersection of the boundaries of each objects in segmented_compartment with segmented_AZ. 14 15 Args: 16 segmented_AZ: 3D array representing the active zone (AZ). 17 segmented_compartment: 3D array representing the compartment, with multiple labels. 18 19 Returns: 20 Array with the cumulative intersection of all boundaries of segmented_compartment labels with segmented_AZ. 21 """ 22 # Step 0: Initialize an empty array to accumulate intersections 23 cumulative_intersection = np.zeros_like(segmented_AZ, dtype=bool) 24 25 # Step 1: Loop through each unique label in segmented_compartment (excluding 0 if it represents background) 26 labels = np.unique(segmented_compartment) 27 labels = labels[labels != 0] # Exclude background label (0) if necessary 28 29 for label in labels: 30 # Step 2: Create a binary mask for the current label 31 label_mask = (segmented_compartment == label) 32 33 # Step 3: Find the boundary of the current label's compartment 34 boundary_compartment = find_boundaries(label_mask, mode='outer') 35 36 # Step 4: Find the intersection with the AZ for this label's boundary 37 intersection = np.logical_and(boundary_compartment, segmented_AZ) 38 39 # Step 5: Accumulate intersections for each label 40 cumulative_intersection = np.logical_or(cumulative_intersection, intersection) 41 42 return cumulative_intersection.astype(int) # Convert boolean array to int (1 for intersecting points, 0 elsewhere) 43 44 45def _run_segmentation( 46 foreground, verbose, min_size, 47 # blocking shapes for parallel computation 48 block_shape=(128, 256, 256), 49): 50 51 # get the segmentation via seeded watershed 52 t0 = time.time() 53 seg = parallel.label(foreground > 0.5, block_shape=block_shape, verbose=verbose) 54 if verbose: 55 print("Compute connected components in", time.time() - t0, "s") 56 57 # size filter 58 t0 = time.time() 59 ids, sizes = parallel.unique(seg, return_counts=True, block_shape=block_shape, verbose=verbose) 60 filter_ids = ids[sizes < min_size] 61 seg[np.isin(seg, filter_ids)] = 0 62 if verbose: 63 print("Size filter in", time.time() - t0, "s") 64 seg = np.where(seg > 0, 1, 0) 65 return seg 66 67 68def segment_active_zone( 69 input_volume: np.ndarray, 70 model_path: Optional[str] = None, 71 model: Optional[torch.nn.Module] = None, 72 tiling: Optional[Dict[str, Dict[str, int]]] = None, 73 min_size: int = 500, 74 verbose: bool = True, 75 return_predictions: bool = False, 76 scale: Optional[List[float]] = None, 77 mask: Optional[np.ndarray] = None, 78 compartment: Optional[np.ndarray] = None, 79) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 80 """Segment active zones in an input volume. 81 82 Args: 83 input_volume: The input volume to segment. 84 model_path: The path to the model checkpoint if `model` is not provided. 85 model: Pre-loaded model. Either `model_path` or `model` is required. 86 tiling: The tiling configuration for the prediction. 87 verbose: Whether to print timing information. 88 scale: The scale factor to use for rescaling the input volume before prediction. 89 mask: An optional mask that is used to restrict the segmentation. 90 compartment: Pass a compartment segmentation, to intersect the boundaries of the 91 compartments with the active zone prediction. 92 93 Returns: 94 The foreground mask as a numpy array. 95 """ 96 if verbose: 97 print("Segmenting AZ in volume of shape", input_volume.shape) 98 # Create the scaler to handle prediction with a different scaling factor. 99 scaler = _Scaler(scale, verbose) 100 input_volume = scaler.scale_input(input_volume) 101 102 # Rescale the mask if it was given and run prediction. 103 if mask is not None: 104 mask = scaler.scale_input(mask, is_segmentation=True) 105 pred = get_prediction(input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose) 106 107 # Run segmentation and rescale the result if necessary. 108 foreground = pred[0] 109 print(f"shape {foreground.shape}") 110 111 segmentation = _run_segmentation(foreground, verbose=verbose, min_size=min_size) 112 segmentation = scaler.rescale_output(segmentation, is_segmentation=True) 113 114 # Returning prediciton and intersection currently not possible. 115 if return_predictions: 116 assert compartment is None 117 pred = scaler.rescale_output(pred, is_segmentation=False) 118 return segmentation, pred 119 120 if compartment is not None: 121 assert not return_predictions 122 compartment = scaler.scale_input(input_volume, is_segmentation=True) 123 intersection = find_intersection_boundary(segmentation, compartment) 124 return segmentation, intersection 125 126 return segmentation
def
find_intersection_boundary( segmented_AZ: numpy.ndarray, segmented_compartment: numpy.ndarray) -> numpy.ndarray:
13def find_intersection_boundary(segmented_AZ: np.ndarray, segmented_compartment: np.ndarray) -> np.ndarray: 14 """Find the intersection of the boundaries of each objects in segmented_compartment with segmented_AZ. 15 16 Args: 17 segmented_AZ: 3D array representing the active zone (AZ). 18 segmented_compartment: 3D array representing the compartment, with multiple labels. 19 20 Returns: 21 Array with the cumulative intersection of all boundaries of segmented_compartment labels with segmented_AZ. 22 """ 23 # Step 0: Initialize an empty array to accumulate intersections 24 cumulative_intersection = np.zeros_like(segmented_AZ, dtype=bool) 25 26 # Step 1: Loop through each unique label in segmented_compartment (excluding 0 if it represents background) 27 labels = np.unique(segmented_compartment) 28 labels = labels[labels != 0] # Exclude background label (0) if necessary 29 30 for label in labels: 31 # Step 2: Create a binary mask for the current label 32 label_mask = (segmented_compartment == label) 33 34 # Step 3: Find the boundary of the current label's compartment 35 boundary_compartment = find_boundaries(label_mask, mode='outer') 36 37 # Step 4: Find the intersection with the AZ for this label's boundary 38 intersection = np.logical_and(boundary_compartment, segmented_AZ) 39 40 # Step 5: Accumulate intersections for each label 41 cumulative_intersection = np.logical_or(cumulative_intersection, intersection) 42 43 return cumulative_intersection.astype(int) # Convert boolean array to int (1 for intersecting points, 0 elsewhere)
Find the intersection of the boundaries of each objects in segmented_compartment with segmented_AZ.
Arguments:
- segmented_AZ: 3D array representing the active zone (AZ).
- segmented_compartment: 3D array representing the compartment, with multiple labels.
Returns:
Array with the cumulative intersection of all boundaries of segmented_compartment labels with segmented_AZ.
def
segment_active_zone( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 500, verbose: bool = True, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, compartment: Optional[numpy.ndarray] = None) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
69def segment_active_zone( 70 input_volume: np.ndarray, 71 model_path: Optional[str] = None, 72 model: Optional[torch.nn.Module] = None, 73 tiling: Optional[Dict[str, Dict[str, int]]] = None, 74 min_size: int = 500, 75 verbose: bool = True, 76 return_predictions: bool = False, 77 scale: Optional[List[float]] = None, 78 mask: Optional[np.ndarray] = None, 79 compartment: Optional[np.ndarray] = None, 80) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 81 """Segment active zones in an input volume. 82 83 Args: 84 input_volume: The input volume to segment. 85 model_path: The path to the model checkpoint if `model` is not provided. 86 model: Pre-loaded model. Either `model_path` or `model` is required. 87 tiling: The tiling configuration for the prediction. 88 verbose: Whether to print timing information. 89 scale: The scale factor to use for rescaling the input volume before prediction. 90 mask: An optional mask that is used to restrict the segmentation. 91 compartment: Pass a compartment segmentation, to intersect the boundaries of the 92 compartments with the active zone prediction. 93 94 Returns: 95 The foreground mask as a numpy array. 96 """ 97 if verbose: 98 print("Segmenting AZ in volume of shape", input_volume.shape) 99 # Create the scaler to handle prediction with a different scaling factor. 100 scaler = _Scaler(scale, verbose) 101 input_volume = scaler.scale_input(input_volume) 102 103 # Rescale the mask if it was given and run prediction. 104 if mask is not None: 105 mask = scaler.scale_input(mask, is_segmentation=True) 106 pred = get_prediction(input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose) 107 108 # Run segmentation and rescale the result if necessary. 109 foreground = pred[0] 110 print(f"shape {foreground.shape}") 111 112 segmentation = _run_segmentation(foreground, verbose=verbose, min_size=min_size) 113 segmentation = scaler.rescale_output(segmentation, is_segmentation=True) 114 115 # Returning prediciton and intersection currently not possible. 116 if return_predictions: 117 assert compartment is None 118 pred = scaler.rescale_output(pred, is_segmentation=False) 119 return segmentation, pred 120 121 if compartment is not None: 122 assert not return_predictions 123 compartment = scaler.scale_input(input_volume, is_segmentation=True) 124 intersection = find_intersection_boundary(segmentation, compartment) 125 return segmentation, intersection 126 127 return segmentation
Segment active zones in an input volume.
Arguments:
- input_volume: The input volume to segment.
- model_path: The path to the model checkpoint if
model
is not provided. - model: Pre-loaded model. Either
model_path
ormodel
is required. - tiling: The tiling configuration for the prediction.
- verbose: Whether to print timing information.
- scale: The scale factor to use for rescaling the input volume before prediction.
- mask: An optional mask that is used to restrict the segmentation.
- compartment: Pass a compartment segmentation, to intersect the boundaries of the compartments with the active zone prediction.
Returns:
The foreground mask as a numpy array.