synapse_net.inference.mitochondria
1import time 2from typing import Dict, List, Optional, Tuple, Union 3 4import elf.parallel as parallel 5import numpy as np 6import torch 7 8from synapse_net.inference.util import apply_size_filter, get_prediction, _Scaler, _postprocess_seg_3d 9 10 11def _run_segmentation( 12 foreground, boundaries, verbose, min_size, 13 # blocking shapes for parallel computation 14 block_shape=(128, 256, 256), 15 halo=(48, 48, 48), 16 seed_distance=6, 17 boundary_threshold=0.25, 18 area_threshold=5000, 19): 20 t0 = time.time() 21 dist = parallel.distance_transform( 22 boundaries < boundary_threshold, halo=halo, verbose=verbose, block_shape=block_shape 23 ) 24 if verbose: 25 print("Compute distance transform in", time.time() - t0, "s") 26 27 # Get the segmentation via seeded watershed. 28 t0 = time.time() 29 seeds = np.logical_and(foreground > 0.5, dist > seed_distance) 30 seeds = parallel.label(seeds, block_shape=block_shape, verbose=verbose) 31 if verbose: 32 print("Compute connected components in", time.time() - t0, "s") 33 34 t0 = time.time() 35 hmap = (dist.max() - dist) / dist.max() 36 hmap[boundaries > boundary_threshold] = (hmap + boundaries).max() 37 mask = (foreground + boundaries) > 0.5 38 39 seg = np.zeros_like(seeds) 40 seg = parallel.seeded_watershed( 41 hmap, seeds, block_shape=block_shape, 42 out=seg, mask=mask, verbose=verbose, halo=halo, 43 ) 44 if verbose: 45 print("Compute watershed in", time.time() - t0, "s") 46 47 seg = apply_size_filter(seg, min_size, verbose=verbose, block_shape=block_shape) 48 seg = _postprocess_seg_3d(seg, area_threshold=area_threshold) 49 return seg 50 51 52def segment_mitochondria( 53 input_volume: np.ndarray, 54 model_path: Optional[str] = None, 55 model: Optional[torch.nn.Module] = None, 56 tiling: Optional[Dict[str, Dict[str, int]]] = None, 57 min_size: int = 50000, 58 verbose: bool = True, 59 distance_based_segmentation: bool = False, 60 return_predictions: bool = False, 61 scale: Optional[List[float]] = None, 62 mask: Optional[np.ndarray] = None, 63 seed_distance: int = 6, 64 ws_block_shape: Tuple[int, ...] = (128, 256, 256), 65 ws_halo: Tuple[int, ...] = (48, 48, 48), 66 boundary_threshold: float = 0.25, 67 area_threshold: int = 5000, 68) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 69 """Segment mitochondria in an input volume. 70 71 Args: 72 input_volume: The input volume to segment. 73 model_path: The path to the model checkpoint if `model` is not provided. 74 model: Pre-loaded model. Either `model_path` or `model` is required. 75 tiling: The tiling configuration for the prediction. 76 min_size: The minimum size of a mitochondria to be considered. 77 verbose: Whether to print timing information. 78 distance_based_segmentation: Whether to use distance-based segmentation. 79 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 80 scale: The scale factor to use for rescaling the input volume before prediction. 81 mask: An optional mask that is used to restrict the segmentation. 82 seed_distance: The distance threshold for the seeded watershed. 83 ws_block_shape: The block shape for the seeded watershed. 84 ws_halo: The halo for the seeded watershed. 85 boundary_threshold: The boundary threshold distance calculation. 86 area_threshold: The maximum area (in pixels) of holes to be removed or filled in the segmentation. 87 This parameter is passed to `skimage.morphology.remove_small_holes`. 88 89 Returns: 90 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 91 and the predictions if return_predictions is True. 92 """ 93 if verbose: 94 print("Segmenting mitochondria in volume of shape", input_volume.shape) 95 # Create the scaler to handle prediction with a different scaling factor. 96 scaler = _Scaler(scale, verbose) 97 input_volume = scaler.scale_input(input_volume) 98 99 # Rescale the mask if it was given and run prediction. 100 if mask is not None: 101 mask = scaler.scale_input(mask, is_segmentation=True) 102 pred = get_prediction(input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose) 103 104 # Run segmentation and rescale the result if necessary. 105 foreground, boundaries = pred[:2] 106 seg = _run_segmentation(foreground, boundaries, verbose=verbose, min_size=min_size, seed_distance=seed_distance, 107 block_shape=ws_block_shape, halo=ws_halo, boundary_threshold=boundary_threshold, 108 area_threshold=area_threshold) 109 seg = scaler.rescale_output(seg, is_segmentation=True) 110 111 if return_predictions: 112 pred = scaler.rescale_output(pred, is_segmentation=False) 113 return seg, pred 114 return seg
def
segment_mitochondria( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 50000, verbose: bool = True, distance_based_segmentation: bool = False, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, seed_distance: int = 6, ws_block_shape: Tuple[int, ...] = (128, 256, 256), ws_halo: Tuple[int, ...] = (48, 48, 48), boundary_threshold: float = 0.25, area_threshold: int = 5000) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
53def segment_mitochondria( 54 input_volume: np.ndarray, 55 model_path: Optional[str] = None, 56 model: Optional[torch.nn.Module] = None, 57 tiling: Optional[Dict[str, Dict[str, int]]] = None, 58 min_size: int = 50000, 59 verbose: bool = True, 60 distance_based_segmentation: bool = False, 61 return_predictions: bool = False, 62 scale: Optional[List[float]] = None, 63 mask: Optional[np.ndarray] = None, 64 seed_distance: int = 6, 65 ws_block_shape: Tuple[int, ...] = (128, 256, 256), 66 ws_halo: Tuple[int, ...] = (48, 48, 48), 67 boundary_threshold: float = 0.25, 68 area_threshold: int = 5000, 69) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 70 """Segment mitochondria in an input volume. 71 72 Args: 73 input_volume: The input volume to segment. 74 model_path: The path to the model checkpoint if `model` is not provided. 75 model: Pre-loaded model. Either `model_path` or `model` is required. 76 tiling: The tiling configuration for the prediction. 77 min_size: The minimum size of a mitochondria to be considered. 78 verbose: Whether to print timing information. 79 distance_based_segmentation: Whether to use distance-based segmentation. 80 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 81 scale: The scale factor to use for rescaling the input volume before prediction. 82 mask: An optional mask that is used to restrict the segmentation. 83 seed_distance: The distance threshold for the seeded watershed. 84 ws_block_shape: The block shape for the seeded watershed. 85 ws_halo: The halo for the seeded watershed. 86 boundary_threshold: The boundary threshold distance calculation. 87 area_threshold: The maximum area (in pixels) of holes to be removed or filled in the segmentation. 88 This parameter is passed to `skimage.morphology.remove_small_holes`. 89 90 Returns: 91 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 92 and the predictions if return_predictions is True. 93 """ 94 if verbose: 95 print("Segmenting mitochondria in volume of shape", input_volume.shape) 96 # Create the scaler to handle prediction with a different scaling factor. 97 scaler = _Scaler(scale, verbose) 98 input_volume = scaler.scale_input(input_volume) 99 100 # Rescale the mask if it was given and run prediction. 101 if mask is not None: 102 mask = scaler.scale_input(mask, is_segmentation=True) 103 pred = get_prediction(input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose) 104 105 # Run segmentation and rescale the result if necessary. 106 foreground, boundaries = pred[:2] 107 seg = _run_segmentation(foreground, boundaries, verbose=verbose, min_size=min_size, seed_distance=seed_distance, 108 block_shape=ws_block_shape, halo=ws_halo, boundary_threshold=boundary_threshold, 109 area_threshold=area_threshold) 110 seg = scaler.rescale_output(seg, is_segmentation=True) 111 112 if return_predictions: 113 pred = scaler.rescale_output(pred, is_segmentation=False) 114 return seg, pred 115 return seg
Segment mitochondria in an input volume.
Arguments:
- input_volume: The input volume to segment.
- model_path: The path to the model checkpoint if
model
is not provided. - model: Pre-loaded model. Either
model_path
ormodel
is required. - tiling: The tiling configuration for the prediction.
- min_size: The minimum size of a mitochondria to be considered.
- verbose: Whether to print timing information.
- distance_based_segmentation: Whether to use distance-based segmentation.
- return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
- scale: The scale factor to use for rescaling the input volume before prediction.
- mask: An optional mask that is used to restrict the segmentation.
- seed_distance: The distance threshold for the seeded watershed.
- ws_block_shape: The block shape for the seeded watershed.
- ws_halo: The halo for the seeded watershed.
- boundary_threshold: The boundary threshold distance calculation.
- area_threshold: The maximum area (in pixels) of holes to be removed or filled in the segmentation.
This parameter is passed to
skimage.morphology.remove_small_holes
.
Returns:
The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.