synapse_net.inference.cristae
1import time 2from typing import Dict, List, Optional, Tuple, Union 3 4import elf.parallel as parallel 5import numpy as np 6import torch 7 8from synapse_net.inference.util import get_prediction, _Scaler 9 10 11def _run_segmentation( 12 foreground, verbose, min_size, 13 # blocking shapes for parallel computation 14 block_shape=(128, 256, 256), 15): 16 17 # get the segmentation via seeded watershed 18 t0 = time.time() 19 seg = parallel.label(foreground > 0.5, block_shape=block_shape, verbose=verbose) 20 if verbose: 21 print("Compute connected components in", time.time() - t0, "s") 22 23 # size filter 24 t0 = time.time() 25 ids, sizes = parallel.unique(seg, return_counts=True, block_shape=block_shape, verbose=verbose) 26 filter_ids = ids[sizes < min_size] 27 seg[np.isin(seg, filter_ids)] = 0 28 if verbose: 29 print("Size filter in", time.time() - t0, "s") 30 seg = np.where(seg > 0, 1, 0) 31 return seg 32 33 34def segment_cristae( 35 input_volume: np.ndarray, 36 model_path: Optional[str] = None, 37 model: Optional[torch.nn.Module] = None, 38 tiling: Optional[Dict[str, Dict[str, int]]] = None, 39 min_size: int = 500, 40 verbose: bool = True, 41 distance_based_segmentation: bool = False, 42 return_predictions: bool = False, 43 scale: Optional[List[float]] = None, 44 mask: Optional[np.ndarray] = None, 45) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 46 """Segment cristae in an input volume. 47 48 Args: 49 input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria 50 model_path: The path to the model checkpoint if `model` is not provided. 51 model: Pre-loaded model. Either `model_path` or `model` is required. 52 tiling: The tiling configuration for the prediction. 53 min_size: The minimum size of a cristae to be considered. 54 verbose: Whether to print timing information. 55 distance_based_segmentation: Whether to use distance-based segmentation. 56 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 57 scale: The scale factor to use for rescaling the input volume before prediction. 58 mask: An optional mask that is used to restrict the segmentation. 59 60 Returns: 61 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 62 and the predictions if return_predictions is True. 63 """ 64 if verbose: 65 print("Segmenting cristae in volume of shape", input_volume.shape) 66 # Create the scaler to handle prediction with a different scaling factor. 67 scaler = _Scaler(scale, verbose) 68 input_volume = scaler.scale_input(input_volume) 69 70 # Run prediction and segmentation. 71 if mask is not None: 72 mask = scaler.scale_input(mask, is_segmentation=True) 73 pred = get_prediction( 74 input_volume, model_path=model_path, model=model, mask=mask, 75 tiling=tiling, with_channels=True, verbose=verbose 76 ) 77 foreground, boundaries = pred[:2] 78 seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size) 79 seg = scaler.rescale_output(seg, is_segmentation=True) 80 81 if return_predictions: 82 pred = scaler.rescale_output(pred, is_segmentation=False) 83 return seg, pred 84 return seg
def
segment_cristae( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 500, verbose: bool = True, distance_based_segmentation: bool = False, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
35def segment_cristae( 36 input_volume: np.ndarray, 37 model_path: Optional[str] = None, 38 model: Optional[torch.nn.Module] = None, 39 tiling: Optional[Dict[str, Dict[str, int]]] = None, 40 min_size: int = 500, 41 verbose: bool = True, 42 distance_based_segmentation: bool = False, 43 return_predictions: bool = False, 44 scale: Optional[List[float]] = None, 45 mask: Optional[np.ndarray] = None, 46) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 47 """Segment cristae in an input volume. 48 49 Args: 50 input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria 51 model_path: The path to the model checkpoint if `model` is not provided. 52 model: Pre-loaded model. Either `model_path` or `model` is required. 53 tiling: The tiling configuration for the prediction. 54 min_size: The minimum size of a cristae to be considered. 55 verbose: Whether to print timing information. 56 distance_based_segmentation: Whether to use distance-based segmentation. 57 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 58 scale: The scale factor to use for rescaling the input volume before prediction. 59 mask: An optional mask that is used to restrict the segmentation. 60 61 Returns: 62 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 63 and the predictions if return_predictions is True. 64 """ 65 if verbose: 66 print("Segmenting cristae in volume of shape", input_volume.shape) 67 # Create the scaler to handle prediction with a different scaling factor. 68 scaler = _Scaler(scale, verbose) 69 input_volume = scaler.scale_input(input_volume) 70 71 # Run prediction and segmentation. 72 if mask is not None: 73 mask = scaler.scale_input(mask, is_segmentation=True) 74 pred = get_prediction( 75 input_volume, model_path=model_path, model=model, mask=mask, 76 tiling=tiling, with_channels=True, verbose=verbose 77 ) 78 foreground, boundaries = pred[:2] 79 seg = _run_segmentation(foreground, verbose=verbose, min_size=min_size) 80 seg = scaler.rescale_output(seg, is_segmentation=True) 81 82 if return_predictions: 83 pred = scaler.rescale_output(pred, is_segmentation=False) 84 return seg, pred 85 return seg
Segment cristae in an input volume.
Arguments:
- input_volume: The input volume to segment. Expects 2 3D volumes: raw and mitochondria
- model_path: The path to the model checkpoint if
model
is not provided. - model: Pre-loaded model. Either
model_path
ormodel
is required. - tiling: The tiling configuration for the prediction.
- min_size: The minimum size of a cristae to be considered.
- verbose: Whether to print timing information.
- distance_based_segmentation: Whether to use distance-based segmentation.
- return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
- scale: The scale factor to use for rescaling the input volume before prediction.
- mask: An optional mask that is used to restrict the segmentation.
Returns:
The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.