synapse_net.inference.util

  1import os
  2import time
  3import warnings
  4from glob import glob
  5from typing import Dict, List, Optional, Tuple, Union
  6
  7# # Suppress annoying import warnings.
  8# with warnings.catch_warnings():
  9#     warnings.simplefilter("ignore")
 10#     import bioimageio.core
 11
 12import imageio.v3 as imageio
 13import elf.parallel as parallel
 14import mrcfile
 15import numpy as np
 16import torch
 17import torch_em
 18# import xarray
 19
 20from elf.io import open_file
 21from numpy.typing import ArrayLike
 22from scipy.ndimage import binary_closing
 23from skimage.measure import regionprops
 24from skimage.morphology import remove_small_holes
 25from skimage.transform import rescale, resize
 26from torch_em.util.prediction import predict_with_halo
 27from tqdm import tqdm
 28
 29
 30#
 31# Utils for prediction.
 32#
 33
 34
 35class _Scaler:
 36    def __init__(self, scale, verbose):
 37        self.verbose = verbose
 38        self._original_shape = None
 39
 40        if scale is None:
 41            self.scale = None
 42            return
 43
 44        # Convert scale to a NumPy array (ensures consistency)
 45        scale = np.atleast_1d(scale).astype(np.float64)
 46
 47        # Validate scale values
 48        if not np.issubdtype(scale.dtype, np.number):
 49            raise TypeError(f"Scale contains non-numeric values: {scale}")
 50
 51        # Check if scaling is effectively identity (1.0 in all dimensions)
 52        if np.allclose(scale, 1.0, atol=1e-3):
 53            self.scale = None
 54        else:
 55            self.scale = scale
 56
 57    def scale_input(self, input_volume, is_segmentation=False):
 58        t0 = time.time()
 59        if self.scale is None:
 60            return input_volume
 61
 62        if self._original_shape is None:
 63            self._original_shape = input_volume.shape
 64        elif self._original_shape != input_volume.shape:
 65            raise RuntimeError(
 66                "Scaler was called with different input shapes. "
 67                "This is not supported, please create a new instance of the class for it."
 68            )
 69
 70        if is_segmentation:
 71            input_volume = rescale(
 72                input_volume, self.scale, preserve_range=True, order=0, anti_aliasing=False,
 73            ).astype(input_volume.dtype)
 74        else:
 75            input_volume = rescale(input_volume, self.scale, preserve_range=True).astype(input_volume.dtype)
 76
 77        if self.verbose:
 78            print("Rescaled volume from", self._original_shape, "to", input_volume.shape, "in", time.time() - t0, "s")
 79        return input_volume
 80
 81    def rescale_output(self, output, is_segmentation):
 82        t0 = time.time()
 83        if self.scale is None:
 84            return output
 85
 86        assert self._original_shape is not None
 87        out_shape = self._original_shape
 88        if output.ndim > len(out_shape):
 89            assert output.ndim == len(out_shape) + 1
 90            out_shape = (output.shape[0],) + out_shape
 91
 92        if is_segmentation:
 93            output = resize(output, out_shape, preserve_range=True, order=0, anti_aliasing=False).astype(output.dtype)
 94        else:
 95            output = resize(output, out_shape, preserve_range=True).astype(output.dtype)
 96
 97        if self.verbose:
 98            print("Resized prediction back to original shape", output.shape, "in", time.time() - t0, "s")
 99
100        return output
101
102
103def _preprocess(input_volume, with_channels, channels_to_standardize):
104    # We standardize the data for the whole volume beforehand.
105    # If we have channels then the standardization is done independently per channel.
106    if with_channels:
107        input_volume = input_volume.astype(np.float32, copy=False)
108        # TODO Check that this is the correct axis.
109        if channels_to_standardize is None:  # assume all channels
110            channels_to_standardize = range(input_volume.shape[0])
111        for ch in channels_to_standardize:
112            input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch])
113    else:
114        input_volume = torch_em.transform.raw.standardize(input_volume)
115    return input_volume
116
117
118def get_prediction(
119    input_volume: ArrayLike,  # [z, y, x]
120    tiling: Optional[Dict[str, Dict[str, int]]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
121    model_path: Optional[str] = None,
122    model: Optional[torch.nn.Module] = None,
123    verbose: bool = True,
124    with_channels: bool = False,
125    channels_to_standardize: Optional[List[int]] = None,
126    mask: Optional[ArrayLike] = None,
127    prediction: Optional[ArrayLike] = None,
128    devices: Optional[List[str]] = None,
129) -> ArrayLike:
130    """Run prediction on a given volume.
131
132    This function will automatically choose the correct prediction implementation,
133    depending on the model type.
134
135    Args:
136        input_volume: The input volume to predict on.
137        model_path: The path to the model checkpoint if 'model' is not provided.
138        model: Pre-loaded model. Either model_path or model is required.
139        tiling: The tiling configuration for the prediction.
140        verbose: Whether to print timing information.
141        with_channels: Whether to predict with channels.
142        channels_to_standardize: List of channels to standardize. Defaults to None.
143        mask: Optional binary mask. If given, the prediction will only be run in
144            the foreground region of the mask.
145        prediction: An array like object for writing the prediction.
146            If not given, the prediction will be computed in moemory.
147        devices: The devices for running prediction. If not given will use the GPU
148            if available, otherwise the CPU.
149
150    Returns:
151        The predicted volume.
152    """
153    # make sure either model path or model is passed
154    if model is None and model_path is None:
155        raise ValueError("Either 'model_path' or 'model' must be provided.")
156
157    if model is not None:
158        is_bioimageio = None
159    else:
160        is_bioimageio = model_path.endswith(".zip")
161
162    if tiling is None:
163        tiling = get_default_tiling()
164
165    # Normalize the whole input volume if it is a numpy array.
166    # Otherwise we have a zarr array or similar as input, and can't normalize it en-block.
167    # Normalization will be applied later per block in this case.
168    if isinstance(input_volume, np.ndarray):
169        input_volume = _preprocess(input_volume, with_channels, channels_to_standardize)
170
171    # Run prediction with the bioimage.io library.
172    if is_bioimageio:
173        if mask is not None:
174            raise NotImplementedError
175        raise NotImplementedError
176
177    # Run prediction with the torch-em library.
178    else:
179        if model is None:
180            # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself.
181            if model_path.endswith("best.pt"):
182                model_path = os.path.split(model_path)[0]
183        # print(f"tiling {tiling}")
184        # Create updated_tiling with the same structure
185        updated_tiling = {
186            "tile": {},
187            "halo": tiling["halo"]  # Keep the halo part unchanged
188        }
189        # Update tile dimensions
190        for dim in tiling["tile"]:
191            updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim]
192        # print(f"updated_tiling {updated_tiling}")
193        prediction = get_prediction_torch_em(
194            input_volume, updated_tiling, model_path, model, verbose, with_channels,
195            mask=mask, prediction=prediction, devices=devices,
196        )
197
198    return prediction
199
200
201def get_prediction_torch_em(
202    input_volume: ArrayLike,  # [z, y, x]
203    tiling: Dict[str, Dict[str, int]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
204    model_path: Optional[str] = None,
205    model: Optional[torch.nn.Module] = None,
206    verbose: bool = True,
207    with_channels: bool = False,
208    mask: Optional[ArrayLike] = None,
209    prediction: Optional[ArrayLike] = None,
210    devices: Optional[List[str]] = None,
211) -> np.ndarray:
212    """Run prediction using torch-em on a given volume.
213
214    Args:
215        input_volume: The input volume to predict on.
216        model_path: The path to the model checkpoint if 'model' is not provided.
217        model: Pre-loaded model. Either model_path or model is required.
218        tiling: The tiling configuration for the prediction.
219        verbose: Whether to print timing information.
220        with_channels: Whether to predict with channels.
221        mask: Optional binary mask. If given, the prediction will only be run in
222            the foreground region of the mask.
223        prediction: An array like object for writing the prediction.
224            If not given, the prediction will be computed in moemory.
225        devices: The devices for running prediction. If not given will use the GPU
226            if available, otherwise the CPU.
227
228    Returns:
229        The predicted volume.
230    """
231    # get block_shape and halo
232    block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]]
233    halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]]
234
235    t0 = time.time()
236    if devices is None:
237        devices = ["cuda" if torch.cuda.is_available() else "cpu"]
238
239    # Suppress warning when loading the model.
240    with warnings.catch_warnings():
241        warnings.simplefilter("ignore")
242        if model is None:
243            if os.path.isdir(model_path):  # Load the model from a torch_em checkpoint.
244                model = torch_em.util.load_model(checkpoint=model_path, device=devices[0])
245            else:  # Load the model directly from a serialized pytorch model.
246                model = torch.load(model_path, weights_only=False)
247
248    # Run prediction with the model.
249    with torch.no_grad():
250
251        # Deal with 2D segmentation case
252        if len(input_volume.shape) == 2:
253            block_shape = [block_shape[1], block_shape[2]]
254            halo = [halo[1], halo[2]]
255
256        if mask is not None:
257            if verbose:
258                print("Run prediction with mask.")
259            mask = mask.astype("bool")
260
261        preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize
262        prediction = predict_with_halo(
263            input_volume, model, gpu_ids=devices,
264            block_shape=block_shape, halo=halo,
265            preprocess=preprocess, with_channels=with_channels, mask=mask,
266            output=prediction,
267        )
268    if verbose:
269        print("Prediction time in", time.time() - t0, "s")
270    return prediction
271
272
273def _get_file_paths(input_path, ext=".mrc"):
274    if not os.path.exists(input_path):
275        raise Exception(f"Input path not found {input_path}")
276
277    if os.path.isfile(input_path):
278        input_files = [input_path]
279        input_root = None
280    else:
281        input_files = sorted(glob(os.path.join(input_path, "**", f"*{ext}"), recursive=True))
282        input_root = input_path
283
284    return input_files, input_root
285
286
287def _load_input(img_path, extra_files, i):
288    # Load the input data.
289    if os.path.splitext(img_path)[-1] == ".tif":
290        input_volume = imageio.imread(img_path)
291
292    else:
293        with open_file(img_path, "r") as f:
294            # Try to automatically derive the key with the raw data.
295            keys = list(f.keys())
296            if len(keys) == 1:
297                key = keys[0]
298            elif "data" in keys:
299                key = "data"
300            elif "raw" in keys:
301                key = "raw"
302            input_volume = f[key][:]
303
304    assert input_volume.ndim in (2, 3)
305    # For now we assume this is always tif.
306    if extra_files is not None:
307        extra_input = imageio.imread(extra_files[i])
308        assert extra_input.shape == input_volume.shape
309        input_volume = np.stack([input_volume, extra_input], axis=0)
310
311    return input_volume
312
313
314def _derive_scale(img_path, model_resolution):
315    try:
316        with mrcfile.open(img_path, "r") as f:
317            voxel_size = f.voxel_size
318            if len(model_resolution) == 2:
319                voxel_size = [voxel_size.y, voxel_size.x]
320            else:
321                voxel_size = [voxel_size.z, voxel_size.y, voxel_size.x]
322
323        assert len(voxel_size) == len(model_resolution)
324        # The voxel size is given in Angstrom and we need to translate it to nanometer.
325        voxel_size = [vsize / 10 for vsize in voxel_size]
326
327        # Compute the correct scale factor.
328        scale = tuple(vsize / res for vsize, res in zip(voxel_size, model_resolution))
329        print("Rescaling the data at", img_path, "by", scale, "to match the training voxel size", model_resolution)
330
331    except Exception:
332        warnings.warn(
333            f"The voxel size could not be read from the data for {img_path}. "
334            "This data will not be scaled for prediction."
335        )
336        scale = None
337
338    return scale
339
340
341def inference_helper(
342    input_path: str,
343    output_root: str,
344    segmentation_function: callable,
345    data_ext: str = ".mrc",
346    extra_input_path: Optional[str] = None,
347    extra_input_ext: str = ".tif",
348    mask_input_path: Optional[str] = None,
349    mask_input_ext: str = ".tif",
350    force: bool = False,
351    output_key: Optional[str] = None,
352    model_resolution: Optional[Tuple[float, float, float]] = None,
353    scale: Optional[Tuple[float, float, float]] = None,
354    allocate_output: bool = False,
355) -> None:
356    """Helper function to run segmentation for mrc files.
357
358    Args:
359        input_path: The path to the input data.
360            Can either be a folder. In this case all mrc files below the folder will be segmented.
361            Or can be a single mrc file. In this case only this mrc file will be segmented.
362        output_root: The path to the output directory where the segmentation results will be saved.
363        segmentation_function: The function performing the segmentation.
364            This function must take the input_volume as the only argument and must return only the segmentation.
365            If you want to pass additional arguments to this function the use 'funtools.partial'
366        data_ext: File extension for the image data. By default '.mrc' is used.
367        extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc.
368            This enables cristae segmentation with an extra mito channel.
369        extra_input_ext: File extension for the extra inputs (by default .tif).
370        mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
371        mask_input_ext: File extension for the mask inputs (by default .tif).
372        force: Whether to rerun segmentation for output files that are already present.
373        output_key: Output key for the prediction. If none will write an hdf5 file.
374        model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction.
375            If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
376        scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
377        allocate_output: Whether to allocate the output for the segmentation function.
378    """
379    if (scale is not None) and (model_resolution is not None):
380        raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.")
381
382    # Get the input files. If input_path is a folder then this will load all
383    # the mrc files beneath it. Otherwise we assume this is an mrc file already
384    # and just return the path to this mrc file.
385    input_files, input_root = _get_file_paths(input_path, data_ext)
386
387    # Load extra inputs if the extra_input_path was specified.
388    if extra_input_path is None:
389        extra_files = None
390    else:
391        extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext)
392        assert len(input_files) == len(extra_files)
393
394    # Load the masks if they were specified.
395    if mask_input_path is None:
396        mask_files = None
397    else:
398        mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext)
399        assert len(input_files) == len(mask_files)
400
401    for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"):
402        # Determine the output file name.
403        input_folder, input_name = os.path.split(img_path)
404
405        if output_key is None:
406            fname = os.path.splitext(input_name)[0] + "_prediction.tif"
407        else:
408            fname = os.path.splitext(input_name)[0] + "_prediction.h5"
409
410        if input_root is None:
411            output_path = os.path.join(output_root, fname)
412        else:  # If we have nested input folders then we preserve the folder structure in the output.
413            rel_folder = os.path.relpath(input_folder, input_root)
414            output_path = os.path.join(output_root, rel_folder, fname)
415
416        # Check if the output path is already present.
417        # If it is we skip the prediction, unless force was set to true.
418        if os.path.exists(output_path) and not force:
419            if output_key is None:
420                continue
421            else:
422                with open_file(output_path, "r") as f:
423                    if output_key in f:
424                        continue
425
426        # Load the input volume. If we have extra_files then this concatenates the
427        # data across a new first axis (= channel axis).
428        input_volume = _load_input(img_path, extra_files, i)
429        # Load the mask (if given).
430        mask = None if mask_files is None else imageio.imread(mask_files[i])
431
432        # Determine the scale factor:
433        # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None.
434        if scale is None and model_resolution is None:
435            this_scale = None
436        elif scale is not None:   # If 'scale' was passed then use it.
437            this_scale = scale
438        else:   # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data
439            assert model_resolution is not None
440            this_scale = _derive_scale(img_path, model_resolution)
441
442        # Run the segmentation.
443        if allocate_output:
444            segmentation = np.zeros(input_volume.shape, dtype="uint32")
445            segmentation_function(input_volume, output=segmentation, mask=mask, scale=this_scale)
446        else:
447            segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
448
449        # Write the result to tif or h5.
450        os.makedirs(os.path.split(output_path)[0], exist_ok=True)
451
452        if output_key is None:
453            imageio.imwrite(output_path, segmentation, compression="zlib")
454        else:
455            with open_file(output_path, "a") as f:
456                f.create_dataset(output_key, data=segmentation, compression="gzip")
457
458        print(f"Saved segmentation to {output_path}.")
459
460
461def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
462    """Determine the tile shape and halo depending on the available VRAM.
463
464    Args:
465        is_2d: Whether to return tiling settings for 2d inference.
466
467    Returns:
468        The default tiling settings for the available computational resources.
469    """
470    if is_2d:
471        tile = {"x": 768, "y": 768, "z": 1}
472        halo = {"x": 128, "y": 128, "z": 0}
473        return {"tile": tile, "halo": halo}
474
475    if torch.cuda.is_available():
476        # The default halo size.
477        halo = {"x": 64, "y": 64, "z": 16}
478
479        # Determine the GPU RAM and derive a suitable tiling.
480        vram = torch.cuda.get_device_properties(0).total_memory / 1e9
481
482        if vram >= 80:
483            tile = {"x": 640, "y": 640, "z": 80}
484        elif vram >= 40:
485            tile = {"x": 512, "y": 512, "z": 64}
486        elif vram >= 20:
487            tile = {"x": 352, "y": 352, "z": 48}
488        elif vram >= 10:
489            tile = {"x": 256, "y": 256, "z": 32}
490            halo = {"x": 64, "y": 64, "z": 8}  # Choose a smaller halo in z.
491        else:
492            raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.")
493
494        tiling = {"tile": tile, "halo": halo}
495        print(f"Determined tile size for CUDA: {tiling}")
496
497    elif torch.backends.mps.is_available():  # Check for Apple Silicon (MPS)
498        tile = {"x": 256, "y": 256, "z": 16}
499        halo = {"x": 16, "y": 16, "z": 4}
500        tiling = {"tile": tile, "halo": halo}
501        print(f"Determined tile size for MPS: {tiling}")
502
503    # I am not sure what is reasonable on a cpu. For now choosing very small tiling.
504    # (This will not work well on a CPU in any case.)
505    else:
506        tiling = {
507            "tile": {"x": 96, "y": 96, "z": 16},
508            "halo": {"x": 16, "y": 16, "z": 4},
509        }
510        print(f"Determining default tiling for CPU: {tiling}")
511
512    return tiling
513
514
515def parse_tiling(
516    tile_shape: Tuple[int, int, int],
517    halo: Tuple[int, int, int],
518    is_2d: bool = False,
519) -> Dict[str, Dict[str, int]]:
520    """Helper function to parse tiling parameter input from the command line.
521
522    Args:
523        tile_shape: The tile shape. If None the default tile shape is used.
524        halo: The halo. If None the default halo is used.
525        is_2d: Whether to return tiling for a 2d model.
526
527    Returns:
528        The tiling specification.
529    """
530
531    default_tiling = get_default_tiling(is_2d=is_2d)
532
533    if tile_shape is None:
534        tile_shape = default_tiling["tile"]
535    else:
536        assert len(tile_shape) == 3
537        tile_shape = dict(zip("zyx", tile_shape))
538
539    if halo is None:
540        halo = default_tiling["halo"]
541    else:
542        assert len(halo) == 3
543        halo = dict(zip("zyx", halo))
544
545    tiling = {"tile": tile_shape, "halo": halo}
546    return tiling
547
548
549#
550# Utils for post-processing.
551#
552
553
554def apply_size_filter(
555    segmentation: np.ndarray,
556    min_size: int,
557    verbose: bool = False,
558    block_shape: Tuple[int, int, int] = (128, 256, 256),
559) -> np.ndarray:
560    """Apply size filter to the segmentation to remove small objects.
561
562    Args:
563        segmentation: The segmentation.
564        min_size: The minimal object size in pixels.
565        verbose: Whether to print runtimes.
566        block_shape: Block shape for parallelizing the operations.
567
568    Returns:
569        The size filtered segmentation.
570    """
571    if min_size == 0:
572        return segmentation
573    t0 = time.time()
574    if segmentation.ndim == 2 and len(block_shape) == 3:
575        block_shape_ = block_shape[1:]
576    else:
577        block_shape_ = block_shape
578    ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose)
579    filter_ids = ids[sizes < min_size]
580    segmentation[np.isin(segmentation, filter_ids)] = 0
581    if verbose:
582        print("Size filter in", time.time() - t0, "s")
583    return segmentation
584
585
586def _postprocess_seg_3d(seg, area_threshold=1000, iterations=4, iterations_3d=8):
587    # Structure lement for 2d dilation in 3d.
588    structure_element = np.ones((3, 3))  # 3x3 structure for XY plane
589    structure_3d = np.zeros((1, 3, 3))  # Only applied in the XY plane
590    structure_3d[0] = structure_element
591
592    props = regionprops(seg)
593    for prop in props:
594        # Get bounding box and mask.
595        bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:3], prop.bbox[3:]))
596        mask = seg[bb] == prop.label
597
598        # Fill small holes and apply closing.
599        mask = remove_small_holes(mask, area_threshold=area_threshold)
600        mask = np.logical_or(binary_closing(mask, iterations=iterations), mask)
601        mask = np.logical_or(binary_closing(mask, iterations=iterations_3d, structure=structure_3d), mask)
602        seg[bb][mask] = prop.label
603
604    return seg
605
606
607#
608# Utils for torch device.
609#
610
611def _get_default_device():
612    # Check that we're in CI and use the CPU if we are.
613    # Otherwise the tests may run out of memory on MAC if MPS is used.
614    if os.getenv("GITHUB_ACTIONS") == "true":
615        return "cpu"
616    # Use cuda enabled gpu if it's available.
617    if torch.cuda.is_available():
618        device = "cuda"
619    # As second priority use mps.
620    # See https://pytorch.org/docs/stable/notes/mps.html for details
621    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
622        device = "mps"
623    # Use the CPU as fallback.
624    else:
625        device = "cpu"
626    return device
627
628
629def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
630    """Get the torch device.
631
632    If no device is passed the default device for your system is used.
633    Else it will be checked if the device you have passed is supported.
634
635    Args:
636        device: The input device.
637
638    Returns:
639        The device.
640    """
641    if device is None or device == "auto":
642        device = _get_default_device()
643    else:
644        device_type = device if isinstance(device, str) else device.type
645        if device_type.lower() == "cuda":
646            if not torch.cuda.is_available():
647                raise RuntimeError("PyTorch CUDA backend is not available.")
648        elif device_type.lower() == "mps":
649            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
650                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
651        elif device_type.lower() == "cpu":
652            pass  # cpu is always available
653        else:
654            raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.")
655    return device
def get_prediction( input_volume: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes]], tiling: Optional[Dict[str, Dict[str, int]]], model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, verbose: bool = True, with_channels: bool = False, channels_to_standardize: Optional[List[int]] = None, mask: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes], NoneType] = None, prediction: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes], NoneType] = None, devices: Optional[List[str]] = None) -> Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes]]:
119def get_prediction(
120    input_volume: ArrayLike,  # [z, y, x]
121    tiling: Optional[Dict[str, Dict[str, int]]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
122    model_path: Optional[str] = None,
123    model: Optional[torch.nn.Module] = None,
124    verbose: bool = True,
125    with_channels: bool = False,
126    channels_to_standardize: Optional[List[int]] = None,
127    mask: Optional[ArrayLike] = None,
128    prediction: Optional[ArrayLike] = None,
129    devices: Optional[List[str]] = None,
130) -> ArrayLike:
131    """Run prediction on a given volume.
132
133    This function will automatically choose the correct prediction implementation,
134    depending on the model type.
135
136    Args:
137        input_volume: The input volume to predict on.
138        model_path: The path to the model checkpoint if 'model' is not provided.
139        model: Pre-loaded model. Either model_path or model is required.
140        tiling: The tiling configuration for the prediction.
141        verbose: Whether to print timing information.
142        with_channels: Whether to predict with channels.
143        channels_to_standardize: List of channels to standardize. Defaults to None.
144        mask: Optional binary mask. If given, the prediction will only be run in
145            the foreground region of the mask.
146        prediction: An array like object for writing the prediction.
147            If not given, the prediction will be computed in moemory.
148        devices: The devices for running prediction. If not given will use the GPU
149            if available, otherwise the CPU.
150
151    Returns:
152        The predicted volume.
153    """
154    # make sure either model path or model is passed
155    if model is None and model_path is None:
156        raise ValueError("Either 'model_path' or 'model' must be provided.")
157
158    if model is not None:
159        is_bioimageio = None
160    else:
161        is_bioimageio = model_path.endswith(".zip")
162
163    if tiling is None:
164        tiling = get_default_tiling()
165
166    # Normalize the whole input volume if it is a numpy array.
167    # Otherwise we have a zarr array or similar as input, and can't normalize it en-block.
168    # Normalization will be applied later per block in this case.
169    if isinstance(input_volume, np.ndarray):
170        input_volume = _preprocess(input_volume, with_channels, channels_to_standardize)
171
172    # Run prediction with the bioimage.io library.
173    if is_bioimageio:
174        if mask is not None:
175            raise NotImplementedError
176        raise NotImplementedError
177
178    # Run prediction with the torch-em library.
179    else:
180        if model is None:
181            # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself.
182            if model_path.endswith("best.pt"):
183                model_path = os.path.split(model_path)[0]
184        # print(f"tiling {tiling}")
185        # Create updated_tiling with the same structure
186        updated_tiling = {
187            "tile": {},
188            "halo": tiling["halo"]  # Keep the halo part unchanged
189        }
190        # Update tile dimensions
191        for dim in tiling["tile"]:
192            updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim]
193        # print(f"updated_tiling {updated_tiling}")
194        prediction = get_prediction_torch_em(
195            input_volume, updated_tiling, model_path, model, verbose, with_channels,
196            mask=mask, prediction=prediction, devices=devices,
197        )
198
199    return prediction

Run prediction on a given volume.

This function will automatically choose the correct prediction implementation, depending on the model type.

Arguments:
  • input_volume: The input volume to predict on.
  • model_path: The path to the model checkpoint if 'model' is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • verbose: Whether to print timing information.
  • with_channels: Whether to predict with channels.
  • channels_to_standardize: List of channels to standardize. Defaults to None.
  • mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
  • prediction: An array like object for writing the prediction. If not given, the prediction will be computed in moemory.
  • devices: The devices for running prediction. If not given will use the GPU if available, otherwise the CPU.
Returns:

The predicted volume.

def get_prediction_torch_em( input_volume: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes]], tiling: Dict[str, Dict[str, int]], model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, verbose: bool = True, with_channels: bool = False, mask: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes], NoneType] = None, prediction: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[bool | int | float | complex | str | bytes], NoneType] = None, devices: Optional[List[str]] = None) -> numpy.ndarray:
202def get_prediction_torch_em(
203    input_volume: ArrayLike,  # [z, y, x]
204    tiling: Dict[str, Dict[str, int]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
205    model_path: Optional[str] = None,
206    model: Optional[torch.nn.Module] = None,
207    verbose: bool = True,
208    with_channels: bool = False,
209    mask: Optional[ArrayLike] = None,
210    prediction: Optional[ArrayLike] = None,
211    devices: Optional[List[str]] = None,
212) -> np.ndarray:
213    """Run prediction using torch-em on a given volume.
214
215    Args:
216        input_volume: The input volume to predict on.
217        model_path: The path to the model checkpoint if 'model' is not provided.
218        model: Pre-loaded model. Either model_path or model is required.
219        tiling: The tiling configuration for the prediction.
220        verbose: Whether to print timing information.
221        with_channels: Whether to predict with channels.
222        mask: Optional binary mask. If given, the prediction will only be run in
223            the foreground region of the mask.
224        prediction: An array like object for writing the prediction.
225            If not given, the prediction will be computed in moemory.
226        devices: The devices for running prediction. If not given will use the GPU
227            if available, otherwise the CPU.
228
229    Returns:
230        The predicted volume.
231    """
232    # get block_shape and halo
233    block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]]
234    halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]]
235
236    t0 = time.time()
237    if devices is None:
238        devices = ["cuda" if torch.cuda.is_available() else "cpu"]
239
240    # Suppress warning when loading the model.
241    with warnings.catch_warnings():
242        warnings.simplefilter("ignore")
243        if model is None:
244            if os.path.isdir(model_path):  # Load the model from a torch_em checkpoint.
245                model = torch_em.util.load_model(checkpoint=model_path, device=devices[0])
246            else:  # Load the model directly from a serialized pytorch model.
247                model = torch.load(model_path, weights_only=False)
248
249    # Run prediction with the model.
250    with torch.no_grad():
251
252        # Deal with 2D segmentation case
253        if len(input_volume.shape) == 2:
254            block_shape = [block_shape[1], block_shape[2]]
255            halo = [halo[1], halo[2]]
256
257        if mask is not None:
258            if verbose:
259                print("Run prediction with mask.")
260            mask = mask.astype("bool")
261
262        preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize
263        prediction = predict_with_halo(
264            input_volume, model, gpu_ids=devices,
265            block_shape=block_shape, halo=halo,
266            preprocess=preprocess, with_channels=with_channels, mask=mask,
267            output=prediction,
268        )
269    if verbose:
270        print("Prediction time in", time.time() - t0, "s")
271    return prediction

Run prediction using torch-em on a given volume.

Arguments:
  • input_volume: The input volume to predict on.
  • model_path: The path to the model checkpoint if 'model' is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • verbose: Whether to print timing information.
  • with_channels: Whether to predict with channels.
  • mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
  • prediction: An array like object for writing the prediction. If not given, the prediction will be computed in moemory.
  • devices: The devices for running prediction. If not given will use the GPU if available, otherwise the CPU.
Returns:

The predicted volume.

def inference_helper( input_path: str, output_root: str, segmentation_function: <built-in function callable>, data_ext: str = '.mrc', extra_input_path: Optional[str] = None, extra_input_ext: str = '.tif', mask_input_path: Optional[str] = None, mask_input_ext: str = '.tif', force: bool = False, output_key: Optional[str] = None, model_resolution: Optional[Tuple[float, float, float]] = None, scale: Optional[Tuple[float, float, float]] = None, allocate_output: bool = False) -> None:
342def inference_helper(
343    input_path: str,
344    output_root: str,
345    segmentation_function: callable,
346    data_ext: str = ".mrc",
347    extra_input_path: Optional[str] = None,
348    extra_input_ext: str = ".tif",
349    mask_input_path: Optional[str] = None,
350    mask_input_ext: str = ".tif",
351    force: bool = False,
352    output_key: Optional[str] = None,
353    model_resolution: Optional[Tuple[float, float, float]] = None,
354    scale: Optional[Tuple[float, float, float]] = None,
355    allocate_output: bool = False,
356) -> None:
357    """Helper function to run segmentation for mrc files.
358
359    Args:
360        input_path: The path to the input data.
361            Can either be a folder. In this case all mrc files below the folder will be segmented.
362            Or can be a single mrc file. In this case only this mrc file will be segmented.
363        output_root: The path to the output directory where the segmentation results will be saved.
364        segmentation_function: The function performing the segmentation.
365            This function must take the input_volume as the only argument and must return only the segmentation.
366            If you want to pass additional arguments to this function the use 'funtools.partial'
367        data_ext: File extension for the image data. By default '.mrc' is used.
368        extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc.
369            This enables cristae segmentation with an extra mito channel.
370        extra_input_ext: File extension for the extra inputs (by default .tif).
371        mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
372        mask_input_ext: File extension for the mask inputs (by default .tif).
373        force: Whether to rerun segmentation for output files that are already present.
374        output_key: Output key for the prediction. If none will write an hdf5 file.
375        model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction.
376            If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
377        scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
378        allocate_output: Whether to allocate the output for the segmentation function.
379    """
380    if (scale is not None) and (model_resolution is not None):
381        raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.")
382
383    # Get the input files. If input_path is a folder then this will load all
384    # the mrc files beneath it. Otherwise we assume this is an mrc file already
385    # and just return the path to this mrc file.
386    input_files, input_root = _get_file_paths(input_path, data_ext)
387
388    # Load extra inputs if the extra_input_path was specified.
389    if extra_input_path is None:
390        extra_files = None
391    else:
392        extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext)
393        assert len(input_files) == len(extra_files)
394
395    # Load the masks if they were specified.
396    if mask_input_path is None:
397        mask_files = None
398    else:
399        mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext)
400        assert len(input_files) == len(mask_files)
401
402    for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"):
403        # Determine the output file name.
404        input_folder, input_name = os.path.split(img_path)
405
406        if output_key is None:
407            fname = os.path.splitext(input_name)[0] + "_prediction.tif"
408        else:
409            fname = os.path.splitext(input_name)[0] + "_prediction.h5"
410
411        if input_root is None:
412            output_path = os.path.join(output_root, fname)
413        else:  # If we have nested input folders then we preserve the folder structure in the output.
414            rel_folder = os.path.relpath(input_folder, input_root)
415            output_path = os.path.join(output_root, rel_folder, fname)
416
417        # Check if the output path is already present.
418        # If it is we skip the prediction, unless force was set to true.
419        if os.path.exists(output_path) and not force:
420            if output_key is None:
421                continue
422            else:
423                with open_file(output_path, "r") as f:
424                    if output_key in f:
425                        continue
426
427        # Load the input volume. If we have extra_files then this concatenates the
428        # data across a new first axis (= channel axis).
429        input_volume = _load_input(img_path, extra_files, i)
430        # Load the mask (if given).
431        mask = None if mask_files is None else imageio.imread(mask_files[i])
432
433        # Determine the scale factor:
434        # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None.
435        if scale is None and model_resolution is None:
436            this_scale = None
437        elif scale is not None:   # If 'scale' was passed then use it.
438            this_scale = scale
439        else:   # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data
440            assert model_resolution is not None
441            this_scale = _derive_scale(img_path, model_resolution)
442
443        # Run the segmentation.
444        if allocate_output:
445            segmentation = np.zeros(input_volume.shape, dtype="uint32")
446            segmentation_function(input_volume, output=segmentation, mask=mask, scale=this_scale)
447        else:
448            segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
449
450        # Write the result to tif or h5.
451        os.makedirs(os.path.split(output_path)[0], exist_ok=True)
452
453        if output_key is None:
454            imageio.imwrite(output_path, segmentation, compression="zlib")
455        else:
456            with open_file(output_path, "a") as f:
457                f.create_dataset(output_key, data=segmentation, compression="gzip")
458
459        print(f"Saved segmentation to {output_path}.")

Helper function to run segmentation for mrc files.

Arguments:
  • input_path: The path to the input data. Can either be a folder. In this case all mrc files below the folder will be segmented. Or can be a single mrc file. In this case only this mrc file will be segmented.
  • output_root: The path to the output directory where the segmentation results will be saved.
  • segmentation_function: The function performing the segmentation. This function must take the input_volume as the only argument and must return only the segmentation. If you want to pass additional arguments to this function the use 'funtools.partial'
  • data_ext: File extension for the image data. By default '.mrc' is used.
  • extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. This enables cristae segmentation with an extra mito channel.
  • extra_input_ext: File extension for the extra inputs (by default .tif).
  • mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
  • mask_input_ext: File extension for the mask inputs (by default .tif).
  • force: Whether to rerun segmentation for output files that are already present.
  • output_key: Output key for the prediction. If none will write an hdf5 file.
  • model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
  • scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
  • allocate_output: Whether to allocate the output for the segmentation function.
def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
462def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
463    """Determine the tile shape and halo depending on the available VRAM.
464
465    Args:
466        is_2d: Whether to return tiling settings for 2d inference.
467
468    Returns:
469        The default tiling settings for the available computational resources.
470    """
471    if is_2d:
472        tile = {"x": 768, "y": 768, "z": 1}
473        halo = {"x": 128, "y": 128, "z": 0}
474        return {"tile": tile, "halo": halo}
475
476    if torch.cuda.is_available():
477        # The default halo size.
478        halo = {"x": 64, "y": 64, "z": 16}
479
480        # Determine the GPU RAM and derive a suitable tiling.
481        vram = torch.cuda.get_device_properties(0).total_memory / 1e9
482
483        if vram >= 80:
484            tile = {"x": 640, "y": 640, "z": 80}
485        elif vram >= 40:
486            tile = {"x": 512, "y": 512, "z": 64}
487        elif vram >= 20:
488            tile = {"x": 352, "y": 352, "z": 48}
489        elif vram >= 10:
490            tile = {"x": 256, "y": 256, "z": 32}
491            halo = {"x": 64, "y": 64, "z": 8}  # Choose a smaller halo in z.
492        else:
493            raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.")
494
495        tiling = {"tile": tile, "halo": halo}
496        print(f"Determined tile size for CUDA: {tiling}")
497
498    elif torch.backends.mps.is_available():  # Check for Apple Silicon (MPS)
499        tile = {"x": 256, "y": 256, "z": 16}
500        halo = {"x": 16, "y": 16, "z": 4}
501        tiling = {"tile": tile, "halo": halo}
502        print(f"Determined tile size for MPS: {tiling}")
503
504    # I am not sure what is reasonable on a cpu. For now choosing very small tiling.
505    # (This will not work well on a CPU in any case.)
506    else:
507        tiling = {
508            "tile": {"x": 96, "y": 96, "z": 16},
509            "halo": {"x": 16, "y": 16, "z": 4},
510        }
511        print(f"Determining default tiling for CPU: {tiling}")
512
513    return tiling

Determine the tile shape and halo depending on the available VRAM.

Arguments:
  • is_2d: Whether to return tiling settings for 2d inference.
Returns:

The default tiling settings for the available computational resources.

def parse_tiling( tile_shape: Tuple[int, int, int], halo: Tuple[int, int, int], is_2d: bool = False) -> Dict[str, Dict[str, int]]:
516def parse_tiling(
517    tile_shape: Tuple[int, int, int],
518    halo: Tuple[int, int, int],
519    is_2d: bool = False,
520) -> Dict[str, Dict[str, int]]:
521    """Helper function to parse tiling parameter input from the command line.
522
523    Args:
524        tile_shape: The tile shape. If None the default tile shape is used.
525        halo: The halo. If None the default halo is used.
526        is_2d: Whether to return tiling for a 2d model.
527
528    Returns:
529        The tiling specification.
530    """
531
532    default_tiling = get_default_tiling(is_2d=is_2d)
533
534    if tile_shape is None:
535        tile_shape = default_tiling["tile"]
536    else:
537        assert len(tile_shape) == 3
538        tile_shape = dict(zip("zyx", tile_shape))
539
540    if halo is None:
541        halo = default_tiling["halo"]
542    else:
543        assert len(halo) == 3
544        halo = dict(zip("zyx", halo))
545
546    tiling = {"tile": tile_shape, "halo": halo}
547    return tiling

Helper function to parse tiling parameter input from the command line.

Arguments:
  • tile_shape: The tile shape. If None the default tile shape is used.
  • halo: The halo. If None the default halo is used.
  • is_2d: Whether to return tiling for a 2d model.
Returns:

The tiling specification.

def apply_size_filter( segmentation: numpy.ndarray, min_size: int, verbose: bool = False, block_shape: Tuple[int, int, int] = (128, 256, 256)) -> numpy.ndarray:
555def apply_size_filter(
556    segmentation: np.ndarray,
557    min_size: int,
558    verbose: bool = False,
559    block_shape: Tuple[int, int, int] = (128, 256, 256),
560) -> np.ndarray:
561    """Apply size filter to the segmentation to remove small objects.
562
563    Args:
564        segmentation: The segmentation.
565        min_size: The minimal object size in pixels.
566        verbose: Whether to print runtimes.
567        block_shape: Block shape for parallelizing the operations.
568
569    Returns:
570        The size filtered segmentation.
571    """
572    if min_size == 0:
573        return segmentation
574    t0 = time.time()
575    if segmentation.ndim == 2 and len(block_shape) == 3:
576        block_shape_ = block_shape[1:]
577    else:
578        block_shape_ = block_shape
579    ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose)
580    filter_ids = ids[sizes < min_size]
581    segmentation[np.isin(segmentation, filter_ids)] = 0
582    if verbose:
583        print("Size filter in", time.time() - t0, "s")
584    return segmentation

Apply size filter to the segmentation to remove small objects.

Arguments:
  • segmentation: The segmentation.
  • min_size: The minimal object size in pixels.
  • verbose: Whether to print runtimes.
  • block_shape: Block shape for parallelizing the operations.
Returns:

The size filtered segmentation.

def get_device( device: Union[torch.device, str, NoneType] = None) -> Union[str, torch.device]:
630def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
631    """Get the torch device.
632
633    If no device is passed the default device for your system is used.
634    Else it will be checked if the device you have passed is supported.
635
636    Args:
637        device: The input device.
638
639    Returns:
640        The device.
641    """
642    if device is None or device == "auto":
643        device = _get_default_device()
644    else:
645        device_type = device if isinstance(device, str) else device.type
646        if device_type.lower() == "cuda":
647            if not torch.cuda.is_available():
648                raise RuntimeError("PyTorch CUDA backend is not available.")
649        elif device_type.lower() == "mps":
650            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
651                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
652        elif device_type.lower() == "cpu":
653            pass  # cpu is always available
654        else:
655            raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.")
656    return device

Get the torch device.

If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.

Arguments:
  • device: The input device.
Returns:

The device.