synapse_net.inference.scalable_segmentation

  1import os
  2import tempfile
  3from typing import Dict, List, Optional
  4
  5import elf.parallel as parallel
  6import numpy as np
  7import torch
  8
  9from elf.io import open_file
 10from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
 11from elf.wrapper.base import MultiTransformationWrapper
 12from elf.wrapper.resized_volume import ResizedVolume
 13from numpy.typing import ArrayLike
 14from synapse_net.inference.util import get_prediction
 15
 16
 17class SelectChannel(SimpleTransformationWrapper):
 18    """Wrapper to select a chanel from an array-like dataset object.
 19
 20    Args:
 21        volume: The array-like input dataset.
 22        channel: The channel that will be selected.
 23    """
 24    def __init__(self, volume: np.typing.ArrayLike, channel: int):
 25        self.channel = channel
 26        super().__init__(volume, lambda x: x[self.channel], with_channels=True)
 27
 28    @property
 29    def shape(self):
 30        return self._volume.shape[1:]
 31
 32    @property
 33    def chunks(self):
 34        return self._volume.chunks[1:]
 35
 36    @property
 37    def ndim(self):
 38        return self._volume.ndim - 1
 39
 40
 41def _run_segmentation(pred, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape):
 42    # Create wrappers for selecting the foreground and the boundary channel.
 43    foreground = SelectChannel(pred, 0)
 44    boundaries = SelectChannel(pred, 1)
 45
 46    # Create wrappers for subtracting and thresholding boundary subtracted from the foreground.
 47    # And then compute the seeds based on this.
 48    seed_input = ThresholdWrapper(
 49        MultiTransformationWrapper(np.subtract, foreground, boundaries), seed_threshold
 50    )
 51    parallel.label(seed_input, seeds, verbose=verbose, block_shape=chunks)
 52
 53    # Run watershed to extend back from the seeds to the boundaries.
 54    mask = ThresholdWrapper(foreground, 0.5)
 55
 56    # Resize if necessary.
 57    if original_shape is not None:
 58        boundaries = ResizedVolume(boundaries, original_shape, order=1)
 59        seeds = ResizedVolume(seeds, original_shape, order=0)
 60        mask = ResizedVolume(mask, original_shape, order=0)
 61
 62    parallel.seeded_watershed(
 63        boundaries, seeds=seeds, out=output, verbose=verbose, mask=mask, block_shape=chunks, halo=3 * (16,)
 64    )
 65
 66    # Run the size filter.
 67    if min_size > 0:
 68        parallel.size_filter(output, output, min_size=min_size, verbose=verbose, block_shape=chunks)
 69
 70
 71def scalable_segmentation(
 72    input_: ArrayLike,
 73    output: ArrayLike,
 74    model: torch.nn.Module,
 75    tiling: Optional[Dict[str, Dict[str, int]]] = None,
 76    scale: Optional[List[float]] = None,
 77    seed_threshold: float = 0.5,
 78    min_size: int = 500,
 79    prediction: Optional[ArrayLike] = None,
 80    verbose: bool = True,
 81    mask: Optional[ArrayLike] = None,
 82    devices: Optional[List[str]] = None,
 83) -> None:
 84    """Run segmentation based on a prediction with foreground and boundary channel.
 85
 86    This function first subtracts the boundary prediction from the foreground prediction,
 87    then applies a threshold, connected components, and a watershed to fit the components
 88    back to the foreground. All processing steps are implemented in a scalable fashion,
 89    so that the function runs for large input volumes.
 90
 91    Args:
 92        input_: The input data.
 93        output: The array for storing the output segmentation.
 94            Can be a numpy array, a zarr array, or similar.
 95        model: The model for prediction.
 96        tiling: The tiling configuration for the prediction.
 97        scale: The scale factor to use for rescaling the input volume before prediction.
 98        seed_threshold: The threshold applied before computing connected components.
 99        min_size: The minimum size of a vesicle to be considered.
100        prediction: The array for storing the prediction.
101            If given, this can be a numpy array, a zarr array, or similar
102            If not given will be stored in a temporary n5 array.
103        verbose: Whether to print timing information.
104        devices: The devices for running prediction. If not given will use the GPU
105            if available, otherwise the CPU.
106    """
107    if mask is not None:
108        raise NotImplementedError
109    assert model.out_channels == 2
110
111    # Create a temporary directory for storing the predictions.
112    chunks = (128,) * 3
113    with tempfile.TemporaryDirectory() as tmp_dir:
114
115        if scale is None or np.allclose(scale, 1.0, atol=1e-3):
116            original_shape = None
117        else:
118            original_shape = input_.shape
119            new_shape = tuple(int(sh * sc) for sh, sc in zip(input_.shape, scale))
120            input_ = ResizedVolume(input_, shape=new_shape, order=1)
121
122        if prediction is None:
123            # Create the dataset for storing the prediction.
124            tmp_pred = os.path.join(tmp_dir, "prediction.n5")
125            f = open_file(tmp_pred, mode="a")
126            pred_shape = (2,) + input_.shape
127            pred_chunks = (1,) + chunks
128            prediction = f.create_dataset("pred", shape=pred_shape, dtype="float32", chunks=pred_chunks)
129        else:
130            assert prediction.shape[0] == 2
131            assert prediction.shape[1:] == input_.shape
132
133        # Create temporary storage for the seeds.
134        tmp_seeds = os.path.join(tmp_dir, "seeds.n5")
135        f = open_file(tmp_seeds, mode="a")
136        seeds = f.create_dataset("seeds", shape=input_.shape, dtype="uint64", chunks=chunks)
137
138        # Run prediction and segmentation.
139        get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose, devices=devices)
140        _run_segmentation(prediction, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape)
class SelectChannel(elf.wrapper.base.SimpleTransformationWrapper):
18class SelectChannel(SimpleTransformationWrapper):
19    """Wrapper to select a chanel from an array-like dataset object.
20
21    Args:
22        volume: The array-like input dataset.
23        channel: The channel that will be selected.
24    """
25    def __init__(self, volume: np.typing.ArrayLike, channel: int):
26        self.channel = channel
27        super().__init__(volume, lambda x: x[self.channel], with_channels=True)
28
29    @property
30    def shape(self):
31        return self._volume.shape[1:]
32
33    @property
34    def chunks(self):
35        return self._volume.chunks[1:]
36
37    @property
38    def ndim(self):
39        return self._volume.ndim - 1

Wrapper to select a chanel from an array-like dataset object.

Arguments:
  • volume: The array-like input dataset.
  • channel: The channel that will be selected.
SelectChannel( volume: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes]], channel: int)
25    def __init__(self, volume: np.typing.ArrayLike, channel: int):
26        self.channel = channel
27        super().__init__(volume, lambda x: x[self.channel], with_channels=True)
channel
shape
29    @property
30    def shape(self):
31        return self._volume.shape[1:]
chunks
33    @property
34    def chunks(self):
35        return self._volume.chunks[1:]
ndim
37    @property
38    def ndim(self):
39        return self._volume.ndim - 1
def scalable_segmentation( input_: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes]], output: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes]], model: torch.nn.modules.module.Module, tiling: Optional[Dict[str, Dict[str, int]]] = None, scale: Optional[List[float]] = None, seed_threshold: float = 0.5, min_size: int = 500, prediction: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes], NoneType] = None, verbose: bool = True, mask: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes], NoneType] = None, devices: Optional[List[str]] = None) -> None:
 72def scalable_segmentation(
 73    input_: ArrayLike,
 74    output: ArrayLike,
 75    model: torch.nn.Module,
 76    tiling: Optional[Dict[str, Dict[str, int]]] = None,
 77    scale: Optional[List[float]] = None,
 78    seed_threshold: float = 0.5,
 79    min_size: int = 500,
 80    prediction: Optional[ArrayLike] = None,
 81    verbose: bool = True,
 82    mask: Optional[ArrayLike] = None,
 83    devices: Optional[List[str]] = None,
 84) -> None:
 85    """Run segmentation based on a prediction with foreground and boundary channel.
 86
 87    This function first subtracts the boundary prediction from the foreground prediction,
 88    then applies a threshold, connected components, and a watershed to fit the components
 89    back to the foreground. All processing steps are implemented in a scalable fashion,
 90    so that the function runs for large input volumes.
 91
 92    Args:
 93        input_: The input data.
 94        output: The array for storing the output segmentation.
 95            Can be a numpy array, a zarr array, or similar.
 96        model: The model for prediction.
 97        tiling: The tiling configuration for the prediction.
 98        scale: The scale factor to use for rescaling the input volume before prediction.
 99        seed_threshold: The threshold applied before computing connected components.
100        min_size: The minimum size of a vesicle to be considered.
101        prediction: The array for storing the prediction.
102            If given, this can be a numpy array, a zarr array, or similar
103            If not given will be stored in a temporary n5 array.
104        verbose: Whether to print timing information.
105        devices: The devices for running prediction. If not given will use the GPU
106            if available, otherwise the CPU.
107    """
108    if mask is not None:
109        raise NotImplementedError
110    assert model.out_channels == 2
111
112    # Create a temporary directory for storing the predictions.
113    chunks = (128,) * 3
114    with tempfile.TemporaryDirectory() as tmp_dir:
115
116        if scale is None or np.allclose(scale, 1.0, atol=1e-3):
117            original_shape = None
118        else:
119            original_shape = input_.shape
120            new_shape = tuple(int(sh * sc) for sh, sc in zip(input_.shape, scale))
121            input_ = ResizedVolume(input_, shape=new_shape, order=1)
122
123        if prediction is None:
124            # Create the dataset for storing the prediction.
125            tmp_pred = os.path.join(tmp_dir, "prediction.n5")
126            f = open_file(tmp_pred, mode="a")
127            pred_shape = (2,) + input_.shape
128            pred_chunks = (1,) + chunks
129            prediction = f.create_dataset("pred", shape=pred_shape, dtype="float32", chunks=pred_chunks)
130        else:
131            assert prediction.shape[0] == 2
132            assert prediction.shape[1:] == input_.shape
133
134        # Create temporary storage for the seeds.
135        tmp_seeds = os.path.join(tmp_dir, "seeds.n5")
136        f = open_file(tmp_seeds, mode="a")
137        seeds = f.create_dataset("seeds", shape=input_.shape, dtype="uint64", chunks=chunks)
138
139        # Run prediction and segmentation.
140        get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose, devices=devices)
141        _run_segmentation(prediction, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape)

Run segmentation based on a prediction with foreground and boundary channel.

This function first subtracts the boundary prediction from the foreground prediction, then applies a threshold, connected components, and a watershed to fit the components back to the foreground. All processing steps are implemented in a scalable fashion, so that the function runs for large input volumes.

Arguments:
  • input_: The input data.
  • output: The array for storing the output segmentation. Can be a numpy array, a zarr array, or similar.
  • model: The model for prediction.
  • tiling: The tiling configuration for the prediction.
  • scale: The scale factor to use for rescaling the input volume before prediction.
  • seed_threshold: The threshold applied before computing connected components.
  • min_size: The minimum size of a vesicle to be considered.
  • prediction: The array for storing the prediction. If given, this can be a numpy array, a zarr array, or similar If not given will be stored in a temporary n5 array.
  • verbose: Whether to print timing information.
  • devices: The devices for running prediction. If not given will use the GPU if available, otherwise the CPU.