synapse_net.inference.util

  1import os
  2import time
  3import warnings
  4from glob import glob
  5from typing import Dict, List, Optional, Tuple, Union
  6
  7# # Suppress annoying import warnings.
  8# with warnings.catch_warnings():
  9#     warnings.simplefilter("ignore")
 10#     import bioimageio.core
 11
 12import imageio.v3 as imageio
 13import elf.parallel as parallel
 14import mrcfile
 15import numpy as np
 16import torch
 17import torch_em
 18# import xarray
 19
 20from elf.io import open_file
 21from scipy.ndimage import binary_closing
 22from skimage.measure import regionprops
 23from skimage.morphology import remove_small_holes
 24from skimage.transform import rescale, resize
 25from torch_em.util.prediction import predict_with_halo
 26from tqdm import tqdm
 27
 28
 29#
 30# Utils for prediction.
 31#
 32
 33
 34class _Scaler:
 35    def __init__(self, scale, verbose):
 36        self.verbose = verbose
 37        self._original_shape = None
 38
 39        if scale is None:
 40            self.scale = None
 41            return
 42
 43        # Convert scale to a NumPy array (ensures consistency)
 44        scale = np.atleast_1d(scale).astype(np.float64)
 45
 46        # Validate scale values
 47        if not np.issubdtype(scale.dtype, np.number):
 48            raise TypeError(f"Scale contains non-numeric values: {scale}")
 49
 50        # Check if scaling is effectively identity (1.0 in all dimensions)
 51        if np.allclose(scale, 1.0, atol=1e-3):
 52            self.scale = None
 53        else:
 54            self.scale = scale
 55
 56    def scale_input(self, input_volume, is_segmentation=False):
 57        t0 = time.time()
 58        if self.scale is None:
 59            return input_volume
 60
 61        if self._original_shape is None:
 62            self._original_shape = input_volume.shape
 63        elif self._original_shape != input_volume.shape:
 64            raise RuntimeError(
 65                "Scaler was called with different input shapes. "
 66                "This is not supported, please create a new instance of the class for it."
 67            )
 68
 69        if is_segmentation:
 70            input_volume = rescale(
 71                input_volume, self.scale, preserve_range=True, order=0, anti_aliasing=False,
 72            ).astype(input_volume.dtype)
 73        else:
 74            input_volume = rescale(input_volume, self.scale, preserve_range=True).astype(input_volume.dtype)
 75
 76        if self.verbose:
 77            print("Rescaled volume from", self._original_shape, "to", input_volume.shape, "in", time.time() - t0, "s")
 78        return input_volume
 79
 80    def rescale_output(self, output, is_segmentation):
 81        t0 = time.time()
 82        if self.scale is None:
 83            return output
 84
 85        assert self._original_shape is not None
 86        out_shape = self._original_shape
 87        if output.ndim > len(out_shape):
 88            assert output.ndim == len(out_shape) + 1
 89            out_shape = (output.shape[0],) + out_shape
 90
 91        if is_segmentation:
 92            output = resize(output, out_shape, preserve_range=True, order=0, anti_aliasing=False).astype(output.dtype)
 93        else:
 94            output = resize(output, out_shape, preserve_range=True).astype(output.dtype)
 95
 96        if self.verbose:
 97            print("Resized prediction back to original shape", output.shape, "in", time.time() - t0, "s")
 98
 99        return output
100
101
102def get_prediction(
103    input_volume: np.ndarray,  # [z, y, x]
104    tiling: Optional[Dict[str, Dict[str, int]]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
105    model_path: Optional[str] = None,
106    model: Optional[torch.nn.Module] = None,
107    verbose: bool = True,
108    with_channels: bool = False,
109    channels_to_standardize: Optional[List[int]] = None,
110    mask: Optional[np.ndarray] = None,
111) -> np.ndarray:
112    """Run prediction on a given volume.
113
114    This function will automatically choose the correct prediction implementation,
115    depending on the model type.
116
117    Args:
118        input_volume: The input volume to predict on.
119        model_path: The path to the model checkpoint if 'model' is not provided.
120        model: Pre-loaded model. Either model_path or model is required.
121        tiling: The tiling configuration for the prediction.
122        verbose: Whether to print timing information.
123        with_channels: Whether to predict with channels.
124        channels_to_standardize: List of channels to standardize. Defaults to None.
125        mask: Optional binary mask. If given, the prediction will only be run in
126            the foreground region of the mask.
127
128    Returns:
129        The predicted volume.
130    """
131    # make sure either model path or model is passed
132    if model is None and model_path is None:
133        raise ValueError("Either 'model_path' or 'model' must be provided.")
134
135    if model is not None:
136        is_bioimageio = None
137    else:
138        is_bioimageio = model_path.endswith(".zip")
139
140    if tiling is None:
141        tiling = get_default_tiling()
142
143    # We standardize the data for the whole volume beforehand.
144    # If we have channels then the standardization is done independently per channel.
145    if with_channels:
146        input_volume = input_volume.astype(np.float32, copy=False)
147        # TODO Check that this is the correct axis.
148        if channels_to_standardize is None:  # assume all channels
149            channels_to_standardize = range(input_volume.shape[0])
150        for ch in channels_to_standardize:
151            input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch])
152    else:
153        input_volume = torch_em.transform.raw.standardize(input_volume)
154
155    # Run prediction with the bioimage.io library.
156    if is_bioimageio:
157        if mask is not None:
158            raise NotImplementedError
159        raise NotImplementedError
160
161    # Run prediction with the torch-em library.
162    else:
163        if model is None:
164            # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself.
165            if model_path.endswith("best.pt"):
166                model_path = os.path.split(model_path)[0]
167        # print(f"tiling {tiling}")
168        # Create updated_tiling with the same structure
169        updated_tiling = {
170            "tile": {},
171            "halo": tiling["halo"]  # Keep the halo part unchanged
172        }
173        # Update tile dimensions
174        for dim in tiling["tile"]:
175            updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim]
176        # print(f"updated_tiling {updated_tiling}")
177        pred = get_prediction_torch_em(
178            input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask
179        )
180
181    return pred
182
183
184def get_prediction_torch_em(
185    input_volume: np.ndarray,  # [z, y, x]
186    tiling: Dict[str, Dict[str, int]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
187    model_path: Optional[str] = None,
188    model: Optional[torch.nn.Module] = None,
189    verbose: bool = True,
190    with_channels: bool = False,
191    mask: Optional[np.ndarray] = None,
192) -> np.ndarray:
193    """Run prediction using torch-em on a given volume.
194
195    Args:
196        input_volume: The input volume to predict on.
197        model_path: The path to the model checkpoint if 'model' is not provided.
198        model: Pre-loaded model. Either model_path or model is required.
199        tiling: The tiling configuration for the prediction.
200        verbose: Whether to print timing information.
201        with_channels: Whether to predict with channels.
202        mask: Optional binary mask. If given, the prediction will only be run in
203            the foreground region of the mask.
204
205    Returns:
206        The predicted volume.
207    """
208    # get block_shape and halo
209    block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]]
210    halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]]
211
212    t0 = time.time()
213    device = "cuda" if torch.cuda.is_available() else "cpu"
214
215    # Suppress warning when loading the model.
216    with warnings.catch_warnings():
217        warnings.simplefilter("ignore")
218        if model is None:
219            if os.path.isdir(model_path):  # Load the model from a torch_em checkpoint.
220                model = torch_em.util.load_model(checkpoint=model_path, device=device)
221            else:  # Load the model directly from a serialized pytorch model.
222                model = torch.load(model_path)
223
224    # Run prediction with the model.
225    with torch.no_grad():
226
227        # Deal with 2D segmentation case
228        if len(input_volume.shape) == 2:
229            block_shape = [block_shape[1], block_shape[2]]
230            halo = [halo[1], halo[2]]
231
232        if mask is not None:
233            if verbose:
234                print("Run prediction with mask.")
235            mask = mask.astype("bool")
236
237        pred = predict_with_halo(
238            input_volume, model, gpu_ids=[device],
239            block_shape=block_shape, halo=halo,
240            preprocess=None, with_channels=with_channels, mask=mask,
241        )
242    if verbose:
243        print("Prediction time in", time.time() - t0, "s")
244    return pred
245
246
247def _get_file_paths(input_path, ext=".mrc"):
248    if not os.path.exists(input_path):
249        raise Exception(f"Input path not found {input_path}")
250
251    if os.path.isfile(input_path):
252        input_files = [input_path]
253        input_root = None
254    else:
255        input_files = sorted(glob(os.path.join(input_path, "**", f"*{ext}"), recursive=True))
256        input_root = input_path
257
258    return input_files, input_root
259
260
261def _load_input(img_path, extra_files, i):
262    # Load the input data.
263    if os.path.splitext(img_path)[-1] == ".tif":
264        input_volume = imageio.imread(img_path)
265
266    else:
267        with open_file(img_path, "r") as f:
268            # Try to automatically derive the key with the raw data.
269            keys = list(f.keys())
270            if len(keys) == 1:
271                key = keys[0]
272            elif "data" in keys:
273                key = "data"
274            elif "raw" in keys:
275                key = "raw"
276            input_volume = f[key][:]
277
278    assert input_volume.ndim in (2, 3)
279    # For now we assume this is always tif.
280    if extra_files is not None:
281        extra_input = imageio.imread(extra_files[i])
282        assert extra_input.shape == input_volume.shape
283        input_volume = np.stack([input_volume, extra_input], axis=0)
284
285    return input_volume
286
287
288def _derive_scale(img_path, model_resolution):
289    try:
290        with mrcfile.open(img_path, "r") as f:
291            voxel_size = f.voxel_size
292            if len(model_resolution) == 2:
293                voxel_size = [voxel_size.y, voxel_size.x]
294            else:
295                voxel_size = [voxel_size.z, voxel_size.y, voxel_size.x]
296
297        assert len(voxel_size) == len(model_resolution)
298        # The voxel size is given in Angstrom and we need to translate it to nanometer.
299        voxel_size = [vsize / 10 for vsize in voxel_size]
300
301        # Compute the correct scale factor.
302        scale = tuple(vsize / res for vsize, res in zip(voxel_size, model_resolution))
303        print("Rescaling the data at", img_path, "by", scale, "to match the training voxel size", model_resolution)
304
305    except Exception:
306        warnings.warn(
307            f"The voxel size could not be read from the data for {img_path}. "
308            "This data will not be scaled for prediction."
309        )
310        scale = None
311
312    return scale
313
314
315def inference_helper(
316    input_path: str,
317    output_root: str,
318    segmentation_function: callable,
319    data_ext: str = ".mrc",
320    extra_input_path: Optional[str] = None,
321    extra_input_ext: str = ".tif",
322    mask_input_path: Optional[str] = None,
323    mask_input_ext: str = ".tif",
324    force: bool = False,
325    output_key: Optional[str] = None,
326    model_resolution: Optional[Tuple[float, float, float]] = None,
327    scale: Optional[Tuple[float, float, float]] = None,
328) -> None:
329    """Helper function to run segmentation for mrc files.
330
331    Args:
332        input_path: The path to the input data.
333            Can either be a folder. In this case all mrc files below the folder will be segmented.
334            Or can be a single mrc file. In this case only this mrc file will be segmented.
335        output_root: The path to the output directory where the segmentation results will be saved.
336        segmentation_function: The function performing the segmentation.
337            This function must take the input_volume as the only argument and must return only the segmentation.
338            If you want to pass additional arguments to this function the use 'funtools.partial'
339        data_ext: File extension for the image data. By default '.mrc' is used.
340        extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc.
341            This enables cristae segmentation with an extra mito channel.
342        extra_input_ext: File extension for the extra inputs (by default .tif).
343        mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
344        mask_input_ext: File extension for the mask inputs (by default .tif).
345        force: Whether to rerun segmentation for output files that are already present.
346        output_key: Output key for the prediction. If none will write an hdf5 file.
347        model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction.
348            If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
349        scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
350    """
351    if (scale is not None) and (model_resolution is not None):
352        raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.")
353
354    # Get the input files. If input_path is a folder then this will load all
355    # the mrc files beneath it. Otherwise we assume this is an mrc file already
356    # and just return the path to this mrc file.
357    input_files, input_root = _get_file_paths(input_path, data_ext)
358
359    # Load extra inputs if the extra_input_path was specified.
360    if extra_input_path is None:
361        extra_files = None
362    else:
363        extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext)
364        assert len(input_files) == len(extra_files)
365
366    # Load the masks if they were specified.
367    if mask_input_path is None:
368        mask_files = None
369    else:
370        mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext)
371        assert len(input_files) == len(mask_files)
372
373    for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"):
374        # Determine the output file name.
375        input_folder, input_name = os.path.split(img_path)
376
377        if output_key is None:
378            fname = os.path.splitext(input_name)[0] + "_prediction.tif"
379        else:
380            fname = os.path.splitext(input_name)[0] + "_prediction.h5"
381
382        if input_root is None:
383            output_path = os.path.join(output_root, fname)
384        else:  # If we have nested input folders then we preserve the folder structure in the output.
385            rel_folder = os.path.relpath(input_folder, input_root)
386            output_path = os.path.join(output_root, rel_folder, fname)
387
388        # Check if the output path is already present.
389        # If it is we skip the prediction, unless force was set to true.
390        if os.path.exists(output_path) and not force:
391            if output_key is None:
392                continue
393            else:
394                with open_file(output_path, "r") as f:
395                    if output_key in f:
396                        continue
397
398        # Load the input volume. If we have extra_files then this concatenates the
399        # data across a new first axis (= channel axis).
400        input_volume = _load_input(img_path, extra_files, i)
401        # Load the mask (if given).
402        mask = None if mask_files is None else imageio.imread(mask_files[i])
403
404        # Determine the scale factor:
405        # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None.
406        if scale is None and model_resolution is None:
407            this_scale = None
408        elif scale is not None:   # If 'scale' was passed then use it.
409            this_scale = scale
410        else:   # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data
411            assert model_resolution is not None
412            this_scale = _derive_scale(img_path, model_resolution)
413
414        # Run the segmentation.
415        segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
416
417        # Write the result to tif or h5.
418        os.makedirs(os.path.split(output_path)[0], exist_ok=True)
419
420        if output_key is None:
421            imageio.imwrite(output_path, segmentation, compression="zlib")
422        else:
423            with open_file(output_path, "a") as f:
424                f.create_dataset(output_key, data=segmentation, compression="gzip")
425
426        print(f"Saved segmentation to {output_path}.")
427
428
429def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
430    """Determine the tile shape and halo depending on the available VRAM.
431
432    Args:
433        is_2d: Whether to return tiling settings for 2d inference.
434
435    Returns:
436        The default tiling settings for the available computational resources.
437    """
438    if is_2d:
439        tile = {"x": 768, "y": 768, "z": 1}
440        halo = {"x": 128, "y": 128, "z": 0}
441        return {"tile": tile, "halo": halo}
442
443    if torch.cuda.is_available():
444        # The default halo size.
445        halo = {"x": 64, "y": 64, "z": 16}
446
447        # Determine the GPU RAM and derive a suitable tiling.
448        vram = torch.cuda.get_device_properties(0).total_memory / 1e9
449
450        if vram >= 80:
451            tile = {"x": 640, "y": 640, "z": 80}
452        elif vram >= 40:
453            tile = {"x": 512, "y": 512, "z": 64}
454        elif vram >= 20:
455            tile = {"x": 352, "y": 352, "z": 48}
456        elif vram >= 10:
457            tile = {"x": 256, "y": 256, "z": 32}
458            halo = {"x": 64, "y": 64, "z": 8}  # Choose a smaller halo in z.
459        else:
460            raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.")
461
462        tiling = {"tile": tile, "halo": halo}
463        print(f"Determined tile size for CUDA: {tiling}")
464
465    elif torch.backends.mps.is_available():  # Check for Apple Silicon (MPS)
466        tile = {"x": 256, "y": 256, "z": 16}
467        halo = {"x": 16, "y": 16, "z": 4}
468        tiling = {"tile": tile, "halo": halo}
469        print(f"Determined tile size for MPS: {tiling}")
470
471    # I am not sure what is reasonable on a cpu. For now choosing very small tiling.
472    # (This will not work well on a CPU in any case.)
473    else:
474        tiling = {
475            "tile": {"x": 96, "y": 96, "z": 16},
476            "halo": {"x": 16, "y": 16, "z": 4},
477        }
478        print(f"Determining default tiling for CPU: {tiling}")
479
480    return tiling
481
482
483def parse_tiling(
484    tile_shape: Tuple[int, int, int],
485    halo: Tuple[int, int, int],
486    is_2d: bool = False,
487) -> Dict[str, Dict[str, int]]:
488    """Helper function to parse tiling parameter input from the command line.
489
490    Args:
491        tile_shape: The tile shape. If None the default tile shape is used.
492        halo: The halo. If None the default halo is used.
493        is_2d: Whether to return tiling for a 2d model.
494
495    Returns:
496        The tiling specification.
497    """
498
499    default_tiling = get_default_tiling(is_2d=is_2d)
500
501    if tile_shape is None:
502        tile_shape = default_tiling["tile"]
503    else:
504        assert len(tile_shape) == 3
505        tile_shape = dict(zip("zyx", tile_shape))
506
507    if halo is None:
508        halo = default_tiling["halo"]
509    else:
510        assert len(halo) == 3
511        halo = dict(zip("zyx", halo))
512
513    tiling = {"tile": tile_shape, "halo": halo}
514    return tiling
515
516
517#
518# Utils for post-processing.
519#
520
521
522def apply_size_filter(
523    segmentation: np.ndarray,
524    min_size: int,
525    verbose: bool = False,
526    block_shape: Tuple[int, int, int] = (128, 256, 256),
527) -> np.ndarray:
528    """Apply size filter to the segmentation to remove small objects.
529
530    Args:
531        segmentation: The segmentation.
532        min_size: The minimal object size in pixels.
533        verbose: Whether to print runtimes.
534        block_shape: Block shape for parallelizing the operations.
535
536    Returns:
537        The size filtered segmentation.
538    """
539    if min_size == 0:
540        return segmentation
541    t0 = time.time()
542    if segmentation.ndim == 2 and len(block_shape) == 3:
543        block_shape_ = block_shape[1:]
544    else:
545        block_shape_ = block_shape
546    ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose)
547    filter_ids = ids[sizes < min_size]
548    segmentation[np.isin(segmentation, filter_ids)] = 0
549    if verbose:
550        print("Size filter in", time.time() - t0, "s")
551    return segmentation
552
553
554def _postprocess_seg_3d(seg, area_threshold=1000, iterations=4, iterations_3d=8):
555    # Structure lement for 2d dilation in 3d.
556    structure_element = np.ones((3, 3))  # 3x3 structure for XY plane
557    structure_3d = np.zeros((1, 3, 3))  # Only applied in the XY plane
558    structure_3d[0] = structure_element
559
560    props = regionprops(seg)
561    for prop in props:
562        # Get bounding box and mask.
563        bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:3], prop.bbox[3:]))
564        mask = seg[bb] == prop.label
565
566        # Fill small holes and apply closing.
567        mask = remove_small_holes(mask, area_threshold=area_threshold)
568        mask = np.logical_or(binary_closing(mask, iterations=iterations), mask)
569        mask = np.logical_or(binary_closing(mask, iterations=iterations_3d, structure=structure_3d), mask)
570        seg[bb][mask] = prop.label
571
572    return seg
573
574
575#
576# Utils for torch device.
577#
578
579def _get_default_device():
580    # Check that we're in CI and use the CPU if we are.
581    # Otherwise the tests may run out of memory on MAC if MPS is used.
582    if os.getenv("GITHUB_ACTIONS") == "true":
583        return "cpu"
584    # Use cuda enabled gpu if it's available.
585    if torch.cuda.is_available():
586        device = "cuda"
587    # As second priority use mps.
588    # See https://pytorch.org/docs/stable/notes/mps.html for details
589    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
590        device = "mps"
591    # Use the CPU as fallback.
592    else:
593        device = "cpu"
594    return device
595
596
597def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
598    """Get the torch device.
599
600    If no device is passed the default device for your system is used.
601    Else it will be checked if the device you have passed is supported.
602
603    Args:
604        device: The input device.
605
606    Returns:
607        The device.
608    """
609    if device is None or device == "auto":
610        device = _get_default_device()
611    else:
612        device_type = device if isinstance(device, str) else device.type
613        if device_type.lower() == "cuda":
614            if not torch.cuda.is_available():
615                raise RuntimeError("PyTorch CUDA backend is not available.")
616        elif device_type.lower() == "mps":
617            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
618                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
619        elif device_type.lower() == "cpu":
620            pass  # cpu is always available
621        else:
622            raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.")
623    return device
def get_prediction( input_volume: numpy.ndarray, tiling: Optional[Dict[str, Dict[str, int]]], model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, verbose: bool = True, with_channels: bool = False, channels_to_standardize: Optional[List[int]] = None, mask: Optional[numpy.ndarray] = None) -> numpy.ndarray:
103def get_prediction(
104    input_volume: np.ndarray,  # [z, y, x]
105    tiling: Optional[Dict[str, Dict[str, int]]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
106    model_path: Optional[str] = None,
107    model: Optional[torch.nn.Module] = None,
108    verbose: bool = True,
109    with_channels: bool = False,
110    channels_to_standardize: Optional[List[int]] = None,
111    mask: Optional[np.ndarray] = None,
112) -> np.ndarray:
113    """Run prediction on a given volume.
114
115    This function will automatically choose the correct prediction implementation,
116    depending on the model type.
117
118    Args:
119        input_volume: The input volume to predict on.
120        model_path: The path to the model checkpoint if 'model' is not provided.
121        model: Pre-loaded model. Either model_path or model is required.
122        tiling: The tiling configuration for the prediction.
123        verbose: Whether to print timing information.
124        with_channels: Whether to predict with channels.
125        channels_to_standardize: List of channels to standardize. Defaults to None.
126        mask: Optional binary mask. If given, the prediction will only be run in
127            the foreground region of the mask.
128
129    Returns:
130        The predicted volume.
131    """
132    # make sure either model path or model is passed
133    if model is None and model_path is None:
134        raise ValueError("Either 'model_path' or 'model' must be provided.")
135
136    if model is not None:
137        is_bioimageio = None
138    else:
139        is_bioimageio = model_path.endswith(".zip")
140
141    if tiling is None:
142        tiling = get_default_tiling()
143
144    # We standardize the data for the whole volume beforehand.
145    # If we have channels then the standardization is done independently per channel.
146    if with_channels:
147        input_volume = input_volume.astype(np.float32, copy=False)
148        # TODO Check that this is the correct axis.
149        if channels_to_standardize is None:  # assume all channels
150            channels_to_standardize = range(input_volume.shape[0])
151        for ch in channels_to_standardize:
152            input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch])
153    else:
154        input_volume = torch_em.transform.raw.standardize(input_volume)
155
156    # Run prediction with the bioimage.io library.
157    if is_bioimageio:
158        if mask is not None:
159            raise NotImplementedError
160        raise NotImplementedError
161
162    # Run prediction with the torch-em library.
163    else:
164        if model is None:
165            # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself.
166            if model_path.endswith("best.pt"):
167                model_path = os.path.split(model_path)[0]
168        # print(f"tiling {tiling}")
169        # Create updated_tiling with the same structure
170        updated_tiling = {
171            "tile": {},
172            "halo": tiling["halo"]  # Keep the halo part unchanged
173        }
174        # Update tile dimensions
175        for dim in tiling["tile"]:
176            updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim]
177        # print(f"updated_tiling {updated_tiling}")
178        pred = get_prediction_torch_em(
179            input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask
180        )
181
182    return pred

Run prediction on a given volume.

This function will automatically choose the correct prediction implementation, depending on the model type.

Arguments:
  • input_volume: The input volume to predict on.
  • model_path: The path to the model checkpoint if 'model' is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • verbose: Whether to print timing information.
  • with_channels: Whether to predict with channels.
  • channels_to_standardize: List of channels to standardize. Defaults to None.
  • mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
Returns:

The predicted volume.

def get_prediction_torch_em( input_volume: numpy.ndarray, tiling: Dict[str, Dict[str, int]], model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, verbose: bool = True, with_channels: bool = False, mask: Optional[numpy.ndarray] = None) -> numpy.ndarray:
185def get_prediction_torch_em(
186    input_volume: np.ndarray,  # [z, y, x]
187    tiling: Dict[str, Dict[str, int]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
188    model_path: Optional[str] = None,
189    model: Optional[torch.nn.Module] = None,
190    verbose: bool = True,
191    with_channels: bool = False,
192    mask: Optional[np.ndarray] = None,
193) -> np.ndarray:
194    """Run prediction using torch-em on a given volume.
195
196    Args:
197        input_volume: The input volume to predict on.
198        model_path: The path to the model checkpoint if 'model' is not provided.
199        model: Pre-loaded model. Either model_path or model is required.
200        tiling: The tiling configuration for the prediction.
201        verbose: Whether to print timing information.
202        with_channels: Whether to predict with channels.
203        mask: Optional binary mask. If given, the prediction will only be run in
204            the foreground region of the mask.
205
206    Returns:
207        The predicted volume.
208    """
209    # get block_shape and halo
210    block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]]
211    halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]]
212
213    t0 = time.time()
214    device = "cuda" if torch.cuda.is_available() else "cpu"
215
216    # Suppress warning when loading the model.
217    with warnings.catch_warnings():
218        warnings.simplefilter("ignore")
219        if model is None:
220            if os.path.isdir(model_path):  # Load the model from a torch_em checkpoint.
221                model = torch_em.util.load_model(checkpoint=model_path, device=device)
222            else:  # Load the model directly from a serialized pytorch model.
223                model = torch.load(model_path)
224
225    # Run prediction with the model.
226    with torch.no_grad():
227
228        # Deal with 2D segmentation case
229        if len(input_volume.shape) == 2:
230            block_shape = [block_shape[1], block_shape[2]]
231            halo = [halo[1], halo[2]]
232
233        if mask is not None:
234            if verbose:
235                print("Run prediction with mask.")
236            mask = mask.astype("bool")
237
238        pred = predict_with_halo(
239            input_volume, model, gpu_ids=[device],
240            block_shape=block_shape, halo=halo,
241            preprocess=None, with_channels=with_channels, mask=mask,
242        )
243    if verbose:
244        print("Prediction time in", time.time() - t0, "s")
245    return pred

Run prediction using torch-em on a given volume.

Arguments:
  • input_volume: The input volume to predict on.
  • model_path: The path to the model checkpoint if 'model' is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • verbose: Whether to print timing information.
  • with_channels: Whether to predict with channels.
  • mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
Returns:

The predicted volume.

def inference_helper( input_path: str, output_root: str, segmentation_function: <built-in function callable>, data_ext: str = '.mrc', extra_input_path: Optional[str] = None, extra_input_ext: str = '.tif', mask_input_path: Optional[str] = None, mask_input_ext: str = '.tif', force: bool = False, output_key: Optional[str] = None, model_resolution: Optional[Tuple[float, float, float]] = None, scale: Optional[Tuple[float, float, float]] = None) -> None:
316def inference_helper(
317    input_path: str,
318    output_root: str,
319    segmentation_function: callable,
320    data_ext: str = ".mrc",
321    extra_input_path: Optional[str] = None,
322    extra_input_ext: str = ".tif",
323    mask_input_path: Optional[str] = None,
324    mask_input_ext: str = ".tif",
325    force: bool = False,
326    output_key: Optional[str] = None,
327    model_resolution: Optional[Tuple[float, float, float]] = None,
328    scale: Optional[Tuple[float, float, float]] = None,
329) -> None:
330    """Helper function to run segmentation for mrc files.
331
332    Args:
333        input_path: The path to the input data.
334            Can either be a folder. In this case all mrc files below the folder will be segmented.
335            Or can be a single mrc file. In this case only this mrc file will be segmented.
336        output_root: The path to the output directory where the segmentation results will be saved.
337        segmentation_function: The function performing the segmentation.
338            This function must take the input_volume as the only argument and must return only the segmentation.
339            If you want to pass additional arguments to this function the use 'funtools.partial'
340        data_ext: File extension for the image data. By default '.mrc' is used.
341        extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc.
342            This enables cristae segmentation with an extra mito channel.
343        extra_input_ext: File extension for the extra inputs (by default .tif).
344        mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
345        mask_input_ext: File extension for the mask inputs (by default .tif).
346        force: Whether to rerun segmentation for output files that are already present.
347        output_key: Output key for the prediction. If none will write an hdf5 file.
348        model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction.
349            If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
350        scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
351    """
352    if (scale is not None) and (model_resolution is not None):
353        raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.")
354
355    # Get the input files. If input_path is a folder then this will load all
356    # the mrc files beneath it. Otherwise we assume this is an mrc file already
357    # and just return the path to this mrc file.
358    input_files, input_root = _get_file_paths(input_path, data_ext)
359
360    # Load extra inputs if the extra_input_path was specified.
361    if extra_input_path is None:
362        extra_files = None
363    else:
364        extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext)
365        assert len(input_files) == len(extra_files)
366
367    # Load the masks if they were specified.
368    if mask_input_path is None:
369        mask_files = None
370    else:
371        mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext)
372        assert len(input_files) == len(mask_files)
373
374    for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"):
375        # Determine the output file name.
376        input_folder, input_name = os.path.split(img_path)
377
378        if output_key is None:
379            fname = os.path.splitext(input_name)[0] + "_prediction.tif"
380        else:
381            fname = os.path.splitext(input_name)[0] + "_prediction.h5"
382
383        if input_root is None:
384            output_path = os.path.join(output_root, fname)
385        else:  # If we have nested input folders then we preserve the folder structure in the output.
386            rel_folder = os.path.relpath(input_folder, input_root)
387            output_path = os.path.join(output_root, rel_folder, fname)
388
389        # Check if the output path is already present.
390        # If it is we skip the prediction, unless force was set to true.
391        if os.path.exists(output_path) and not force:
392            if output_key is None:
393                continue
394            else:
395                with open_file(output_path, "r") as f:
396                    if output_key in f:
397                        continue
398
399        # Load the input volume. If we have extra_files then this concatenates the
400        # data across a new first axis (= channel axis).
401        input_volume = _load_input(img_path, extra_files, i)
402        # Load the mask (if given).
403        mask = None if mask_files is None else imageio.imread(mask_files[i])
404
405        # Determine the scale factor:
406        # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None.
407        if scale is None and model_resolution is None:
408            this_scale = None
409        elif scale is not None:   # If 'scale' was passed then use it.
410            this_scale = scale
411        else:   # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data
412            assert model_resolution is not None
413            this_scale = _derive_scale(img_path, model_resolution)
414
415        # Run the segmentation.
416        segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
417
418        # Write the result to tif or h5.
419        os.makedirs(os.path.split(output_path)[0], exist_ok=True)
420
421        if output_key is None:
422            imageio.imwrite(output_path, segmentation, compression="zlib")
423        else:
424            with open_file(output_path, "a") as f:
425                f.create_dataset(output_key, data=segmentation, compression="gzip")
426
427        print(f"Saved segmentation to {output_path}.")

Helper function to run segmentation for mrc files.

Arguments:
  • input_path: The path to the input data. Can either be a folder. In this case all mrc files below the folder will be segmented. Or can be a single mrc file. In this case only this mrc file will be segmented.
  • output_root: The path to the output directory where the segmentation results will be saved.
  • segmentation_function: The function performing the segmentation. This function must take the input_volume as the only argument and must return only the segmentation. If you want to pass additional arguments to this function the use 'funtools.partial'
  • data_ext: File extension for the image data. By default '.mrc' is used.
  • extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. This enables cristae segmentation with an extra mito channel.
  • extra_input_ext: File extension for the extra inputs (by default .tif).
  • mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
  • mask_input_ext: File extension for the mask inputs (by default .tif).
  • force: Whether to rerun segmentation for output files that are already present.
  • output_key: Output key for the prediction. If none will write an hdf5 file.
  • model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
  • scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
430def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
431    """Determine the tile shape and halo depending on the available VRAM.
432
433    Args:
434        is_2d: Whether to return tiling settings for 2d inference.
435
436    Returns:
437        The default tiling settings for the available computational resources.
438    """
439    if is_2d:
440        tile = {"x": 768, "y": 768, "z": 1}
441        halo = {"x": 128, "y": 128, "z": 0}
442        return {"tile": tile, "halo": halo}
443
444    if torch.cuda.is_available():
445        # The default halo size.
446        halo = {"x": 64, "y": 64, "z": 16}
447
448        # Determine the GPU RAM and derive a suitable tiling.
449        vram = torch.cuda.get_device_properties(0).total_memory / 1e9
450
451        if vram >= 80:
452            tile = {"x": 640, "y": 640, "z": 80}
453        elif vram >= 40:
454            tile = {"x": 512, "y": 512, "z": 64}
455        elif vram >= 20:
456            tile = {"x": 352, "y": 352, "z": 48}
457        elif vram >= 10:
458            tile = {"x": 256, "y": 256, "z": 32}
459            halo = {"x": 64, "y": 64, "z": 8}  # Choose a smaller halo in z.
460        else:
461            raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.")
462
463        tiling = {"tile": tile, "halo": halo}
464        print(f"Determined tile size for CUDA: {tiling}")
465
466    elif torch.backends.mps.is_available():  # Check for Apple Silicon (MPS)
467        tile = {"x": 256, "y": 256, "z": 16}
468        halo = {"x": 16, "y": 16, "z": 4}
469        tiling = {"tile": tile, "halo": halo}
470        print(f"Determined tile size for MPS: {tiling}")
471
472    # I am not sure what is reasonable on a cpu. For now choosing very small tiling.
473    # (This will not work well on a CPU in any case.)
474    else:
475        tiling = {
476            "tile": {"x": 96, "y": 96, "z": 16},
477            "halo": {"x": 16, "y": 16, "z": 4},
478        }
479        print(f"Determining default tiling for CPU: {tiling}")
480
481    return tiling

Determine the tile shape and halo depending on the available VRAM.

Arguments:
  • is_2d: Whether to return tiling settings for 2d inference.
Returns:

The default tiling settings for the available computational resources.

def parse_tiling( tile_shape: Tuple[int, int, int], halo: Tuple[int, int, int], is_2d: bool = False) -> Dict[str, Dict[str, int]]:
484def parse_tiling(
485    tile_shape: Tuple[int, int, int],
486    halo: Tuple[int, int, int],
487    is_2d: bool = False,
488) -> Dict[str, Dict[str, int]]:
489    """Helper function to parse tiling parameter input from the command line.
490
491    Args:
492        tile_shape: The tile shape. If None the default tile shape is used.
493        halo: The halo. If None the default halo is used.
494        is_2d: Whether to return tiling for a 2d model.
495
496    Returns:
497        The tiling specification.
498    """
499
500    default_tiling = get_default_tiling(is_2d=is_2d)
501
502    if tile_shape is None:
503        tile_shape = default_tiling["tile"]
504    else:
505        assert len(tile_shape) == 3
506        tile_shape = dict(zip("zyx", tile_shape))
507
508    if halo is None:
509        halo = default_tiling["halo"]
510    else:
511        assert len(halo) == 3
512        halo = dict(zip("zyx", halo))
513
514    tiling = {"tile": tile_shape, "halo": halo}
515    return tiling

Helper function to parse tiling parameter input from the command line.

Arguments:
  • tile_shape: The tile shape. If None the default tile shape is used.
  • halo: The halo. If None the default halo is used.
  • is_2d: Whether to return tiling for a 2d model.
Returns:

The tiling specification.

def apply_size_filter( segmentation: numpy.ndarray, min_size: int, verbose: bool = False, block_shape: Tuple[int, int, int] = (128, 256, 256)) -> numpy.ndarray:
523def apply_size_filter(
524    segmentation: np.ndarray,
525    min_size: int,
526    verbose: bool = False,
527    block_shape: Tuple[int, int, int] = (128, 256, 256),
528) -> np.ndarray:
529    """Apply size filter to the segmentation to remove small objects.
530
531    Args:
532        segmentation: The segmentation.
533        min_size: The minimal object size in pixels.
534        verbose: Whether to print runtimes.
535        block_shape: Block shape for parallelizing the operations.
536
537    Returns:
538        The size filtered segmentation.
539    """
540    if min_size == 0:
541        return segmentation
542    t0 = time.time()
543    if segmentation.ndim == 2 and len(block_shape) == 3:
544        block_shape_ = block_shape[1:]
545    else:
546        block_shape_ = block_shape
547    ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose)
548    filter_ids = ids[sizes < min_size]
549    segmentation[np.isin(segmentation, filter_ids)] = 0
550    if verbose:
551        print("Size filter in", time.time() - t0, "s")
552    return segmentation

Apply size filter to the segmentation to remove small objects.

Arguments:
  • segmentation: The segmentation.
  • min_size: The minimal object size in pixels.
  • verbose: Whether to print runtimes.
  • block_shape: Block shape for parallelizing the operations.
Returns:

The size filtered segmentation.

def get_device( device: Union[torch.device, str, NoneType] = None) -> Union[str, torch.device]:
598def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
599    """Get the torch device.
600
601    If no device is passed the default device for your system is used.
602    Else it will be checked if the device you have passed is supported.
603
604    Args:
605        device: The input device.
606
607    Returns:
608        The device.
609    """
610    if device is None or device == "auto":
611        device = _get_default_device()
612    else:
613        device_type = device if isinstance(device, str) else device.type
614        if device_type.lower() == "cuda":
615            if not torch.cuda.is_available():
616                raise RuntimeError("PyTorch CUDA backend is not available.")
617        elif device_type.lower() == "mps":
618            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
619                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
620        elif device_type.lower() == "cpu":
621            pass  # cpu is always available
622        else:
623            raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.")
624    return device

Get the torch device.

If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.

Arguments:
  • device: The input device.
Returns:

The device.