synapse_net.inference.util

  1import os
  2import time
  3import warnings
  4from glob import glob
  5from typing import Dict, List, Optional, Tuple, Union
  6
  7# # Suppress annoying import warnings.
  8# with warnings.catch_warnings():
  9#     warnings.simplefilter("ignore")
 10#     import bioimageio.core
 11
 12import imageio.v3 as imageio
 13import elf.parallel as parallel
 14import mrcfile
 15import numpy as np
 16import torch
 17import torch_em
 18# import xarray
 19
 20from elf.io import open_file
 21from numpy.typing import ArrayLike
 22from scipy.ndimage import binary_closing
 23from skimage.measure import regionprops
 24from skimage.morphology import remove_small_holes
 25from skimage.transform import rescale, resize
 26from torch_em.util.prediction import predict_with_halo
 27from tqdm import tqdm
 28
 29
 30#
 31# Utils for prediction.
 32#
 33
 34
 35class _Scaler:
 36    def __init__(self, scale, verbose):
 37        self.verbose = verbose
 38        self._original_shape = None
 39
 40        if scale is None:
 41            self.scale = None
 42            return
 43
 44        # Convert scale to a NumPy array (ensures consistency)
 45        scale = np.atleast_1d(scale).astype(np.float64)
 46
 47        # Validate scale values
 48        if not np.issubdtype(scale.dtype, np.number):
 49            raise TypeError(f"Scale contains non-numeric values: {scale}")
 50
 51        # Check if scaling is effectively identity (1.0 in all dimensions)
 52        if np.allclose(scale, 1.0, atol=1e-3):
 53            self.scale = None
 54        else:
 55            self.scale = scale
 56
 57    def scale_input(self, input_volume, is_segmentation=False):
 58        t0 = time.time()
 59        if self.scale is None:
 60            return input_volume
 61
 62        if self._original_shape is None:
 63            self._original_shape = input_volume.shape
 64        elif self._original_shape != input_volume.shape:
 65            raise RuntimeError(
 66                "Scaler was called with different input shapes. "
 67                "This is not supported, please create a new instance of the class for it."
 68            )
 69
 70        if is_segmentation:
 71            input_volume = rescale(
 72                input_volume, self.scale, preserve_range=True, order=0, anti_aliasing=False,
 73            ).astype(input_volume.dtype)
 74        else:
 75            input_volume = rescale(input_volume, self.scale, preserve_range=True).astype(input_volume.dtype)
 76
 77        if self.verbose:
 78            print("Rescaled volume from", self._original_shape, "to", input_volume.shape, "in", time.time() - t0, "s")
 79        return input_volume
 80
 81    def rescale_output(self, output, is_segmentation):
 82        t0 = time.time()
 83        if self.scale is None:
 84            return output
 85
 86        assert self._original_shape is not None
 87        out_shape = self._original_shape
 88        if output.ndim > len(out_shape):
 89            assert output.ndim == len(out_shape) + 1
 90            out_shape = (output.shape[0],) + out_shape
 91
 92        if is_segmentation:
 93            output = resize(output, out_shape, preserve_range=True, order=0, anti_aliasing=False).astype(output.dtype)
 94        else:
 95            output = resize(output, out_shape, preserve_range=True).astype(output.dtype)
 96
 97        if self.verbose:
 98            print("Resized prediction back to original shape", output.shape, "in", time.time() - t0, "s")
 99
100        return output
101
102
103def _preprocess(input_volume, with_channels, channels_to_standardize):
104    # We standardize the data for the whole volume beforehand.
105    # If we have channels then the standardization is done independently per channel.
106    if with_channels:
107        input_volume = input_volume.astype(np.float32, copy=False)
108        # TODO Check that this is the correct axis.
109        if channels_to_standardize is None:  # assume all channels
110            channels_to_standardize = range(input_volume.shape[0])
111        for ch in channels_to_standardize:
112            input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch])
113    else:
114        input_volume = torch_em.transform.raw.standardize(input_volume)
115    return input_volume
116
117
118def get_prediction(
119    input_volume: ArrayLike,  # [z, y, x]
120    tiling: Optional[Dict[str, Dict[str, int]]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
121    model_path: Optional[str] = None,
122    model: Optional[torch.nn.Module] = None,
123    verbose: bool = True,
124    with_channels: bool = False,
125    channels_to_standardize: Optional[List[int]] = None,
126    mask: Optional[ArrayLike] = None,
127    prediction: Optional[ArrayLike] = None,
128    devices: Optional[List[str]] = None,
129    preprocess: Optional[callable] = None,
130) -> ArrayLike:
131    """Run prediction on a given volume.
132
133    This function will automatically choose the correct prediction implementation,
134    depending on the model type.
135
136    Args:
137        input_volume: The input volume to predict on.
138        model_path: The path to the model checkpoint if 'model' is not provided.
139        model: Pre-loaded model. Either model_path or model is required.
140        tiling: The tiling configuration for the prediction.
141        verbose: Whether to print timing information.
142        with_channels: Whether to predict with channels.
143        channels_to_standardize: List of channels to standardize. Defaults to None.
144        mask: Optional binary mask. If given, the prediction will only be run in
145            the foreground region of the mask.
146        prediction: An array like object for writing the prediction.
147            If not given, the prediction will be computed in moemory.
148        devices: The devices for running prediction. If not given will use the GPU
149            if available, otherwise the CPU.
150
151    Returns:
152        The predicted volume.
153    """
154    # make sure either model path or model is passed
155    if model is None and model_path is None:
156        raise ValueError("Either 'model_path' or 'model' must be provided.")
157
158    if model is not None:
159        is_bioimageio = None
160    else:
161        is_bioimageio = model_path.endswith(".zip")
162
163    if tiling is None:
164        tiling = get_default_tiling()
165
166    # Normalize the whole input volume if it is a numpy array.
167    # Otherwise we have a zarr array or similar as input, and can't normalize it en-block.
168    # Normalization will be applied later per block in this case.
169    if isinstance(input_volume, np.ndarray):
170        input_volume = _preprocess(input_volume, with_channels, channels_to_standardize)
171
172    # Run prediction with the bioimage.io library.
173    if is_bioimageio:
174        if mask is not None:
175            raise NotImplementedError
176        raise NotImplementedError
177
178    # Run prediction with the torch-em library.
179    else:
180        if model is None:
181            # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself.
182            if model_path.endswith("best.pt"):
183                model_path = os.path.split(model_path)[0]
184        # print(f"tiling {tiling}")
185        # Create updated_tiling with the same structure
186        updated_tiling = {
187            "tile": {},
188            "halo": tiling["halo"]  # Keep the halo part unchanged
189        }
190        # Update tile dimensions
191        for dim in tiling["tile"]:
192            updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim]
193        # print(f"updated_tiling {updated_tiling}")
194        prediction = get_prediction_torch_em(
195            input_volume, updated_tiling, model_path, model, verbose, with_channels,
196            mask=mask, prediction=prediction, devices=devices, preprocess=preprocess,
197        )
198
199    return prediction
200
201
202def get_prediction_torch_em(
203    input_volume: ArrayLike,  # [z, y, x]
204    tiling: Dict[str, Dict[str, int]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
205    model_path: Optional[str] = None,
206    model: Optional[torch.nn.Module] = None,
207    verbose: bool = True,
208    with_channels: bool = False,
209    mask: Optional[ArrayLike] = None,
210    prediction: Optional[ArrayLike] = None,
211    devices: Optional[List[str]] = None,
212    preprocess: Optional[callable] = None,
213) -> np.ndarray:
214    """Run prediction using torch-em on a given volume.
215
216    Args:
217        input_volume: The input volume to predict on.
218        model_path: The path to the model checkpoint if 'model' is not provided.
219        model: Pre-loaded model. Either model_path or model is required.
220        tiling: The tiling configuration for the prediction.
221        verbose: Whether to print timing information.
222        with_channels: Whether to predict with channels.
223        mask: Optional binary mask. If given, the prediction will only be run in
224            the foreground region of the mask.
225        prediction: An array like object for writing the prediction.
226            If not given, the prediction will be computed in moemory.
227        devices: The devices for running prediction. If not given will use the GPU
228            if available, otherwise the CPU.
229
230    Returns:
231        The predicted volume.
232    """
233    # get block_shape and halo
234    block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]]
235    halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]]
236
237    t0 = time.time()
238    if devices is None:
239        devices = ["cuda" if torch.cuda.is_available() else "cpu"]
240
241    # Suppress warning when loading the model.
242    with warnings.catch_warnings():
243        warnings.simplefilter("ignore")
244        if model is None:
245            if os.path.isdir(model_path):  # Load the model from a torch_em checkpoint.
246                model = torch_em.util.load_model(checkpoint=model_path, device=devices[0])
247            else:  # Load the model directly from a serialized pytorch model.
248                model = torch.load(model_path, weights_only=False)
249
250    # Run prediction with the model.
251    with torch.no_grad():
252
253        # Deal with 2D segmentation case
254        if len(input_volume.shape) == 2:
255            block_shape = [block_shape[1], block_shape[2]]
256            halo = [halo[1], halo[2]]
257
258        if mask is not None:
259            if verbose:
260                print("Run prediction with mask.")
261            mask = mask.astype("bool")
262
263        if preprocess is None:
264            preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize
265        else:
266            preprocess = preprocess
267        prediction = predict_with_halo(
268            input_volume, model, gpu_ids=devices,
269            block_shape=block_shape, halo=halo,
270            preprocess=preprocess, with_channels=with_channels, mask=mask,
271            output=prediction,
272        )
273    if verbose:
274        print("Prediction time in", time.time() - t0, "s")
275    return prediction
276
277
278def _get_file_paths(input_path, ext=".mrc"):
279    if not os.path.exists(input_path):
280        raise Exception(f"Input path not found {input_path}")
281
282    if os.path.isfile(input_path):
283        input_files = [input_path]
284        input_root = None
285    else:
286        input_files = sorted(glob(os.path.join(input_path, "**", f"*{ext}"), recursive=True))
287        input_root = input_path
288
289    return input_files, input_root
290
291
292def _load_input(img_path, extra_files, i):
293    # Load the input data.
294    if os.path.splitext(img_path)[-1] == ".tif":
295        input_volume = imageio.imread(img_path)
296
297    else:
298        with open_file(img_path, "r") as f:
299            # Try to automatically derive the key with the raw data.
300            keys = list(f.keys())
301            if len(keys) == 1:
302                key = keys[0]
303            elif "data" in keys:
304                key = "data"
305            elif "raw" in keys:
306                key = "raw"
307            input_volume = f[key][:]
308
309    assert input_volume.ndim in (2, 3)
310    # For now we assume this is always tif.
311    if extra_files is not None:
312        extra_input = imageio.imread(extra_files[i])
313        assert extra_input.shape == input_volume.shape
314        input_volume = np.stack([input_volume, extra_input], axis=0)
315
316    return input_volume
317
318
319def _derive_scale(img_path, model_resolution):
320    try:
321        with mrcfile.open(img_path, "r") as f:
322            voxel_size = f.voxel_size
323            if len(model_resolution) == 2:
324                voxel_size = [voxel_size.y, voxel_size.x]
325            else:
326                voxel_size = [voxel_size.z, voxel_size.y, voxel_size.x]
327
328        assert len(voxel_size) == len(model_resolution)
329        # The voxel size is given in Angstrom and we need to translate it to nanometer.
330        voxel_size = [vsize / 10 for vsize in voxel_size]
331
332        # Compute the correct scale factor.
333        scale = tuple(vsize / res for vsize, res in zip(voxel_size, model_resolution))
334        print("Rescaling the data at", img_path, "by", scale, "to match the training voxel size", model_resolution)
335
336    except Exception:
337        warnings.warn(
338            f"The voxel size could not be read from the data for {img_path}. "
339            "This data will not be scaled for prediction."
340        )
341        scale = None
342
343    return scale
344
345
346def inference_helper(
347    input_path: str,
348    output_root: str,
349    segmentation_function: callable,
350    data_ext: str = ".mrc",
351    extra_input_path: Optional[str] = None,
352    extra_input_ext: str = ".tif",
353    mask_input_path: Optional[str] = None,
354    mask_input_ext: str = ".tif",
355    force: bool = False,
356    output_key: Optional[str] = None,
357    model_resolution: Optional[Tuple[float, float, float]] = None,
358    scale: Optional[Tuple[float, float, float]] = None,
359    allocate_output: bool = False,
360) -> None:
361    """Helper function to run segmentation for mrc files.
362
363    Args:
364        input_path: The path to the input data.
365            Can either be a folder. In this case all mrc files below the folder will be segmented.
366            Or can be a single mrc file. In this case only this mrc file will be segmented.
367        output_root: The path to the output directory where the segmentation results will be saved.
368        segmentation_function: The function performing the segmentation.
369            This function must take the input_volume as the only argument and must return only the segmentation.
370            If you want to pass additional arguments to this function the use 'funtools.partial'
371        data_ext: File extension for the image data. By default '.mrc' is used.
372        extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc.
373            This enables cristae segmentation with an extra mito channel.
374        extra_input_ext: File extension for the extra inputs (by default .tif).
375        mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
376        mask_input_ext: File extension for the mask inputs (by default .tif).
377        force: Whether to rerun segmentation for output files that are already present.
378        output_key: Output key for the prediction. If none will write an hdf5 file.
379        model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction.
380            If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
381        scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
382        allocate_output: Whether to allocate the output for the segmentation function.
383    """
384    if (scale is not None) and (model_resolution is not None):
385        raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.")
386
387    # Get the input files. If input_path is a folder then this will load all
388    # the mrc files beneath it. Otherwise we assume this is an mrc file already
389    # and just return the path to this mrc file.
390    input_files, input_root = _get_file_paths(input_path, data_ext)
391
392    # Load extra inputs if the extra_input_path was specified.
393    if extra_input_path is None:
394        extra_files = None
395    else:
396        extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext)
397        assert len(input_files) == len(extra_files)
398
399    # Load the masks if they were specified.
400    if mask_input_path is None:
401        mask_files = None
402    else:
403        mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext)
404        assert len(input_files) == len(mask_files)
405
406    for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"):
407        # Determine the output file name.
408        input_folder, input_name = os.path.split(img_path)
409
410        if output_key is None:
411            fname = os.path.splitext(input_name)[0] + "_prediction.tif"
412        else:
413            fname = os.path.splitext(input_name)[0] + "_prediction.h5"
414
415        if input_root is None:
416            output_path = os.path.join(output_root, fname)
417        else:  # If we have nested input folders then we preserve the folder structure in the output.
418            rel_folder = os.path.relpath(input_folder, input_root)
419            output_path = os.path.join(output_root, rel_folder, fname)
420
421        # Check if the output path is already present.
422        # If it is we skip the prediction, unless force was set to true.
423        if os.path.exists(output_path) and not force:
424            if output_key is None:
425                continue
426            else:
427                with open_file(output_path, "r") as f:
428                    if output_key in f:
429                        continue
430
431        # Load the input volume. If we have extra_files then this concatenates the
432        # data across a new first axis (= channel axis).
433        input_volume = _load_input(img_path, extra_files, i)
434        # Load the mask (if given).
435        mask = None if mask_files is None else imageio.imread(mask_files[i])
436
437        # Determine the scale factor:
438        # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None.
439        if scale is None and model_resolution is None:
440            this_scale = None
441        elif scale is not None:   # If 'scale' was passed then use it.
442            this_scale = scale
443        else:   # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data
444            assert model_resolution is not None
445            this_scale = _derive_scale(img_path, model_resolution)
446
447        # Run the segmentation.
448        if allocate_output:
449            segmentation = np.zeros(input_volume.shape, dtype="uint32")
450            segmentation_function(input_volume, output=segmentation, mask=mask, scale=this_scale)
451        else:
452            segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
453
454        # Write the result to tif or h5.
455        os.makedirs(os.path.split(output_path)[0], exist_ok=True)
456
457        if output_key is None:
458            imageio.imwrite(output_path, segmentation, compression="zlib")
459        else:
460            with open_file(output_path, "a") as f:
461                f.create_dataset(output_key, data=segmentation, compression="gzip")
462
463        print(f"Saved segmentation to {output_path}.")
464
465
466def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
467    """Determine the tile shape and halo depending on the available VRAM.
468
469    Args:
470        is_2d: Whether to return tiling settings for 2d inference.
471
472    Returns:
473        The default tiling settings for the available computational resources.
474    """
475    if is_2d:
476        tile = {"x": 768, "y": 768, "z": 1}
477        halo = {"x": 128, "y": 128, "z": 0}
478        return {"tile": tile, "halo": halo}
479
480    if torch.cuda.is_available():
481        # The default halo size.
482        halo = {"x": 64, "y": 64, "z": 16}
483
484        # Determine the GPU RAM and derive a suitable tiling.
485        vram = torch.cuda.get_device_properties(0).total_memory / 1e9
486
487        if vram >= 80:
488            tile = {"x": 640, "y": 640, "z": 80}
489        elif vram >= 40:
490            tile = {"x": 512, "y": 512, "z": 64}
491        elif vram >= 20:
492            tile = {"x": 352, "y": 352, "z": 48}
493        elif vram >= 10:
494            tile = {"x": 256, "y": 256, "z": 32}
495            halo = {"x": 64, "y": 64, "z": 8}  # Choose a smaller halo in z.
496        else:
497            raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.")
498
499        tiling = {"tile": tile, "halo": halo}
500        print(f"Determined tile size for CUDA: {tiling}")
501
502    elif torch.backends.mps.is_available():  # Check for Apple Silicon (MPS)
503        tile = {"x": 256, "y": 256, "z": 16}
504        halo = {"x": 16, "y": 16, "z": 4}
505        tiling = {"tile": tile, "halo": halo}
506        print(f"Determined tile size for MPS: {tiling}")
507
508    # I am not sure what is reasonable on a cpu. For now choosing very small tiling.
509    # (This will not work well on a CPU in any case.)
510    else:
511        tiling = {
512            "tile": {"x": 96, "y": 96, "z": 16},
513            "halo": {"x": 16, "y": 16, "z": 4},
514        }
515        print(f"Determining default tiling for CPU: {tiling}")
516
517    return tiling
518
519
520def parse_tiling(
521    tile_shape: Tuple[int, int, int],
522    halo: Tuple[int, int, int],
523    is_2d: bool = False,
524) -> Dict[str, Dict[str, int]]:
525    """Helper function to parse tiling parameter input from the command line.
526
527    Args:
528        tile_shape: The tile shape. If None the default tile shape is used.
529        halo: The halo. If None the default halo is used.
530        is_2d: Whether to return tiling for a 2d model.
531
532    Returns:
533        The tiling specification.
534    """
535
536    default_tiling = get_default_tiling(is_2d=is_2d)
537
538    if tile_shape is None:
539        tile_shape = default_tiling["tile"]
540    else:
541        assert len(tile_shape) == 3
542        tile_shape = dict(zip("zyx", tile_shape))
543
544    if halo is None:
545        halo = default_tiling["halo"]
546    else:
547        assert len(halo) == 3
548        halo = dict(zip("zyx", halo))
549
550    tiling = {"tile": tile_shape, "halo": halo}
551    return tiling
552
553
554#
555# Utils for post-processing.
556#
557
558
559def apply_size_filter(
560    segmentation: np.ndarray,
561    min_size: int,
562    verbose: bool = False,
563    block_shape: Tuple[int, int, int] = (128, 256, 256),
564) -> np.ndarray:
565    """Apply size filter to the segmentation to remove small objects.
566
567    Args:
568        segmentation: The segmentation.
569        min_size: The minimal object size in pixels.
570        verbose: Whether to print runtimes.
571        block_shape: Block shape for parallelizing the operations.
572
573    Returns:
574        The size filtered segmentation.
575    """
576    if min_size == 0:
577        return segmentation
578    t0 = time.time()
579    if segmentation.ndim == 2 and len(block_shape) == 3:
580        block_shape_ = block_shape[1:]
581    else:
582        block_shape_ = block_shape
583    ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose)
584    filter_ids = ids[sizes < min_size]
585    segmentation[np.isin(segmentation, filter_ids)] = 0
586    if verbose:
587        print("Size filter in", time.time() - t0, "s")
588    return segmentation
589
590
591def _postprocess_seg_3d(seg, area_threshold=1000, iterations=4, iterations_3d=8):
592    # Structure lement for 2d dilation in 3d.
593    structure_element = np.ones((3, 3))  # 3x3 structure for XY plane
594    structure_3d = np.zeros((1, 3, 3))  # Only applied in the XY plane
595    structure_3d[0] = structure_element
596
597    props = regionprops(seg)
598    for prop in props:
599        # Get bounding box and mask.
600        bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:3], prop.bbox[3:]))
601        mask = seg[bb] == prop.label
602
603        # Fill small holes and apply closing.
604        mask = remove_small_holes(mask, area_threshold=area_threshold)
605        mask = np.logical_or(binary_closing(mask, iterations=iterations), mask)
606        mask = np.logical_or(binary_closing(mask, iterations=iterations_3d, structure=structure_3d), mask)
607        seg[bb][mask] = prop.label
608
609    return seg
610
611
612#
613# Utils for torch device.
614#
615
616def _get_default_device():
617    # Check that we're in CI and use the CPU if we are.
618    # Otherwise the tests may run out of memory on MAC if MPS is used.
619    if os.getenv("GITHUB_ACTIONS") == "true":
620        return "cpu"
621    # Use cuda enabled gpu if it's available.
622    if torch.cuda.is_available():
623        device = "cuda"
624    # As second priority use mps.
625    # See https://pytorch.org/docs/stable/notes/mps.html for details
626    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
627        device = "mps"
628    # Use the CPU as fallback.
629    else:
630        device = "cpu"
631    return device
632
633
634def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
635    """Get the torch device.
636
637    If no device is passed the default device for your system is used.
638    Else it will be checked if the device you have passed is supported.
639
640    Args:
641        device: The input device.
642
643    Returns:
644        The device.
645    """
646    if device is None or device == "auto":
647        device = _get_default_device()
648    else:
649        device_type = device if isinstance(device, str) else device.type
650        if device_type.lower() == "cuda":
651            if not torch.cuda.is_available():
652                raise RuntimeError("PyTorch CUDA backend is not available.")
653        elif device_type.lower() == "mps":
654            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
655                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
656        elif device_type.lower() == "cpu":
657            pass  # cpu is always available
658        else:
659            raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.")
660    return device
def get_prediction( input_volume: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]], tiling: Optional[Dict[str, Dict[str, int]]], model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, verbose: bool = True, with_channels: bool = False, channels_to_standardize: Optional[List[int]] = None, mask: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str], NoneType] = None, prediction: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str], NoneType] = None, devices: Optional[List[str]] = None, preprocess: Optional[<built-in function callable>] = None) -> Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]]:
119def get_prediction(
120    input_volume: ArrayLike,  # [z, y, x]
121    tiling: Optional[Dict[str, Dict[str, int]]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
122    model_path: Optional[str] = None,
123    model: Optional[torch.nn.Module] = None,
124    verbose: bool = True,
125    with_channels: bool = False,
126    channels_to_standardize: Optional[List[int]] = None,
127    mask: Optional[ArrayLike] = None,
128    prediction: Optional[ArrayLike] = None,
129    devices: Optional[List[str]] = None,
130    preprocess: Optional[callable] = None,
131) -> ArrayLike:
132    """Run prediction on a given volume.
133
134    This function will automatically choose the correct prediction implementation,
135    depending on the model type.
136
137    Args:
138        input_volume: The input volume to predict on.
139        model_path: The path to the model checkpoint if 'model' is not provided.
140        model: Pre-loaded model. Either model_path or model is required.
141        tiling: The tiling configuration for the prediction.
142        verbose: Whether to print timing information.
143        with_channels: Whether to predict with channels.
144        channels_to_standardize: List of channels to standardize. Defaults to None.
145        mask: Optional binary mask. If given, the prediction will only be run in
146            the foreground region of the mask.
147        prediction: An array like object for writing the prediction.
148            If not given, the prediction will be computed in moemory.
149        devices: The devices for running prediction. If not given will use the GPU
150            if available, otherwise the CPU.
151
152    Returns:
153        The predicted volume.
154    """
155    # make sure either model path or model is passed
156    if model is None and model_path is None:
157        raise ValueError("Either 'model_path' or 'model' must be provided.")
158
159    if model is not None:
160        is_bioimageio = None
161    else:
162        is_bioimageio = model_path.endswith(".zip")
163
164    if tiling is None:
165        tiling = get_default_tiling()
166
167    # Normalize the whole input volume if it is a numpy array.
168    # Otherwise we have a zarr array or similar as input, and can't normalize it en-block.
169    # Normalization will be applied later per block in this case.
170    if isinstance(input_volume, np.ndarray):
171        input_volume = _preprocess(input_volume, with_channels, channels_to_standardize)
172
173    # Run prediction with the bioimage.io library.
174    if is_bioimageio:
175        if mask is not None:
176            raise NotImplementedError
177        raise NotImplementedError
178
179    # Run prediction with the torch-em library.
180    else:
181        if model is None:
182            # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself.
183            if model_path.endswith("best.pt"):
184                model_path = os.path.split(model_path)[0]
185        # print(f"tiling {tiling}")
186        # Create updated_tiling with the same structure
187        updated_tiling = {
188            "tile": {},
189            "halo": tiling["halo"]  # Keep the halo part unchanged
190        }
191        # Update tile dimensions
192        for dim in tiling["tile"]:
193            updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim]
194        # print(f"updated_tiling {updated_tiling}")
195        prediction = get_prediction_torch_em(
196            input_volume, updated_tiling, model_path, model, verbose, with_channels,
197            mask=mask, prediction=prediction, devices=devices, preprocess=preprocess,
198        )
199
200    return prediction

Run prediction on a given volume.

This function will automatically choose the correct prediction implementation, depending on the model type.

Arguments:
  • input_volume: The input volume to predict on.
  • model_path: The path to the model checkpoint if 'model' is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • verbose: Whether to print timing information.
  • with_channels: Whether to predict with channels.
  • channels_to_standardize: List of channels to standardize. Defaults to None.
  • mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
  • prediction: An array like object for writing the prediction. If not given, the prediction will be computed in moemory.
  • devices: The devices for running prediction. If not given will use the GPU if available, otherwise the CPU.
Returns:

The predicted volume.

def get_prediction_torch_em( input_volume: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str]], tiling: Dict[str, Dict[str, int]], model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, verbose: bool = True, with_channels: bool = False, mask: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str], NoneType] = None, prediction: Union[Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], complex, bytes, str, numpy._typing._nested_sequence._NestedSequence[complex | bytes | str], NoneType] = None, devices: Optional[List[str]] = None, preprocess: Optional[<built-in function callable>] = None) -> numpy.ndarray:
203def get_prediction_torch_em(
204    input_volume: ArrayLike,  # [z, y, x]
205    tiling: Dict[str, Dict[str, int]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
206    model_path: Optional[str] = None,
207    model: Optional[torch.nn.Module] = None,
208    verbose: bool = True,
209    with_channels: bool = False,
210    mask: Optional[ArrayLike] = None,
211    prediction: Optional[ArrayLike] = None,
212    devices: Optional[List[str]] = None,
213    preprocess: Optional[callable] = None,
214) -> np.ndarray:
215    """Run prediction using torch-em on a given volume.
216
217    Args:
218        input_volume: The input volume to predict on.
219        model_path: The path to the model checkpoint if 'model' is not provided.
220        model: Pre-loaded model. Either model_path or model is required.
221        tiling: The tiling configuration for the prediction.
222        verbose: Whether to print timing information.
223        with_channels: Whether to predict with channels.
224        mask: Optional binary mask. If given, the prediction will only be run in
225            the foreground region of the mask.
226        prediction: An array like object for writing the prediction.
227            If not given, the prediction will be computed in moemory.
228        devices: The devices for running prediction. If not given will use the GPU
229            if available, otherwise the CPU.
230
231    Returns:
232        The predicted volume.
233    """
234    # get block_shape and halo
235    block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]]
236    halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]]
237
238    t0 = time.time()
239    if devices is None:
240        devices = ["cuda" if torch.cuda.is_available() else "cpu"]
241
242    # Suppress warning when loading the model.
243    with warnings.catch_warnings():
244        warnings.simplefilter("ignore")
245        if model is None:
246            if os.path.isdir(model_path):  # Load the model from a torch_em checkpoint.
247                model = torch_em.util.load_model(checkpoint=model_path, device=devices[0])
248            else:  # Load the model directly from a serialized pytorch model.
249                model = torch.load(model_path, weights_only=False)
250
251    # Run prediction with the model.
252    with torch.no_grad():
253
254        # Deal with 2D segmentation case
255        if len(input_volume.shape) == 2:
256            block_shape = [block_shape[1], block_shape[2]]
257            halo = [halo[1], halo[2]]
258
259        if mask is not None:
260            if verbose:
261                print("Run prediction with mask.")
262            mask = mask.astype("bool")
263
264        if preprocess is None:
265            preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize
266        else:
267            preprocess = preprocess
268        prediction = predict_with_halo(
269            input_volume, model, gpu_ids=devices,
270            block_shape=block_shape, halo=halo,
271            preprocess=preprocess, with_channels=with_channels, mask=mask,
272            output=prediction,
273        )
274    if verbose:
275        print("Prediction time in", time.time() - t0, "s")
276    return prediction

Run prediction using torch-em on a given volume.

Arguments:
  • input_volume: The input volume to predict on.
  • model_path: The path to the model checkpoint if 'model' is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • verbose: Whether to print timing information.
  • with_channels: Whether to predict with channels.
  • mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
  • prediction: An array like object for writing the prediction. If not given, the prediction will be computed in moemory.
  • devices: The devices for running prediction. If not given will use the GPU if available, otherwise the CPU.
Returns:

The predicted volume.

def inference_helper( input_path: str, output_root: str, segmentation_function: <built-in function callable>, data_ext: str = '.mrc', extra_input_path: Optional[str] = None, extra_input_ext: str = '.tif', mask_input_path: Optional[str] = None, mask_input_ext: str = '.tif', force: bool = False, output_key: Optional[str] = None, model_resolution: Optional[Tuple[float, float, float]] = None, scale: Optional[Tuple[float, float, float]] = None, allocate_output: bool = False) -> None:
347def inference_helper(
348    input_path: str,
349    output_root: str,
350    segmentation_function: callable,
351    data_ext: str = ".mrc",
352    extra_input_path: Optional[str] = None,
353    extra_input_ext: str = ".tif",
354    mask_input_path: Optional[str] = None,
355    mask_input_ext: str = ".tif",
356    force: bool = False,
357    output_key: Optional[str] = None,
358    model_resolution: Optional[Tuple[float, float, float]] = None,
359    scale: Optional[Tuple[float, float, float]] = None,
360    allocate_output: bool = False,
361) -> None:
362    """Helper function to run segmentation for mrc files.
363
364    Args:
365        input_path: The path to the input data.
366            Can either be a folder. In this case all mrc files below the folder will be segmented.
367            Or can be a single mrc file. In this case only this mrc file will be segmented.
368        output_root: The path to the output directory where the segmentation results will be saved.
369        segmentation_function: The function performing the segmentation.
370            This function must take the input_volume as the only argument and must return only the segmentation.
371            If you want to pass additional arguments to this function the use 'funtools.partial'
372        data_ext: File extension for the image data. By default '.mrc' is used.
373        extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc.
374            This enables cristae segmentation with an extra mito channel.
375        extra_input_ext: File extension for the extra inputs (by default .tif).
376        mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
377        mask_input_ext: File extension for the mask inputs (by default .tif).
378        force: Whether to rerun segmentation for output files that are already present.
379        output_key: Output key for the prediction. If none will write an hdf5 file.
380        model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction.
381            If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
382        scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
383        allocate_output: Whether to allocate the output for the segmentation function.
384    """
385    if (scale is not None) and (model_resolution is not None):
386        raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.")
387
388    # Get the input files. If input_path is a folder then this will load all
389    # the mrc files beneath it. Otherwise we assume this is an mrc file already
390    # and just return the path to this mrc file.
391    input_files, input_root = _get_file_paths(input_path, data_ext)
392
393    # Load extra inputs if the extra_input_path was specified.
394    if extra_input_path is None:
395        extra_files = None
396    else:
397        extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext)
398        assert len(input_files) == len(extra_files)
399
400    # Load the masks if they were specified.
401    if mask_input_path is None:
402        mask_files = None
403    else:
404        mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext)
405        assert len(input_files) == len(mask_files)
406
407    for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"):
408        # Determine the output file name.
409        input_folder, input_name = os.path.split(img_path)
410
411        if output_key is None:
412            fname = os.path.splitext(input_name)[0] + "_prediction.tif"
413        else:
414            fname = os.path.splitext(input_name)[0] + "_prediction.h5"
415
416        if input_root is None:
417            output_path = os.path.join(output_root, fname)
418        else:  # If we have nested input folders then we preserve the folder structure in the output.
419            rel_folder = os.path.relpath(input_folder, input_root)
420            output_path = os.path.join(output_root, rel_folder, fname)
421
422        # Check if the output path is already present.
423        # If it is we skip the prediction, unless force was set to true.
424        if os.path.exists(output_path) and not force:
425            if output_key is None:
426                continue
427            else:
428                with open_file(output_path, "r") as f:
429                    if output_key in f:
430                        continue
431
432        # Load the input volume. If we have extra_files then this concatenates the
433        # data across a new first axis (= channel axis).
434        input_volume = _load_input(img_path, extra_files, i)
435        # Load the mask (if given).
436        mask = None if mask_files is None else imageio.imread(mask_files[i])
437
438        # Determine the scale factor:
439        # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None.
440        if scale is None and model_resolution is None:
441            this_scale = None
442        elif scale is not None:   # If 'scale' was passed then use it.
443            this_scale = scale
444        else:   # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data
445            assert model_resolution is not None
446            this_scale = _derive_scale(img_path, model_resolution)
447
448        # Run the segmentation.
449        if allocate_output:
450            segmentation = np.zeros(input_volume.shape, dtype="uint32")
451            segmentation_function(input_volume, output=segmentation, mask=mask, scale=this_scale)
452        else:
453            segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
454
455        # Write the result to tif or h5.
456        os.makedirs(os.path.split(output_path)[0], exist_ok=True)
457
458        if output_key is None:
459            imageio.imwrite(output_path, segmentation, compression="zlib")
460        else:
461            with open_file(output_path, "a") as f:
462                f.create_dataset(output_key, data=segmentation, compression="gzip")
463
464        print(f"Saved segmentation to {output_path}.")

Helper function to run segmentation for mrc files.

Arguments:
  • input_path: The path to the input data. Can either be a folder. In this case all mrc files below the folder will be segmented. Or can be a single mrc file. In this case only this mrc file will be segmented.
  • output_root: The path to the output directory where the segmentation results will be saved.
  • segmentation_function: The function performing the segmentation. This function must take the input_volume as the only argument and must return only the segmentation. If you want to pass additional arguments to this function the use 'funtools.partial'
  • data_ext: File extension for the image data. By default '.mrc' is used.
  • extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. This enables cristae segmentation with an extra mito channel.
  • extra_input_ext: File extension for the extra inputs (by default .tif).
  • mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
  • mask_input_ext: File extension for the mask inputs (by default .tif).
  • force: Whether to rerun segmentation for output files that are already present.
  • output_key: Output key for the prediction. If none will write an hdf5 file.
  • model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
  • scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
  • allocate_output: Whether to allocate the output for the segmentation function.
def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
467def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
468    """Determine the tile shape and halo depending on the available VRAM.
469
470    Args:
471        is_2d: Whether to return tiling settings for 2d inference.
472
473    Returns:
474        The default tiling settings for the available computational resources.
475    """
476    if is_2d:
477        tile = {"x": 768, "y": 768, "z": 1}
478        halo = {"x": 128, "y": 128, "z": 0}
479        return {"tile": tile, "halo": halo}
480
481    if torch.cuda.is_available():
482        # The default halo size.
483        halo = {"x": 64, "y": 64, "z": 16}
484
485        # Determine the GPU RAM and derive a suitable tiling.
486        vram = torch.cuda.get_device_properties(0).total_memory / 1e9
487
488        if vram >= 80:
489            tile = {"x": 640, "y": 640, "z": 80}
490        elif vram >= 40:
491            tile = {"x": 512, "y": 512, "z": 64}
492        elif vram >= 20:
493            tile = {"x": 352, "y": 352, "z": 48}
494        elif vram >= 10:
495            tile = {"x": 256, "y": 256, "z": 32}
496            halo = {"x": 64, "y": 64, "z": 8}  # Choose a smaller halo in z.
497        else:
498            raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.")
499
500        tiling = {"tile": tile, "halo": halo}
501        print(f"Determined tile size for CUDA: {tiling}")
502
503    elif torch.backends.mps.is_available():  # Check for Apple Silicon (MPS)
504        tile = {"x": 256, "y": 256, "z": 16}
505        halo = {"x": 16, "y": 16, "z": 4}
506        tiling = {"tile": tile, "halo": halo}
507        print(f"Determined tile size for MPS: {tiling}")
508
509    # I am not sure what is reasonable on a cpu. For now choosing very small tiling.
510    # (This will not work well on a CPU in any case.)
511    else:
512        tiling = {
513            "tile": {"x": 96, "y": 96, "z": 16},
514            "halo": {"x": 16, "y": 16, "z": 4},
515        }
516        print(f"Determining default tiling for CPU: {tiling}")
517
518    return tiling

Determine the tile shape and halo depending on the available VRAM.

Arguments:
  • is_2d: Whether to return tiling settings for 2d inference.
Returns:

The default tiling settings for the available computational resources.

def parse_tiling( tile_shape: Tuple[int, int, int], halo: Tuple[int, int, int], is_2d: bool = False) -> Dict[str, Dict[str, int]]:
521def parse_tiling(
522    tile_shape: Tuple[int, int, int],
523    halo: Tuple[int, int, int],
524    is_2d: bool = False,
525) -> Dict[str, Dict[str, int]]:
526    """Helper function to parse tiling parameter input from the command line.
527
528    Args:
529        tile_shape: The tile shape. If None the default tile shape is used.
530        halo: The halo. If None the default halo is used.
531        is_2d: Whether to return tiling for a 2d model.
532
533    Returns:
534        The tiling specification.
535    """
536
537    default_tiling = get_default_tiling(is_2d=is_2d)
538
539    if tile_shape is None:
540        tile_shape = default_tiling["tile"]
541    else:
542        assert len(tile_shape) == 3
543        tile_shape = dict(zip("zyx", tile_shape))
544
545    if halo is None:
546        halo = default_tiling["halo"]
547    else:
548        assert len(halo) == 3
549        halo = dict(zip("zyx", halo))
550
551    tiling = {"tile": tile_shape, "halo": halo}
552    return tiling

Helper function to parse tiling parameter input from the command line.

Arguments:
  • tile_shape: The tile shape. If None the default tile shape is used.
  • halo: The halo. If None the default halo is used.
  • is_2d: Whether to return tiling for a 2d model.
Returns:

The tiling specification.

def apply_size_filter( segmentation: numpy.ndarray, min_size: int, verbose: bool = False, block_shape: Tuple[int, int, int] = (128, 256, 256)) -> numpy.ndarray:
560def apply_size_filter(
561    segmentation: np.ndarray,
562    min_size: int,
563    verbose: bool = False,
564    block_shape: Tuple[int, int, int] = (128, 256, 256),
565) -> np.ndarray:
566    """Apply size filter to the segmentation to remove small objects.
567
568    Args:
569        segmentation: The segmentation.
570        min_size: The minimal object size in pixels.
571        verbose: Whether to print runtimes.
572        block_shape: Block shape for parallelizing the operations.
573
574    Returns:
575        The size filtered segmentation.
576    """
577    if min_size == 0:
578        return segmentation
579    t0 = time.time()
580    if segmentation.ndim == 2 and len(block_shape) == 3:
581        block_shape_ = block_shape[1:]
582    else:
583        block_shape_ = block_shape
584    ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose)
585    filter_ids = ids[sizes < min_size]
586    segmentation[np.isin(segmentation, filter_ids)] = 0
587    if verbose:
588        print("Size filter in", time.time() - t0, "s")
589    return segmentation

Apply size filter to the segmentation to remove small objects.

Arguments:
  • segmentation: The segmentation.
  • min_size: The minimal object size in pixels.
  • verbose: Whether to print runtimes.
  • block_shape: Block shape for parallelizing the operations.
Returns:

The size filtered segmentation.

def get_device( device: Union[torch.device, str, NoneType] = None) -> Union[str, torch.device]:
635def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
636    """Get the torch device.
637
638    If no device is passed the default device for your system is used.
639    Else it will be checked if the device you have passed is supported.
640
641    Args:
642        device: The input device.
643
644    Returns:
645        The device.
646    """
647    if device is None or device == "auto":
648        device = _get_default_device()
649    else:
650        device_type = device if isinstance(device, str) else device.type
651        if device_type.lower() == "cuda":
652            if not torch.cuda.is_available():
653                raise RuntimeError("PyTorch CUDA backend is not available.")
654        elif device_type.lower() == "mps":
655            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
656                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
657        elif device_type.lower() == "cpu":
658            pass  # cpu is always available
659        else:
660            raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.")
661    return device

Get the torch device.

If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.

Arguments:
  • device: The input device.
Returns:

The device.