synapse_net.inference.scalable_segmentation
1import os 2import tempfile 3from typing import Dict, List, Optional 4 5import elf.parallel as parallel 6import numpy as np 7import torch 8 9from elf.io import open_file 10from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper 11from elf.wrapper.base import MultiTransformationWrapper 12from elf.wrapper.resized_volume import ResizedVolume 13from numpy.typing import ArrayLike 14from synapse_net.inference.util import get_prediction 15 16 17class SelectChannel(SimpleTransformationWrapper): 18 """Wrapper to select a chanel from an array-like dataset object. 19 20 Args: 21 volume: The array-like input dataset. 22 channel: The channel that will be selected. 23 """ 24 def __init__(self, volume: np.typing.ArrayLike, channel: int): 25 self.channel = channel 26 super().__init__(volume, lambda x: x[self.channel], with_channels=True) 27 28 @property 29 def shape(self): 30 return self._volume.shape[1:] 31 32 @property 33 def chunks(self): 34 return self._volume.chunks[1:] 35 36 @property 37 def ndim(self): 38 return self._volume.ndim - 1 39 40 41def _run_segmentation(pred, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape): 42 # Create wrappers for selecting the foreground and the boundary channel. 43 foreground = SelectChannel(pred, 0) 44 boundaries = SelectChannel(pred, 1) 45 46 # Create wrappers for subtracting and thresholding boundary subtracted from the foreground. 47 # And then compute the seeds based on this. 48 seed_input = ThresholdWrapper( 49 MultiTransformationWrapper(np.subtract, foreground, boundaries), seed_threshold 50 ) 51 parallel.label(seed_input, seeds, verbose=verbose, block_shape=chunks) 52 53 # Run watershed to extend back from the seeds to the boundaries. 54 mask = ThresholdWrapper(foreground, 0.5) 55 56 # Resize if necessary. 57 if original_shape is not None: 58 boundaries = ResizedVolume(boundaries, original_shape, order=1) 59 seeds = ResizedVolume(seeds, original_shape, order=0) 60 mask = ResizedVolume(mask, original_shape, order=0) 61 62 parallel.seeded_watershed( 63 boundaries, seeds=seeds, out=output, verbose=verbose, mask=mask, block_shape=chunks, halo=3 * (16,) 64 ) 65 66 # Run the size filter. 67 if min_size > 0: 68 parallel.size_filter(output, output, min_size=min_size, verbose=verbose, block_shape=chunks) 69 70 71def scalable_segmentation( 72 input_: ArrayLike, 73 output: ArrayLike, 74 model: torch.nn.Module, 75 tiling: Optional[Dict[str, Dict[str, int]]] = None, 76 scale: Optional[List[float]] = None, 77 seed_threshold: float = 0.5, 78 min_size: int = 500, 79 prediction: Optional[ArrayLike] = None, 80 verbose: bool = True, 81 mask: Optional[ArrayLike] = None, 82 devices: Optional[List[str]] = None, 83) -> None: 84 """Run segmentation based on a prediction with foreground and boundary channel. 85 86 This function first subtracts the boundary prediction from the foreground prediction, 87 then applies a threshold, connected components, and a watershed to fit the components 88 back to the foreground. All processing steps are implemented in a scalable fashion, 89 so that the function runs for large input volumes. 90 91 Args: 92 input_: The input data. 93 output: The array for storing the output segmentation. 94 Can be a numpy array, a zarr array, or similar. 95 model: The model for prediction. 96 tiling: The tiling configuration for the prediction. 97 scale: The scale factor to use for rescaling the input volume before prediction. 98 seed_threshold: The threshold applied before computing connected components. 99 min_size: The minimum size of a vesicle to be considered. 100 prediction: The array for storing the prediction. 101 If given, this can be a numpy array, a zarr array, or similar 102 If not given will be stored in a temporary n5 array. 103 verbose: Whether to print timing information. 104 devices: The devices for running prediction. If not given will use the GPU 105 if available, otherwise the CPU. 106 """ 107 if mask is not None: 108 raise NotImplementedError 109 assert model.out_channels == 2 110 111 # Create a temporary directory for storing the predictions. 112 chunks = (128,) * 3 113 with tempfile.TemporaryDirectory() as tmp_dir: 114 115 if scale is None or np.allclose(scale, 1.0, atol=1e-3): 116 original_shape = None 117 else: 118 original_shape = input_.shape 119 new_shape = tuple(int(sh * sc) for sh, sc in zip(input_.shape, scale)) 120 input_ = ResizedVolume(input_, shape=new_shape, order=1) 121 122 if prediction is None: 123 # Create the dataset for storing the prediction. 124 tmp_pred = os.path.join(tmp_dir, "prediction.n5") 125 f = open_file(tmp_pred, mode="a") 126 pred_shape = (2,) + input_.shape 127 pred_chunks = (1,) + chunks 128 prediction = f.create_dataset("pred", shape=pred_shape, dtype="float32", chunks=pred_chunks) 129 else: 130 assert prediction.shape[0] == 2 131 assert prediction.shape[1:] == input_.shape 132 133 # Create temporary storage for the seeds. 134 tmp_seeds = os.path.join(tmp_dir, "seeds.n5") 135 f = open_file(tmp_seeds, mode="a") 136 seeds = f.create_dataset("seeds", shape=input_.shape, dtype="uint64", chunks=chunks) 137 138 # Run prediction and segmentation. 139 get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose, devices=devices) 140 _run_segmentation(prediction, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape)
class
SelectChannel(elf.wrapper.base.SimpleTransformationWrapper):
18class SelectChannel(SimpleTransformationWrapper): 19 """Wrapper to select a chanel from an array-like dataset object. 20 21 Args: 22 volume: The array-like input dataset. 23 channel: The channel that will be selected. 24 """ 25 def __init__(self, volume: np.typing.ArrayLike, channel: int): 26 self.channel = channel 27 super().__init__(volume, lambda x: x[self.channel], with_channels=True) 28 29 @property 30 def shape(self): 31 return self._volume.shape[1:] 32 33 @property 34 def chunks(self): 35 return self._volume.chunks[1:] 36 37 @property 38 def ndim(self): 39 return self._volume.ndim - 1
Wrapper to select a chanel from an array-like dataset object.
Arguments:
- volume: The array-like input dataset.
- channel: The channel that will be selected.
SelectChannel( volume: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes]], channel: int)
def
scalable_segmentation( input_: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes]], output: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes]], model: torch.nn.modules.module.Module, tiling: Optional[Dict[str, Dict[str, int]]] = None, scale: Optional[List[float]] = None, seed_threshold: float = 0.5, min_size: int = 500, prediction: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes], NoneType] = None, verbose: bool = True, mask: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes], NoneType] = None, devices: Optional[List[str]] = None) -> None:
72def scalable_segmentation( 73 input_: ArrayLike, 74 output: ArrayLike, 75 model: torch.nn.Module, 76 tiling: Optional[Dict[str, Dict[str, int]]] = None, 77 scale: Optional[List[float]] = None, 78 seed_threshold: float = 0.5, 79 min_size: int = 500, 80 prediction: Optional[ArrayLike] = None, 81 verbose: bool = True, 82 mask: Optional[ArrayLike] = None, 83 devices: Optional[List[str]] = None, 84) -> None: 85 """Run segmentation based on a prediction with foreground and boundary channel. 86 87 This function first subtracts the boundary prediction from the foreground prediction, 88 then applies a threshold, connected components, and a watershed to fit the components 89 back to the foreground. All processing steps are implemented in a scalable fashion, 90 so that the function runs for large input volumes. 91 92 Args: 93 input_: The input data. 94 output: The array for storing the output segmentation. 95 Can be a numpy array, a zarr array, or similar. 96 model: The model for prediction. 97 tiling: The tiling configuration for the prediction. 98 scale: The scale factor to use for rescaling the input volume before prediction. 99 seed_threshold: The threshold applied before computing connected components. 100 min_size: The minimum size of a vesicle to be considered. 101 prediction: The array for storing the prediction. 102 If given, this can be a numpy array, a zarr array, or similar 103 If not given will be stored in a temporary n5 array. 104 verbose: Whether to print timing information. 105 devices: The devices for running prediction. If not given will use the GPU 106 if available, otherwise the CPU. 107 """ 108 if mask is not None: 109 raise NotImplementedError 110 assert model.out_channels == 2 111 112 # Create a temporary directory for storing the predictions. 113 chunks = (128,) * 3 114 with tempfile.TemporaryDirectory() as tmp_dir: 115 116 if scale is None or np.allclose(scale, 1.0, atol=1e-3): 117 original_shape = None 118 else: 119 original_shape = input_.shape 120 new_shape = tuple(int(sh * sc) for sh, sc in zip(input_.shape, scale)) 121 input_ = ResizedVolume(input_, shape=new_shape, order=1) 122 123 if prediction is None: 124 # Create the dataset for storing the prediction. 125 tmp_pred = os.path.join(tmp_dir, "prediction.n5") 126 f = open_file(tmp_pred, mode="a") 127 pred_shape = (2,) + input_.shape 128 pred_chunks = (1,) + chunks 129 prediction = f.create_dataset("pred", shape=pred_shape, dtype="float32", chunks=pred_chunks) 130 else: 131 assert prediction.shape[0] == 2 132 assert prediction.shape[1:] == input_.shape 133 134 # Create temporary storage for the seeds. 135 tmp_seeds = os.path.join(tmp_dir, "seeds.n5") 136 f = open_file(tmp_seeds, mode="a") 137 seeds = f.create_dataset("seeds", shape=input_.shape, dtype="uint64", chunks=chunks) 138 139 # Run prediction and segmentation. 140 get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose, devices=devices) 141 _run_segmentation(prediction, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape)
Run segmentation based on a prediction with foreground and boundary channel.
This function first subtracts the boundary prediction from the foreground prediction, then applies a threshold, connected components, and a watershed to fit the components back to the foreground. All processing steps are implemented in a scalable fashion, so that the function runs for large input volumes.
Arguments:
- input_: The input data.
- output: The array for storing the output segmentation. Can be a numpy array, a zarr array, or similar.
- model: The model for prediction.
- tiling: The tiling configuration for the prediction.
- scale: The scale factor to use for rescaling the input volume before prediction.
- seed_threshold: The threshold applied before computing connected components.
- min_size: The minimum size of a vesicle to be considered.
- prediction: The array for storing the prediction. If given, this can be a numpy array, a zarr array, or similar If not given will be stored in a temporary n5 array.
- verbose: Whether to print timing information.
- devices: The devices for running prediction. If not given will use the GPU if available, otherwise the CPU.