synapse_net.inference.util

  1import os
  2import time
  3import warnings
  4from glob import glob
  5from typing import Dict, Optional, Tuple, Union
  6
  7# # Suppress annoying import warnings.
  8# with warnings.catch_warnings():
  9#     warnings.simplefilter("ignore")
 10#     import bioimageio.core
 11
 12import imageio.v3 as imageio
 13import elf.parallel as parallel
 14import mrcfile
 15import numpy as np
 16import torch
 17import torch_em
 18# import xarray
 19
 20from elf.io import open_file
 21from scipy.ndimage import binary_closing
 22from skimage.measure import regionprops
 23from skimage.morphology import remove_small_holes
 24from skimage.transform import rescale, resize
 25from torch_em.util.prediction import predict_with_halo
 26from tqdm import tqdm
 27
 28
 29#
 30# Utils for prediction.
 31#
 32
 33
 34class _Scaler:
 35    def __init__(self, scale, verbose):
 36        self.scale = scale
 37        self.verbose = verbose
 38        self._original_shape = None
 39
 40    def scale_input(self, input_volume, is_segmentation=False):
 41        if self.scale is None:
 42            return input_volume
 43
 44        if self._original_shape is None:
 45            self._original_shape = input_volume.shape
 46        elif self._oringal_shape != input_volume.shape:
 47            raise RuntimeError(
 48                "Scaler was called with different input shapes. "
 49                "This is not supported, please create a new instance of the class for it."
 50            )
 51
 52        if is_segmentation:
 53            input_volume = rescale(
 54                input_volume, self.scale, preserve_range=True, order=0, anti_aliasing=False,
 55            ).astype(input_volume.dtype)
 56        else:
 57            input_volume = rescale(input_volume, self.scale, preserve_range=True).astype(input_volume.dtype)
 58
 59        if self.verbose:
 60            print("Rescaled volume from", self._original_shape, "to", input_volume.shape)
 61        return input_volume
 62
 63    def rescale_output(self, output, is_segmentation):
 64        if self.scale is None:
 65            return output
 66
 67        assert self._original_shape is not None
 68        out_shape = self._original_shape
 69        if output.ndim > len(out_shape):
 70            assert output.ndim == len(out_shape) + 1
 71            out_shape = (output.shape[0],) + out_shape
 72
 73        if is_segmentation:
 74            output = resize(output, out_shape, preserve_range=True, order=0, anti_aliasing=False).astype(output.dtype)
 75        else:
 76            output = resize(output, out_shape, preserve_range=True).astype(output.dtype)
 77
 78        return output
 79
 80
 81def get_prediction(
 82    input_volume: np.ndarray,  # [z, y, x]
 83    tiling: Optional[Dict[str, Dict[str, int]]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
 84    model_path: Optional[str] = None,
 85    model: Optional[torch.nn.Module] = None,
 86    verbose: bool = True,
 87    with_channels: bool = False,
 88    mask: Optional[np.ndarray] = None,
 89) -> np.ndarray:
 90    """Run prediction on a given volume.
 91
 92    This function will automatically choose the correct prediction implementation,
 93    depending on the model type.
 94
 95    Args:
 96        input_volume: The input volume to predict on.
 97        model_path: The path to the model checkpoint if 'model' is not provided.
 98        model: Pre-loaded model. Either model_path or model is required.
 99        tiling: The tiling configuration for the prediction.
100        verbose: Whether to print timing information.
101        with_channels: Whether to predict with channels.
102        mask: Optional binary mask. If given, the prediction will only be run in
103            the foreground region of the mask.
104
105    Returns:
106        The predicted volume.
107    """
108    # make sure either model path or model is passed
109    if model is None and model_path is None:
110        raise ValueError("Either 'model_path' or 'model' must be provided.")
111
112    if model is not None:
113        is_bioimageio = None
114    else:
115        is_bioimageio = model_path.endswith(".zip")
116
117    if tiling is None:
118        tiling = get_default_tiling()
119
120    # We standardize the data for the whole volume beforehand.
121    # If we have channels then the standardization is done independently per channel.
122    if with_channels:
123        # TODO Check that this is the correct axis.
124        input_volume = torch_em.transform.raw.standardize(input_volume, axis=(1, 2, 3))
125    else:
126        input_volume = torch_em.transform.raw.standardize(input_volume)
127
128    # Run prediction with the bioimage.io library.
129    if is_bioimageio:
130        if mask is not None:
131            raise NotImplementedError
132        raise NotImplementedError
133
134    # Run prediction with the torch-em library.
135    else:
136        if model is None:
137            # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself.
138            if model_path.endswith("best.pt"):
139                model_path = os.path.split(model_path)[0]
140        # print(f"tiling {tiling}")
141        # Create updated_tiling with the same structure
142        updated_tiling = {
143            "tile": {},
144            "halo": tiling["halo"]  # Keep the halo part unchanged
145        }
146        # Update tile dimensions
147        for dim in tiling["tile"]:
148            updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim]
149        # print(f"updated_tiling {updated_tiling}")
150        pred = get_prediction_torch_em(
151            input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask
152        )
153
154    return pred
155
156
157def get_prediction_torch_em(
158    input_volume: np.ndarray,  # [z, y, x]
159    tiling: Dict[str, Dict[str, int]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
160    model_path: Optional[str] = None,
161    model: Optional[torch.nn.Module] = None,
162    verbose: bool = True,
163    with_channels: bool = False,
164    mask: Optional[np.ndarray] = None,
165) -> np.ndarray:
166    """Run prediction using torch-em on a given volume.
167
168    Args:
169        input_volume: The input volume to predict on.
170        model_path: The path to the model checkpoint if 'model' is not provided.
171        model: Pre-loaded model. Either model_path or model is required.
172        tiling: The tiling configuration for the prediction.
173        verbose: Whether to print timing information.
174        with_channels: Whether to predict with channels.
175        mask: Optional binary mask. If given, the prediction will only be run in
176            the foreground region of the mask.
177
178    Returns:
179        The predicted volume.
180    """
181    # get block_shape and halo
182    block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]]
183    halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]]
184
185    t0 = time.time()
186    device = "cuda" if torch.cuda.is_available() else "cpu"
187
188    # Suppress warning when loading the model.
189    with warnings.catch_warnings():
190        warnings.simplefilter("ignore")
191        if model is None:
192            if os.path.isdir(model_path):  # Load the model from a torch_em checkpoint.
193                model = torch_em.util.load_model(checkpoint=model_path, device=device)
194            else:  # Load the model directly from a serialized pytorch model.
195                model = torch.load(model_path)
196
197    # Run prediction with the model.
198    with torch.no_grad():
199
200        # Deal with 2D segmentation case
201        if len(input_volume.shape) == 2:
202            block_shape = [block_shape[1], block_shape[2]]
203            halo = [halo[1], halo[2]]
204
205        if mask is not None:
206            if verbose:
207                print("Run prediction with mask.")
208            mask = mask.astype("bool")
209
210        pred = predict_with_halo(
211            input_volume, model, gpu_ids=[device],
212            block_shape=block_shape, halo=halo,
213            preprocess=None, with_channels=with_channels, mask=mask,
214        )
215    if verbose:
216        print("Prediction time in", time.time() - t0, "s")
217    return pred
218
219
220def _get_file_paths(input_path, ext=".mrc"):
221    if not os.path.exists(input_path):
222        raise Exception(f"Input path not found {input_path}")
223
224    if os.path.isfile(input_path):
225        input_files = [input_path]
226        input_root = None
227    else:
228        input_files = sorted(glob(os.path.join(input_path, "**", f"*{ext}"), recursive=True))
229        input_root = input_path
230
231    return input_files, input_root
232
233
234def _load_input(img_path, extra_files, i):
235    # Load the input data.
236    if os.path.splitext(img_path)[-1] == ".tif":
237        input_volume = imageio.imread(img_path)
238
239    else:
240        with open_file(img_path, "r") as f:
241            # Try to automatically derive the key with the raw data.
242            keys = list(f.keys())
243            if len(keys) == 1:
244                key = keys[0]
245            elif "data" in keys:
246                key = "data"
247            elif "raw" in keys:
248                key = "raw"
249            input_volume = f[key][:]
250
251    assert input_volume.ndim in (2, 3)
252    # For now we assume this is always tif.
253    if extra_files is not None:
254        extra_input = imageio.imread(extra_files[i])
255        assert extra_input.shape == input_volume.shape
256        input_volume = np.stack([input_volume, extra_input], axis=0)
257
258    return input_volume
259
260
261def _derive_scale(img_path, model_resolution):
262    try:
263        with mrcfile.open(img_path, "r") as f:
264            voxel_size = f.voxel_size
265            if len(model_resolution) == 2:
266                voxel_size = [voxel_size.y, voxel_size.x]
267            else:
268                voxel_size = [voxel_size.z, voxel_size.y, voxel_size.x]
269
270        assert len(voxel_size) == len(model_resolution)
271        # The voxel size is given in Angstrom and we need to translate it to nanometer.
272        voxel_size = [vsize / 10 for vsize in voxel_size]
273
274        # Compute the correct scale factor.
275        scale = tuple(vsize / res for vsize, res in zip(voxel_size, model_resolution))
276        print("Rescaling the data at", img_path, "by", scale, "to match the training voxel size", model_resolution)
277
278    except Exception:
279        warnings.warn(
280            f"The voxel size could not be read from the data for {img_path}. "
281            "This data will not be scaled for prediction."
282        )
283        scale = None
284
285    return scale
286
287
288def inference_helper(
289    input_path: str,
290    output_root: str,
291    segmentation_function: callable,
292    data_ext: str = ".mrc",
293    extra_input_path: Optional[str] = None,
294    extra_input_ext: str = ".tif",
295    mask_input_path: Optional[str] = None,
296    mask_input_ext: str = ".tif",
297    force: bool = False,
298    output_key: Optional[str] = None,
299    model_resolution: Optional[Tuple[float, float, float]] = None,
300    scale: Optional[Tuple[float, float, float]] = None,
301) -> None:
302    """Helper function to run segmentation for mrc files.
303
304    Args:
305        input_path: The path to the input data.
306            Can either be a folder. In this case all mrc files below the folder will be segmented.
307            Or can be a single mrc file. In this case only this mrc file will be segmented.
308        output_root: The path to the output directory where the segmentation results will be saved.
309        segmentation_function: The function performing the segmentation.
310            This function must take the input_volume as the only argument and must return only the segmentation.
311            If you want to pass additional arguments to this function the use 'funtools.partial'
312        data_ext: File extension for the image data. By default '.mrc' is used.
313        extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc.
314            This enables cristae segmentation with an extra mito channel.
315        extra_input_ext: File extension for the extra inputs (by default .tif).
316        mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
317        mask_input_ext: File extension for the mask inputs (by default .tif).
318        force: Whether to rerun segmentation for output files that are already present.
319        output_key: Output key for the prediction. If none will write an hdf5 file.
320        model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction.
321            If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
322        scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
323    """
324    if (scale is not None) and (model_resolution is not None):
325        raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.")
326
327    # Get the input files. If input_path is a folder then this will load all
328    # the mrc files beneath it. Otherwise we assume this is an mrc file already
329    # and just return the path to this mrc file.
330    input_files, input_root = _get_file_paths(input_path, data_ext)
331
332    # Load extra inputs if the extra_input_path was specified.
333    if extra_input_path is None:
334        extra_files = None
335    else:
336        extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext)
337        assert len(input_files) == len(extra_files)
338
339    # Load the masks if they were specified.
340    if mask_input_path is None:
341        mask_files = None
342    else:
343        mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext)
344        assert len(input_files) == len(mask_files)
345
346    for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"):
347        # Determine the output file name.
348        input_folder, input_name = os.path.split(img_path)
349
350        if output_key is None:
351            fname = os.path.splitext(input_name)[0] + "_prediction.tif"
352        else:
353            fname = os.path.splitext(input_name)[0] + "_prediction.h5"
354
355        if input_root is None:
356            output_path = os.path.join(output_root, fname)
357        else:  # If we have nested input folders then we preserve the folder structure in the output.
358            rel_folder = os.path.relpath(input_folder, input_root)
359            output_path = os.path.join(output_root, rel_folder, fname)
360
361        # Check if the output path is already present.
362        # If it is we skip the prediction, unless force was set to true.
363        if os.path.exists(output_path) and not force:
364            if output_key is None:
365                continue
366            else:
367                with open_file(output_path, "r") as f:
368                    if output_key in f:
369                        continue
370
371        # Load the input volume. If we have extra_files then this concatenates the
372        # data across a new first axis (= channel axis).
373        input_volume = _load_input(img_path, extra_files, i)
374        # Load the mask (if given).
375        mask = None if mask_files is None else imageio.imread(mask_files[i])
376
377        # Determine the scale factor:
378        # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None.
379        if scale is None and model_resolution is None:
380            this_scale = None
381        elif scale is not None:   # If 'scale' was passed then use it.
382            this_scale = scale
383        else:   # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data
384            assert model_resolution is not None
385            this_scale = _derive_scale(img_path, model_resolution)
386
387        # Run the segmentation.
388        segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
389
390        # Write the result to tif or h5.
391        os.makedirs(os.path.split(output_path)[0], exist_ok=True)
392
393        if output_key is None:
394            imageio.imwrite(output_path, segmentation, compression="zlib")
395        else:
396            with open_file(output_path, "a") as f:
397                f.create_dataset(output_key, data=segmentation, compression="gzip")
398
399        print(f"Saved segmentation to {output_path}.")
400
401
402def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
403    """Determine the tile shape and halo depending on the available VRAM.
404
405    Args:
406        is_2d: Whether to return tiling settings for 2d inference.
407
408    Returns:
409        The default tiling settings for the available computational resources.
410    """
411    if is_2d:
412        tile = {"x": 768, "y": 768, "z": 1}
413        halo = {"x": 128, "y": 128, "z": 0}
414        return {"tile": tile, "halo": halo}
415
416    if torch.cuda.is_available():
417        # The default halo size.
418        halo = {"x": 64, "y": 64, "z": 16}
419
420        # Determine the GPU RAM and derive a suitable tiling.
421        vram = torch.cuda.get_device_properties(0).total_memory / 1e9
422
423        if vram >= 80:
424            tile = {"x": 640, "y": 640, "z": 80}
425        elif vram >= 40:
426            tile = {"x": 512, "y": 512, "z": 64}
427        elif vram >= 20:
428            tile = {"x": 352, "y": 352, "z": 48}
429        elif vram >= 10:
430            tile = {"x": 256, "y": 256, "z": 32}
431            halo = {"x": 64, "y": 64, "z": 8}  # Choose a smaller halo in z.
432        else:
433            raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.")
434
435        print(f"Determined tile size: {tile}")
436        tiling = {"tile": tile, "halo": halo}
437
438    # I am not sure what is reasonable on a cpu. For now choosing very small tiling.
439    # (This will not work well on a CPU in any case.)
440    else:
441        print("Determining default tiling")
442        tiling = {
443            "tile": {"x": 96, "y": 96, "z": 16},
444            "halo": {"x": 16, "y": 16, "z": 4},
445        }
446
447    return tiling
448
449
450def parse_tiling(
451    tile_shape: Tuple[int, int, int],
452    halo: Tuple[int, int, int],
453    is_2d: bool = False,
454) -> Dict[str, Dict[str, int]]:
455    """Helper function to parse tiling parameter input from the command line.
456
457    Args:
458        tile_shape: The tile shape. If None the default tile shape is used.
459        halo: The halo. If None the default halo is used.
460        is_2d: Whether to return tiling for a 2d model.
461
462    Returns:
463        The tiling specification.
464    """
465
466    default_tiling = get_default_tiling(is_2d=is_2d)
467
468    if tile_shape is None:
469        tile_shape = default_tiling["tile"]
470    else:
471        assert len(tile_shape) == 3
472        tile_shape = dict(zip("zyx", tile_shape))
473
474    if halo is None:
475        halo = default_tiling["halo"]
476    else:
477        assert len(halo) == 3
478        halo = dict(zip("zyx", halo))
479
480    tiling = {"tile": tile_shape, "halo": halo}
481    return tiling
482
483
484#
485# Utils for post-processing.
486#
487
488
489def apply_size_filter(
490    segmentation: np.ndarray,
491    min_size: int,
492    verbose: bool = False,
493    block_shape: Tuple[int, int, int] = (128, 256, 256),
494) -> np.ndarray:
495    """Apply size filter to the segmentation to remove small objects.
496
497    Args:
498        segmentation: The segmentation.
499        min_size: The minimal object size in pixels.
500        verbose: Whether to print runtimes.
501        block_shape: Block shape for parallelizing the operations.
502
503    Returns:
504        The size filtered segmentation.
505    """
506    if min_size == 0:
507        return segmentation
508    t0 = time.time()
509    if segmentation.ndim == 2 and len(block_shape) == 3:
510        block_shape_ = block_shape[1:]
511    else:
512        block_shape_ = block_shape
513    ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose)
514    filter_ids = ids[sizes < min_size]
515    segmentation[np.isin(segmentation, filter_ids)] = 0
516    if verbose:
517        print("Size filter in", time.time() - t0, "s")
518    return segmentation
519
520
521def _postprocess_seg_3d(seg, area_threshold=1000, iterations=4, iterations_3d=8):
522    # Structure lement for 2d dilation in 3d.
523    structure_element = np.ones((3, 3))  # 3x3 structure for XY plane
524    structure_3d = np.zeros((1, 3, 3))  # Only applied in the XY plane
525    structure_3d[0] = structure_element
526
527    props = regionprops(seg)
528    for prop in props:
529        # Get bounding box and mask.
530        bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:3], prop.bbox[3:]))
531        mask = seg[bb] == prop.label
532
533        # Fill small holes and apply closing.
534        mask = remove_small_holes(mask, area_threshold=area_threshold)
535        mask = np.logical_or(binary_closing(mask, iterations=iterations), mask)
536        mask = np.logical_or(binary_closing(mask, iterations=iterations_3d, structure=structure_3d), mask)
537        seg[bb][mask] = prop.label
538
539    return seg
540
541
542#
543# Utils for torch device.
544#
545
546def _get_default_device():
547    # Check that we're in CI and use the CPU if we are.
548    # Otherwise the tests may run out of memory on MAC if MPS is used.
549    if os.getenv("GITHUB_ACTIONS") == "true":
550        return "cpu"
551    # Use cuda enabled gpu if it's available.
552    if torch.cuda.is_available():
553        device = "cuda"
554    # As second priority use mps.
555    # See https://pytorch.org/docs/stable/notes/mps.html for details
556    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
557        device = "mps"
558    # Use the CPU as fallback.
559    else:
560        device = "cpu"
561    return device
562
563
564def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
565    """Get the torch device.
566
567    If no device is passed the default device for your system is used.
568    Else it will be checked if the device you have passed is supported.
569
570    Args:
571        device: The input device.
572
573    Returns:
574        The device.
575    """
576    if device is None or device == "auto":
577        device = _get_default_device()
578    else:
579        device_type = device if isinstance(device, str) else device.type
580        if device_type.lower() == "cuda":
581            if not torch.cuda.is_available():
582                raise RuntimeError("PyTorch CUDA backend is not available.")
583        elif device_type.lower() == "mps":
584            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
585                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
586        elif device_type.lower() == "cpu":
587            pass  # cpu is always available
588        else:
589            raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.")
590    return device
def get_prediction( input_volume: numpy.ndarray, tiling: Optional[Dict[str, Dict[str, int]]], model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, verbose: bool = True, with_channels: bool = False, mask: Optional[numpy.ndarray] = None) -> numpy.ndarray:
 82def get_prediction(
 83    input_volume: np.ndarray,  # [z, y, x]
 84    tiling: Optional[Dict[str, Dict[str, int]]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
 85    model_path: Optional[str] = None,
 86    model: Optional[torch.nn.Module] = None,
 87    verbose: bool = True,
 88    with_channels: bool = False,
 89    mask: Optional[np.ndarray] = None,
 90) -> np.ndarray:
 91    """Run prediction on a given volume.
 92
 93    This function will automatically choose the correct prediction implementation,
 94    depending on the model type.
 95
 96    Args:
 97        input_volume: The input volume to predict on.
 98        model_path: The path to the model checkpoint if 'model' is not provided.
 99        model: Pre-loaded model. Either model_path or model is required.
100        tiling: The tiling configuration for the prediction.
101        verbose: Whether to print timing information.
102        with_channels: Whether to predict with channels.
103        mask: Optional binary mask. If given, the prediction will only be run in
104            the foreground region of the mask.
105
106    Returns:
107        The predicted volume.
108    """
109    # make sure either model path or model is passed
110    if model is None and model_path is None:
111        raise ValueError("Either 'model_path' or 'model' must be provided.")
112
113    if model is not None:
114        is_bioimageio = None
115    else:
116        is_bioimageio = model_path.endswith(".zip")
117
118    if tiling is None:
119        tiling = get_default_tiling()
120
121    # We standardize the data for the whole volume beforehand.
122    # If we have channels then the standardization is done independently per channel.
123    if with_channels:
124        # TODO Check that this is the correct axis.
125        input_volume = torch_em.transform.raw.standardize(input_volume, axis=(1, 2, 3))
126    else:
127        input_volume = torch_em.transform.raw.standardize(input_volume)
128
129    # Run prediction with the bioimage.io library.
130    if is_bioimageio:
131        if mask is not None:
132            raise NotImplementedError
133        raise NotImplementedError
134
135    # Run prediction with the torch-em library.
136    else:
137        if model is None:
138            # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself.
139            if model_path.endswith("best.pt"):
140                model_path = os.path.split(model_path)[0]
141        # print(f"tiling {tiling}")
142        # Create updated_tiling with the same structure
143        updated_tiling = {
144            "tile": {},
145            "halo": tiling["halo"]  # Keep the halo part unchanged
146        }
147        # Update tile dimensions
148        for dim in tiling["tile"]:
149            updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim]
150        # print(f"updated_tiling {updated_tiling}")
151        pred = get_prediction_torch_em(
152            input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask
153        )
154
155    return pred

Run prediction on a given volume.

This function will automatically choose the correct prediction implementation, depending on the model type.

Arguments:
  • input_volume: The input volume to predict on.
  • model_path: The path to the model checkpoint if 'model' is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • verbose: Whether to print timing information.
  • with_channels: Whether to predict with channels.
  • mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
Returns:

The predicted volume.

def get_prediction_torch_em( input_volume: numpy.ndarray, tiling: Dict[str, Dict[str, int]], model_path: Optional[str] = None, model: Optional[torch.nn.modules.module.Module] = None, verbose: bool = True, with_channels: bool = False, mask: Optional[numpy.ndarray] = None) -> numpy.ndarray:
158def get_prediction_torch_em(
159    input_volume: np.ndarray,  # [z, y, x]
160    tiling: Dict[str, Dict[str, int]],  # {"tile": {"z": int, ...}, "halo": {"z": int, ...}}
161    model_path: Optional[str] = None,
162    model: Optional[torch.nn.Module] = None,
163    verbose: bool = True,
164    with_channels: bool = False,
165    mask: Optional[np.ndarray] = None,
166) -> np.ndarray:
167    """Run prediction using torch-em on a given volume.
168
169    Args:
170        input_volume: The input volume to predict on.
171        model_path: The path to the model checkpoint if 'model' is not provided.
172        model: Pre-loaded model. Either model_path or model is required.
173        tiling: The tiling configuration for the prediction.
174        verbose: Whether to print timing information.
175        with_channels: Whether to predict with channels.
176        mask: Optional binary mask. If given, the prediction will only be run in
177            the foreground region of the mask.
178
179    Returns:
180        The predicted volume.
181    """
182    # get block_shape and halo
183    block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]]
184    halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]]
185
186    t0 = time.time()
187    device = "cuda" if torch.cuda.is_available() else "cpu"
188
189    # Suppress warning when loading the model.
190    with warnings.catch_warnings():
191        warnings.simplefilter("ignore")
192        if model is None:
193            if os.path.isdir(model_path):  # Load the model from a torch_em checkpoint.
194                model = torch_em.util.load_model(checkpoint=model_path, device=device)
195            else:  # Load the model directly from a serialized pytorch model.
196                model = torch.load(model_path)
197
198    # Run prediction with the model.
199    with torch.no_grad():
200
201        # Deal with 2D segmentation case
202        if len(input_volume.shape) == 2:
203            block_shape = [block_shape[1], block_shape[2]]
204            halo = [halo[1], halo[2]]
205
206        if mask is not None:
207            if verbose:
208                print("Run prediction with mask.")
209            mask = mask.astype("bool")
210
211        pred = predict_with_halo(
212            input_volume, model, gpu_ids=[device],
213            block_shape=block_shape, halo=halo,
214            preprocess=None, with_channels=with_channels, mask=mask,
215        )
216    if verbose:
217        print("Prediction time in", time.time() - t0, "s")
218    return pred

Run prediction using torch-em on a given volume.

Arguments:
  • input_volume: The input volume to predict on.
  • model_path: The path to the model checkpoint if 'model' is not provided.
  • model: Pre-loaded model. Either model_path or model is required.
  • tiling: The tiling configuration for the prediction.
  • verbose: Whether to print timing information.
  • with_channels: Whether to predict with channels.
  • mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
Returns:

The predicted volume.

def inference_helper( input_path: str, output_root: str, segmentation_function: <built-in function callable>, data_ext: str = '.mrc', extra_input_path: Optional[str] = None, extra_input_ext: str = '.tif', mask_input_path: Optional[str] = None, mask_input_ext: str = '.tif', force: bool = False, output_key: Optional[str] = None, model_resolution: Optional[Tuple[float, float, float]] = None, scale: Optional[Tuple[float, float, float]] = None) -> None:
289def inference_helper(
290    input_path: str,
291    output_root: str,
292    segmentation_function: callable,
293    data_ext: str = ".mrc",
294    extra_input_path: Optional[str] = None,
295    extra_input_ext: str = ".tif",
296    mask_input_path: Optional[str] = None,
297    mask_input_ext: str = ".tif",
298    force: bool = False,
299    output_key: Optional[str] = None,
300    model_resolution: Optional[Tuple[float, float, float]] = None,
301    scale: Optional[Tuple[float, float, float]] = None,
302) -> None:
303    """Helper function to run segmentation for mrc files.
304
305    Args:
306        input_path: The path to the input data.
307            Can either be a folder. In this case all mrc files below the folder will be segmented.
308            Or can be a single mrc file. In this case only this mrc file will be segmented.
309        output_root: The path to the output directory where the segmentation results will be saved.
310        segmentation_function: The function performing the segmentation.
311            This function must take the input_volume as the only argument and must return only the segmentation.
312            If you want to pass additional arguments to this function the use 'funtools.partial'
313        data_ext: File extension for the image data. By default '.mrc' is used.
314        extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc.
315            This enables cristae segmentation with an extra mito channel.
316        extra_input_ext: File extension for the extra inputs (by default .tif).
317        mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
318        mask_input_ext: File extension for the mask inputs (by default .tif).
319        force: Whether to rerun segmentation for output files that are already present.
320        output_key: Output key for the prediction. If none will write an hdf5 file.
321        model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction.
322            If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
323        scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
324    """
325    if (scale is not None) and (model_resolution is not None):
326        raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.")
327
328    # Get the input files. If input_path is a folder then this will load all
329    # the mrc files beneath it. Otherwise we assume this is an mrc file already
330    # and just return the path to this mrc file.
331    input_files, input_root = _get_file_paths(input_path, data_ext)
332
333    # Load extra inputs if the extra_input_path was specified.
334    if extra_input_path is None:
335        extra_files = None
336    else:
337        extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext)
338        assert len(input_files) == len(extra_files)
339
340    # Load the masks if they were specified.
341    if mask_input_path is None:
342        mask_files = None
343    else:
344        mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext)
345        assert len(input_files) == len(mask_files)
346
347    for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"):
348        # Determine the output file name.
349        input_folder, input_name = os.path.split(img_path)
350
351        if output_key is None:
352            fname = os.path.splitext(input_name)[0] + "_prediction.tif"
353        else:
354            fname = os.path.splitext(input_name)[0] + "_prediction.h5"
355
356        if input_root is None:
357            output_path = os.path.join(output_root, fname)
358        else:  # If we have nested input folders then we preserve the folder structure in the output.
359            rel_folder = os.path.relpath(input_folder, input_root)
360            output_path = os.path.join(output_root, rel_folder, fname)
361
362        # Check if the output path is already present.
363        # If it is we skip the prediction, unless force was set to true.
364        if os.path.exists(output_path) and not force:
365            if output_key is None:
366                continue
367            else:
368                with open_file(output_path, "r") as f:
369                    if output_key in f:
370                        continue
371
372        # Load the input volume. If we have extra_files then this concatenates the
373        # data across a new first axis (= channel axis).
374        input_volume = _load_input(img_path, extra_files, i)
375        # Load the mask (if given).
376        mask = None if mask_files is None else imageio.imread(mask_files[i])
377
378        # Determine the scale factor:
379        # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None.
380        if scale is None and model_resolution is None:
381            this_scale = None
382        elif scale is not None:   # If 'scale' was passed then use it.
383            this_scale = scale
384        else:   # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data
385            assert model_resolution is not None
386            this_scale = _derive_scale(img_path, model_resolution)
387
388        # Run the segmentation.
389        segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)
390
391        # Write the result to tif or h5.
392        os.makedirs(os.path.split(output_path)[0], exist_ok=True)
393
394        if output_key is None:
395            imageio.imwrite(output_path, segmentation, compression="zlib")
396        else:
397            with open_file(output_path, "a") as f:
398                f.create_dataset(output_key, data=segmentation, compression="gzip")
399
400        print(f"Saved segmentation to {output_path}.")

Helper function to run segmentation for mrc files.

Arguments:
  • input_path: The path to the input data. Can either be a folder. In this case all mrc files below the folder will be segmented. Or can be a single mrc file. In this case only this mrc file will be segmented.
  • output_root: The path to the output directory where the segmentation results will be saved.
  • segmentation_function: The function performing the segmentation. This function must take the input_volume as the only argument and must return only the segmentation. If you want to pass additional arguments to this function the use 'funtools.partial'
  • data_ext: File extension for the image data. By default '.mrc' is used.
  • extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. This enables cristae segmentation with an extra mito channel.
  • extra_input_ext: File extension for the extra inputs (by default .tif).
  • mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
  • mask_input_ext: File extension for the mask inputs (by default .tif).
  • force: Whether to rerun segmentation for output files that are already present.
  • output_key: Output key for the prediction. If none will write an hdf5 file.
  • model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
  • scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
403def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
404    """Determine the tile shape and halo depending on the available VRAM.
405
406    Args:
407        is_2d: Whether to return tiling settings for 2d inference.
408
409    Returns:
410        The default tiling settings for the available computational resources.
411    """
412    if is_2d:
413        tile = {"x": 768, "y": 768, "z": 1}
414        halo = {"x": 128, "y": 128, "z": 0}
415        return {"tile": tile, "halo": halo}
416
417    if torch.cuda.is_available():
418        # The default halo size.
419        halo = {"x": 64, "y": 64, "z": 16}
420
421        # Determine the GPU RAM and derive a suitable tiling.
422        vram = torch.cuda.get_device_properties(0).total_memory / 1e9
423
424        if vram >= 80:
425            tile = {"x": 640, "y": 640, "z": 80}
426        elif vram >= 40:
427            tile = {"x": 512, "y": 512, "z": 64}
428        elif vram >= 20:
429            tile = {"x": 352, "y": 352, "z": 48}
430        elif vram >= 10:
431            tile = {"x": 256, "y": 256, "z": 32}
432            halo = {"x": 64, "y": 64, "z": 8}  # Choose a smaller halo in z.
433        else:
434            raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.")
435
436        print(f"Determined tile size: {tile}")
437        tiling = {"tile": tile, "halo": halo}
438
439    # I am not sure what is reasonable on a cpu. For now choosing very small tiling.
440    # (This will not work well on a CPU in any case.)
441    else:
442        print("Determining default tiling")
443        tiling = {
444            "tile": {"x": 96, "y": 96, "z": 16},
445            "halo": {"x": 16, "y": 16, "z": 4},
446        }
447
448    return tiling

Determine the tile shape and halo depending on the available VRAM.

Arguments:
  • is_2d: Whether to return tiling settings for 2d inference.
Returns:

The default tiling settings for the available computational resources.

def parse_tiling( tile_shape: Tuple[int, int, int], halo: Tuple[int, int, int], is_2d: bool = False) -> Dict[str, Dict[str, int]]:
451def parse_tiling(
452    tile_shape: Tuple[int, int, int],
453    halo: Tuple[int, int, int],
454    is_2d: bool = False,
455) -> Dict[str, Dict[str, int]]:
456    """Helper function to parse tiling parameter input from the command line.
457
458    Args:
459        tile_shape: The tile shape. If None the default tile shape is used.
460        halo: The halo. If None the default halo is used.
461        is_2d: Whether to return tiling for a 2d model.
462
463    Returns:
464        The tiling specification.
465    """
466
467    default_tiling = get_default_tiling(is_2d=is_2d)
468
469    if tile_shape is None:
470        tile_shape = default_tiling["tile"]
471    else:
472        assert len(tile_shape) == 3
473        tile_shape = dict(zip("zyx", tile_shape))
474
475    if halo is None:
476        halo = default_tiling["halo"]
477    else:
478        assert len(halo) == 3
479        halo = dict(zip("zyx", halo))
480
481    tiling = {"tile": tile_shape, "halo": halo}
482    return tiling

Helper function to parse tiling parameter input from the command line.

Arguments:
  • tile_shape: The tile shape. If None the default tile shape is used.
  • halo: The halo. If None the default halo is used.
  • is_2d: Whether to return tiling for a 2d model.
Returns:

The tiling specification.

def apply_size_filter( segmentation: numpy.ndarray, min_size: int, verbose: bool = False, block_shape: Tuple[int, int, int] = (128, 256, 256)) -> numpy.ndarray:
490def apply_size_filter(
491    segmentation: np.ndarray,
492    min_size: int,
493    verbose: bool = False,
494    block_shape: Tuple[int, int, int] = (128, 256, 256),
495) -> np.ndarray:
496    """Apply size filter to the segmentation to remove small objects.
497
498    Args:
499        segmentation: The segmentation.
500        min_size: The minimal object size in pixels.
501        verbose: Whether to print runtimes.
502        block_shape: Block shape for parallelizing the operations.
503
504    Returns:
505        The size filtered segmentation.
506    """
507    if min_size == 0:
508        return segmentation
509    t0 = time.time()
510    if segmentation.ndim == 2 and len(block_shape) == 3:
511        block_shape_ = block_shape[1:]
512    else:
513        block_shape_ = block_shape
514    ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose)
515    filter_ids = ids[sizes < min_size]
516    segmentation[np.isin(segmentation, filter_ids)] = 0
517    if verbose:
518        print("Size filter in", time.time() - t0, "s")
519    return segmentation

Apply size filter to the segmentation to remove small objects.

Arguments:
  • segmentation: The segmentation.
  • min_size: The minimal object size in pixels.
  • verbose: Whether to print runtimes.
  • block_shape: Block shape for parallelizing the operations.
Returns:

The size filtered segmentation.

def get_device( device: Union[str, torch.device, NoneType] = None) -> Union[str, torch.device]:
565def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
566    """Get the torch device.
567
568    If no device is passed the default device for your system is used.
569    Else it will be checked if the device you have passed is supported.
570
571    Args:
572        device: The input device.
573
574    Returns:
575        The device.
576    """
577    if device is None or device == "auto":
578        device = _get_default_device()
579    else:
580        device_type = device if isinstance(device, str) else device.type
581        if device_type.lower() == "cuda":
582            if not torch.cuda.is_available():
583                raise RuntimeError("PyTorch CUDA backend is not available.")
584        elif device_type.lower() == "mps":
585            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
586                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
587        elif device_type.lower() == "cpu":
588            pass  # cpu is always available
589        else:
590            raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.")
591    return device

Get the torch device.

If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.

Arguments:
  • device: The input device.
Returns:

The device.