synapse_net.inference.mitochondria

  1import time
  2from typing import Dict, List, Optional, Tuple, Union
  3
  4import elf.parallel as parallel
  5import numpy as np
  6import torch
  7
  8from synapse_net.inference.util import apply_size_filter, get_prediction, _Scaler, _postprocess_seg_3d
  9
 10
 11def _run_segmentation(
 12    foreground, boundaries, verbose, min_size,
 13    # blocking shapes for parallel computation
 14    block_shape=(128, 256, 256),
 15    halo=(48, 48, 48),
 16    seed_distance=6,
 17    boundary_threshold=0.25,
 18    area_threshold=5000,
 19):
 20    t0 = time.time()
 21    dist = parallel.distance_transform(
 22        boundaries < boundary_threshold, halo=halo, verbose=verbose, block_shape=block_shape
 23    )
 24    if verbose:
 25        print("Compute distance transform in", time.time() - t0, "s")
 26
 27    # Get the segmentation via seeded watershed.
 28    t0 = time.time()
 29    seeds = np.logical_and(foreground > 0.5, dist > seed_distance)
 30    seeds = parallel.label(seeds, block_shape=block_shape, verbose=verbose)
 31    if verbose:
 32        print("Compute connected components in", time.time() - t0, "s")
 33
 34    t0 = time.time()
 35    hmap = (dist.max() - dist) / dist.max()
 36    hmap[boundaries > boundary_threshold] = (hmap + boundaries).max()
 37    mask = (foreground + boundaries) > 0.5
 38
 39    seg = np.zeros_like(seeds)
 40    seg = parallel.seeded_watershed(
 41        hmap, seeds, block_shape=block_shape,
 42        out=seg, mask=mask, verbose=verbose, halo=halo,
 43    )
 44    if verbose:
 45        print("Compute watershed in", time.time() - t0, "s")
 46
 47    seg = apply_size_filter(seg, min_size, verbose=verbose, block_shape=block_shape)
 48    seg = _postprocess_seg_3d(seg, area_threshold=area_threshold)
 49    return seg
 50
 51
 52def segment_mitochondria(
 53    input_volume: np.ndarray,
 54    model_path: Optional[str] = None,
 55    model: Optional[torch.nn.Module] = None,
 56    tiling: Optional[Dict[str, Dict[str, int]]] = None,
 57    min_size: int = 50000,
 58    verbose: bool = True,
 59    distance_based_segmentation: bool = False,
 60    return_predictions: bool = False,
 61    scale: Optional[List[float]] = None,
 62    mask: Optional[np.ndarray] = None,
 63    seed_distance: int = 6,
 64    ws_block_shape: Tuple[int, ...] = (128, 256, 256),
 65    ws_halo: Tuple[int, ...] = (48, 48, 48),
 66    boundary_threshold: float = 0.25,
 67    area_threshold: int = 5000,
 68) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
 69    """Segment mitochondria in an input volume.
 70
 71    Args:
 72        input_volume: The input volume to segment.
 73        model_path: The path to the model checkpoint if `model` is not provided.
 74        model: Pre-loaded model. Either `model_path` or `model` is required.
 75        tiling: The tiling configuration for the prediction.
 76        min_size: The minimum size of a mitochondria to be considered.
 77        verbose: Whether to print timing information.
 78        distance_based_segmentation: Whether to use distance-based segmentation.
 79        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
 80        scale: The scale factor to use for rescaling the input volume before prediction.
 81        mask: An optional mask that is used to restrict the segmentation.
 82        seed_distance: The distance threshold for the seeded watershed.
 83        ws_block_shape: The block shape for the seeded watershed.
 84        ws_halo: The halo for the seeded watershed.
 85        boundary_threshold: The boundary threshold distance calculation.
 86        area_threshold: The maximum area (in pixels) of holes to be removed or filled in the segmentation.
 87            This parameter is passed to `skimage.morphology.remove_small_holes`.
 88
 89    Returns:
 90        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
 91        and the predictions if return_predictions is True.
 92    """
 93    if verbose:
 94        print("Segmenting mitochondria in volume of shape", input_volume.shape)
 95    # Create the scaler to handle prediction with a different scaling factor.
 96    scaler = _Scaler(scale, verbose)
 97    input_volume = scaler.scale_input(input_volume)
 98
 99    # Rescale the mask if it was given and run prediction.
100    if mask is not None:
101        mask = scaler.scale_input(mask, is_segmentation=True)
102    pred = get_prediction(input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose)
103
104    # Run segmentation and rescale the result if necessary.
105    foreground, boundaries = pred[:2]
106    seg = _run_segmentation(foreground, boundaries, verbose=verbose, min_size=min_size, seed_distance=seed_distance,
107                            block_shape=ws_block_shape, halo=ws_halo, boundary_threshold=boundary_threshold,
108                            area_threshold=area_threshold)
109    seg = scaler.rescale_output(seg, is_segmentation=True)
110
111    if return_predictions:
112        pred = scaler.rescale_output(pred, is_segmentation=False)
113        return seg, pred
114    return seg
def segment_mitochondria( input_volume: numpy.ndarray, model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, tiling: Optional[Dict[str, Dict[str, int]]] = None, min_size: int = 50000, verbose: bool = True, distance_based_segmentation: bool = False, return_predictions: bool = False, scale: Optional[List[float]] = None, mask: Optional[numpy.ndarray] = None, seed_distance: int = 6, ws_block_shape: Tuple[int, ...] = (128, 256, 256), ws_halo: Tuple[int, ...] = (48, 48, 48), boundary_threshold: float = 0.25, area_threshold: int = 5000) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
 53def segment_mitochondria(
 54    input_volume: np.ndarray,
 55    model_path: Optional[str] = None,
 56    model: Optional[torch.nn.Module] = None,
 57    tiling: Optional[Dict[str, Dict[str, int]]] = None,
 58    min_size: int = 50000,
 59    verbose: bool = True,
 60    distance_based_segmentation: bool = False,
 61    return_predictions: bool = False,
 62    scale: Optional[List[float]] = None,
 63    mask: Optional[np.ndarray] = None,
 64    seed_distance: int = 6,
 65    ws_block_shape: Tuple[int, ...] = (128, 256, 256),
 66    ws_halo: Tuple[int, ...] = (48, 48, 48),
 67    boundary_threshold: float = 0.25,
 68    area_threshold: int = 5000,
 69) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
 70    """Segment mitochondria in an input volume.
 71
 72    Args:
 73        input_volume: The input volume to segment.
 74        model_path: The path to the model checkpoint if `model` is not provided.
 75        model: Pre-loaded model. Either `model_path` or `model` is required.
 76        tiling: The tiling configuration for the prediction.
 77        min_size: The minimum size of a mitochondria to be considered.
 78        verbose: Whether to print timing information.
 79        distance_based_segmentation: Whether to use distance-based segmentation.
 80        return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
 81        scale: The scale factor to use for rescaling the input volume before prediction.
 82        mask: An optional mask that is used to restrict the segmentation.
 83        seed_distance: The distance threshold for the seeded watershed.
 84        ws_block_shape: The block shape for the seeded watershed.
 85        ws_halo: The halo for the seeded watershed.
 86        boundary_threshold: The boundary threshold distance calculation.
 87        area_threshold: The maximum area (in pixels) of holes to be removed or filled in the segmentation.
 88            This parameter is passed to `skimage.morphology.remove_small_holes`.
 89
 90    Returns:
 91        The segmentation mask as a numpy array, or a tuple containing the segmentation mask
 92        and the predictions if return_predictions is True.
 93    """
 94    if verbose:
 95        print("Segmenting mitochondria in volume of shape", input_volume.shape)
 96    # Create the scaler to handle prediction with a different scaling factor.
 97    scaler = _Scaler(scale, verbose)
 98    input_volume = scaler.scale_input(input_volume)
 99
100    # Rescale the mask if it was given and run prediction.
101    if mask is not None:
102        mask = scaler.scale_input(mask, is_segmentation=True)
103    pred = get_prediction(input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose)
104
105    # Run segmentation and rescale the result if necessary.
106    foreground, boundaries = pred[:2]
107    seg = _run_segmentation(foreground, boundaries, verbose=verbose, min_size=min_size, seed_distance=seed_distance,
108                            block_shape=ws_block_shape, halo=ws_halo, boundary_threshold=boundary_threshold,
109                            area_threshold=area_threshold)
110    seg = scaler.rescale_output(seg, is_segmentation=True)
111
112    if return_predictions:
113        pred = scaler.rescale_output(pred, is_segmentation=False)
114        return seg, pred
115    return seg

Segment mitochondria in an input volume.

Arguments:
  • input_volume: The input volume to segment.
  • model_path: The path to the model checkpoint if model is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • min_size: The minimum size of a mitochondria to be considered.
  • verbose: Whether to print timing information.
  • distance_based_segmentation: Whether to use distance-based segmentation.
  • return_predictions: Whether to return the predictions (foreground, boundaries) alongside the segmentation.
  • scale: The scale factor to use for rescaling the input volume before prediction.
  • mask: An optional mask that is used to restrict the segmentation.
  • seed_distance: The distance threshold for the seeded watershed.
  • ws_block_shape: The block shape for the seeded watershed.
  • ws_halo: The halo for the seeded watershed.
  • boundary_threshold: The boundary threshold distance calculation.
  • area_threshold: The maximum area (in pixels) of holes to be removed or filled in the segmentation. This parameter is passed to skimage.morphology.remove_small_holes.
Returns:

The segmentation mask as a numpy array, or a tuple containing the segmentation mask and the predictions if return_predictions is True.