synapse_net.inference.mitochondria
1import time 2from typing import Callable, Dict, List, Optional, Tuple, Union 3 4import elf.parallel as parallel 5import numpy as np 6import torch 7 8from synapse_net.inference.util import apply_size_filter, get_prediction, _Scaler, _postprocess_seg_3d 9 10 11def _run_segmentation( 12 foreground, boundaries, verbose, min_size, 13 # blocking shapes for parallel computation 14 block_shape=(128, 256, 256), 15 halo=(48, 48, 48), 16 seed_distance=6, 17 boundary_threshold=0.25, 18 area_threshold=5000, 19): 20 t0 = time.time() 21 dist = parallel.distance_transform( 22 boundaries < boundary_threshold, halo=halo, verbose=verbose, block_shape=block_shape 23 ) 24 if verbose: 25 print("Compute distance transform in", time.time() - t0, "s") 26 27 # Get the segmentation via seeded watershed. 28 t0 = time.time() 29 seeds = np.logical_and(foreground > 0.5, dist > seed_distance) 30 seeds = parallel.label(seeds, block_shape=block_shape, verbose=verbose) 31 if verbose: 32 print("Compute connected components in", time.time() - t0, "s") 33 34 t0 = time.time() 35 hmap = (dist.max() - dist) / dist.max() 36 hmap[np.logical_and(boundaries > boundary_threshold, foreground < boundary_threshold)] = (hmap + boundaries).max() 37 mask = (foreground + boundaries) > 0.5 38 39 seg = np.zeros_like(seeds) 40 seg = parallel.seeded_watershed( 41 hmap, seeds, block_shape=block_shape, 42 out=seg, mask=mask, verbose=verbose, halo=halo, 43 ) 44 if verbose: 45 print("Compute watershed in", time.time() - t0, "s") 46 47 seg = apply_size_filter(seg, min_size, verbose=verbose, block_shape=block_shape) 48 seg = _postprocess_seg_3d(seg, area_threshold=area_threshold) 49 return seg 50 51 52def segment_mitochondria( 53 input_volume: np.ndarray, 54 model_path: Optional[str] = None, 55 model: Optional[torch.nn.Module] = None, 56 tiling: Optional[Dict[str, Dict[str, int]]] = None, 57 min_size: int = 50000, 58 verbose: bool = True, 59 distance_based_segmentation: bool = False, 60 return_predictions: bool = False, 61 scale: Optional[List[float]] = None, 62 mask: Optional[np.ndarray] = None, 63 seed_distance: int = 6, 64 ws_block_shape: Tuple[int, ...] = (128, 256, 256), 65 ws_halo: Tuple[int, ...] = (48, 48, 48), 66 boundary_threshold: float = 0.25, 67 area_threshold: int = 5000, 68 preprocess: Callable = None, 69) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 70 """Segment mitochondria in an input volume. 71 72 Args: 73 input_volume: The input volume to segment. 74 model_path: The path to the model checkpoint if `model` is not provided. 75 model: Pre-loaded model. Either `model_path` or `model` is required. 76 tiling: The tiling configuration for the prediction. 77 min_size: The minimum size of a mitochondria to be considered. 78 verbose: Whether to print timing information. 79 distance_based_segmentation: Whether to use distance-based segmentation. 80 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 81 scale: The scale factor to use for rescaling the input volume before prediction. 82 mask: An optional mask that is used to restrict the segmentation. 83 seed_distance: The distance threshold for the seeded watershed. 84 ws_block_shape: The block shape for the seeded watershed. 85 ws_halo: The halo for the seeded watershed. 86 boundary_threshold: The boundary threshold distance calculation. 87 area_threshold: The maximum area (in pixels) of holes to be removed or filled in the segmentation. 88 This parameter is passed to `skimage.morphology.remove_small_holes`. 89 90 Returns: 91 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 92 and the predictions if return_predictions is True. 93 """ 94 if verbose: 95 print("Segmenting mitochondria in volume of shape", input_volume.shape) 96 # Create the scaler to handle prediction with a different scaling factor. 97 scaler = _Scaler(scale, verbose) 98 input_volume = scaler.scale_input(input_volume) 99 100 # Rescale the mask if it was given and run prediction. 101 if mask is not None: 102 mask = scaler.scale_input(mask, is_segmentation=True) 103 pred = get_prediction( 104 input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose, 105 preprocess=preprocess 106 ) 107 108 # Run segmentation and rescale the result if necessary. 109 foreground, boundaries = pred[:2] 110 seg = _run_segmentation(foreground, boundaries, verbose=verbose, min_size=min_size, seed_distance=seed_distance, 111 block_shape=ws_block_shape, halo=ws_halo, boundary_threshold=boundary_threshold, 112 area_threshold=area_threshold) 113 seg = scaler.rescale_output(seg, is_segmentation=True) 114 115 if return_predictions: 116 pred = scaler.rescale_output(pred, is_segmentation=False) 117 return seg, pred 118 return seg
def
segment_mitochondria( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 50000, verbose: bool = True, distance_based_segmentation: bool = False, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, seed_distance: int = 6, ws_block_shape: Tuple[int, ...] = (128, 256, 256), ws_halo: Tuple[int, ...] = (48, 48, 48), boundary_threshold: float = 0.25, area_threshold: int = 5000, preprocess: Callable = None) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
53def segment_mitochondria( 54 input_volume: np.ndarray, 55 model_path: Optional[str] = None, 56 model: Optional[torch.nn.Module] = None, 57 tiling: Optional[Dict[str, Dict[str, int]]] = None, 58 min_size: int = 50000, 59 verbose: bool = True, 60 distance_based_segmentation: bool = False, 61 return_predictions: bool = False, 62 scale: Optional[List[float]] = None, 63 mask: Optional[np.ndarray] = None, 64 seed_distance: int = 6, 65 ws_block_shape: Tuple[int, ...] = (128, 256, 256), 66 ws_halo: Tuple[int, ...] = (48, 48, 48), 67 boundary_threshold: float = 0.25, 68 area_threshold: int = 5000, 69 preprocess: Callable = None, 70) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: 71 """Segment mitochondria in an input volume. 72 73 Args: 74 input_volume: The input volume to segment. 75 model_path: The path to the model checkpoint if `model` is not provided. 76 model: Pre-loaded model. Either `model_path` or `model` is required. 77 tiling: The tiling configuration for the prediction. 78 min_size: The minimum size of a mitochondria to be considered. 79 verbose: Whether to print timing information. 80 distance_based_segmentation: Whether to use distance-based segmentation. 81 return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation. 82 scale: The scale factor to use for rescaling the input volume before prediction. 83 mask: An optional mask that is used to restrict the segmentation. 84 seed_distance: The distance threshold for the seeded watershed. 85 ws_block_shape: The block shape for the seeded watershed. 86 ws_halo: The halo for the seeded watershed. 87 boundary_threshold: The boundary threshold distance calculation. 88 area_threshold: The maximum area (in pixels) of holes to be removed or filled in the segmentation. 89 This parameter is passed to `skimage.morphology.remove_small_holes`. 90 91 Returns: 92 The segmentation mask as a numpy array, or a tuple containing the segmentation mask 93 and the predictions if return_predictions is True. 94 """ 95 if verbose: 96 print("Segmenting mitochondria in volume of shape", input_volume.shape) 97 # Create the scaler to handle prediction with a different scaling factor. 98 scaler = _Scaler(scale, verbose) 99 input_volume = scaler.scale_input(input_volume) 100 101 # Rescale the mask if it was given and run prediction. 102 if mask is not None: 103 mask = scaler.scale_input(mask, is_segmentation=True) 104 pred = get_prediction( 105 input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose, 106 preprocess=preprocess 107 ) 108 109 # Run segmentation and rescale the result if necessary. 110 foreground, boundaries = pred[:2] 111 seg = _run_segmentation(foreground, boundaries, verbose=verbose, min_size=min_size, seed_distance=seed_distance, 112 block_shape=ws_block_shape, halo=ws_halo, boundary_threshold=boundary_threshold, 113 area_threshold=area_threshold) 114 seg = scaler.rescale_output(seg, is_segmentation=True) 115 116 if return_predictions: 117 pred = scaler.rescale_output(pred, is_segmentation=False) 118 return seg, pred 119 return seg
Segment mitochondria in an input volume.
Arguments:
- input_volume: The input volume to segment.
- model_path: The path to the model checkpoint if
modelis not provided. - model: Pre-loaded model. Either
model_pathormodelis required. - tiling: The tiling configuration for the prediction.
- min_size: The minimum size of a mitochondria to be considered.
- verbose: Whether to print timing information.
- distance_based_segmentation: Whether to use distance-based segmentation.
- return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
- scale: The scale factor to use for rescaling the input volume before prediction.
- mask: An optional mask that is used to restrict the segmentation.
- seed_distance: The distance threshold for the seeded watershed.
- ws_block_shape: The block shape for the seeded watershed.
- ws_halo: The halo for the seeded watershed.
- boundary_threshold: The boundary threshold distance calculation.
- area_threshold: The maximum area (in pixels) of holes to be removed or filled in the segmentation.
This parameter is passed to
skimage.morphology.remove_small_holes.
Returns:
The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.