synapse_net.inference.cristae
1import time 2from typing import Dict, List, Optional, Tuple, Union 3 4import elf.parallel as parallel 5import numpy as np 6import torch 7 8from synapse_net.inference.util import get_prediction, _Scaler 9 10 11def _run_segmentation( 12 foreground, verbose, min_size, 13 # blocking shapes for parallel computation 14 block_shape=(128, 256, 256), 15): 16 17 # get the segmentation via seeded watershed 18 t0 = time.time() 19 seg = parallel.label(foreground > 0.5, block_shape=block_shape, verbose=verbose) 20 if verbose: 21 print("Compute connected components in", time.time() - t0, "s") 22 23 # size filter 24 t0 = time.time() 25 ids, sizes = parallel.unique(seg, return_counts=True, block_shape=block_shape, verbose=verbose) 26 filter_ids = ids[sizes < min_size] 27 seg[np.isin(seg, filter_ids)] = 0 28 if verbose: 29 print("Size filter in", time.time() - t0, "s") 30 seg = np.where(seg > 0, 1, 0) 31 return seg 32 33 34def segment_cristae( 35 input_volume: np.ndarray, 36 model_path: Optional[str] = None, 37 model: Optional[torch.nn.Module] = None, 38 tiling: Optional[Dict[str, Dict[str, int]]] = None, 39 min_size: int = 500, 40 verbose: bool = True, 41 distance_based_segmentation: bool = False, 42 return_predictions: bool = False, 43 scale: Optional[List[float]] = None, 44 mask: Optional[np.ndarray] = None, 45 **kwargs 46) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 47 """Segment cristae in an input volume. 48 49 Args: 50 input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria 51 model_path: The path to the model checkpoint if `model` is not provided. 52 model: Pre-loaded model. Either `model_path` or `model` is required. 53 tiling: The tiling configuration for the prediction. 54 min_size: The minimum size of a cristae to be considered. 55 verbose: Whether to print timing information. 56 distance_based_segmentation: Whether to use distance-based segmentation. 57 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 58 scale: The scale factor to use for rescaling the input volume before prediction. 59 mask: An optional mask that is used to restrict the segmentation. 60 61 Returns: 62 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 63 and the predictions if return_predictions is True. 64 """ 65 mitochondria = kwargs.pop("extra_segmentation") 66 with_channels = kwargs.pop("with_channels", True) 67 channels_to_standardize = kwargs.pop("channels_to_standardize", [0]) 68 if verbose: 69 print("Segmenting cristae in volume of shape", input_volume.shape) 70 # Create the scaler to handle prediction with a different scaling factor. 71 scaler = _Scaler(scale, verbose) 72 # rescale each channel 73 volume = scaler.scale_input(input_volume) 74 mito_seg = scaler.scale_input(mitochondria, is_segmentation=True) 75 input_volume = np.stack([volume, mito_seg], axis=0) 76 77 # Run prediction and segmentation. 78 if mask is not None: 79 mask = scaler.scale_input(mask, is_segmentation=True) 80 pred = get_prediction( 81 input_volume, model_path=model_path, model=model, mask=mask, 82 tiling=tiling, with_channels=with_channels, channels_to_standardize=channels_to_standardize, verbose=verbose 83 ) 84 foreground, boundaries = pred[:2] 85 seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size) 86 seg = scaler.rescale_output(seg, is_segmentation=True) 87 88 if return_predictions: 89 pred = scaler.rescale_output(pred, is_segmentation=False) 90 return seg, pred 91 return seg
def
segment_cristae( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 500, verbose: bool = True, distance_based_segmentation: bool = False, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, **kwargs) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
35def segment_cristae( 36 input_volume: np.ndarray, 37 model_path: Optional[str] = None, 38 model: Optional[torch.nn.Module] = None, 39 tiling: Optional[Dict[str, Dict[str, int]]] = None, 40 min_size: int = 500, 41 verbose: bool = True, 42 distance_based_segmentation: bool = False, 43 return_predictions: bool = False, 44 scale: Optional[List[float]] = None, 45 mask: Optional[np.ndarray] = None, 46 **kwargs 47) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 48 """Segment cristae in an input volume. 49 50 Args: 51 input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria 52 model_path: The path to the model checkpoint if `model` is not provided. 53 model: Pre-loaded model. Either `model_path` or `model` is required. 54 tiling: The tiling configuration for the prediction. 55 min_size: The minimum size of a cristae to be considered. 56 verbose: Whether to print timing information. 57 distance_based_segmentation: Whether to use distance-based segmentation. 58 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 59 scale: The scale factor to use for rescaling the input volume before prediction. 60 mask: An optional mask that is used to restrict the segmentation. 61 62 Returns: 63 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 64 and the predictions if return_predictions is True. 65 """ 66 mitochondria = kwargs.pop("extra_segmentation") 67 with_channels = kwargs.pop("with_channels", True) 68 channels_to_standardize = kwargs.pop("channels_to_standardize", [0]) 69 if verbose: 70 print("Segmenting cristae in volume of shape", input_volume.shape) 71 # Create the scaler to handle prediction with a different scaling factor. 72 scaler = _Scaler(scale, verbose) 73 # rescale each channel 74 volume = scaler.scale_input(input_volume) 75 mito_seg = scaler.scale_input(mitochondria, is_segmentation=True) 76 input_volume = np.stack([volume, mito_seg], axis=0) 77 78 # Run prediction and segmentation. 79 if mask is not None: 80 mask = scaler.scale_input(mask, is_segmentation=True) 81 pred = get_prediction( 82 input_volume, model_path=model_path, model=model, mask=mask, 83 tiling=tiling, with_channels=with_channels, channels_to_standardize=channels_to_standardize, verbose=verbose 84 ) 85 foreground, boundaries = pred[:2] 86 seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size) 87 seg = scaler.rescale_output(seg, is_segmentation=True) 88 89 if return_predictions: 90 pred = scaler.rescale_output(pred, is_segmentation=False) 91 return seg, pred 92 return seg
Segment cristae in an input volume.
Arguments:
- input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria
- model_path: The path to the model checkpoint if
model
is not provided. - model: Pre-loaded model. Either
model_path
ormodel
is required. - tiling: The tiling configuration for the prediction.
- min_size: The minimum size of a cristae to be considered.
- verbose: Whether to print timing information.
- distance_based_segmentation: Whether to use distance-based segmentation.
- return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
- scale: The scale factor to use for rescaling the input volume before prediction.
- mask: An optional mask that is used to restrict the segmentation.
Returns:
The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.