synapse_net.inference.util
1import os 2import time 3import warnings 4from glob import glob 5from typing import Dict, List, Optional, Tuple, Union 6 7# # Suppress annoying import warnings. 8# with warnings.catch_warnings(): 9# warnings.simplefilter("ignore") 10# import bioimageio.core 11 12import imageio.v3 as imageio 13import elf.parallel as parallel 14import mrcfile 15import numpy as np 16import torch 17import torch_em 18# import xarray 19 20from elf.io import open_file 21from numpy.typing import ArrayLike 22from scipy.ndimage import binary_closing 23from skimage.measure import regionprops 24from skimage.morphology import remove_small_holes 25from skimage.transform import rescale, resize 26from torch_em.util.prediction import predict_with_halo 27from tqdm import tqdm 28 29 30# 31# Utils for prediction. 32# 33 34 35class _Scaler: 36 def __init__(self, scale, verbose): 37 self.verbose = verbose 38 self._original_shape = None 39 40 if scale is None: 41 self.scale = None 42 return 43 44 # Convert scale to a NumPy array (ensures consistency) 45 scale = np.atleast_1d(scale).astype(np.float64) 46 47 # Validate scale values 48 if not np.issubdtype(scale.dtype, np.number): 49 raise TypeError(f"Scale contains non-numeric values: {scale}") 50 51 # Check if scaling is effectively identity (1.0 in all dimensions) 52 if np.allclose(scale, 1.0, atol=1e-3): 53 self.scale = None 54 else: 55 self.scale = scale 56 57 def scale_input(self, input_volume, is_segmentation=False): 58 t0 = time.time() 59 if self.scale is None: 60 return input_volume 61 62 if self._original_shape is None: 63 self._original_shape = input_volume.shape 64 elif self._original_shape != input_volume.shape: 65 raise RuntimeError( 66 "Scaler was called with different input shapes. " 67 "This is not supported, please create a new instance of the class for it." 68 ) 69 70 if is_segmentation: 71 input_volume = rescale( 72 input_volume, self.scale, preserve_range=True, order=0, anti_aliasing=False, 73 ).astype(input_volume.dtype) 74 else: 75 input_volume = rescale(input_volume, self.scale, preserve_range=True).astype(input_volume.dtype) 76 77 if self.verbose: 78 print("Rescaled volume from", self._original_shape, "to", input_volume.shape, "in", time.time() - t0, "s") 79 return input_volume 80 81 def rescale_output(self, output, is_segmentation): 82 t0 = time.time() 83 if self.scale is None: 84 return output 85 86 assert self._original_shape is not None 87 out_shape = self._original_shape 88 if output.ndim > len(out_shape): 89 assert output.ndim == len(out_shape) + 1 90 out_shape = (output.shape[0],) + out_shape 91 92 if is_segmentation: 93 output = resize(output, out_shape, preserve_range=True, order=0, anti_aliasing=False).astype(output.dtype) 94 else: 95 output = resize(output, out_shape, preserve_range=True).astype(output.dtype) 96 97 if self.verbose: 98 print("Resized prediction back to original shape", output.shape, "in", time.time() - t0, "s") 99 100 return output 101 102 103def _preprocess(input_volume, with_channels, channels_to_standardize): 104 # We standardize the data for the whole volume beforehand. 105 # If we have channels then the standardization is done independently per channel. 106 if with_channels: 107 input_volume = input_volume.astype(np.float32, copy=False) 108 # TODO Check that this is the correct axis. 109 if channels_to_standardize is None: # assume all channels 110 channels_to_standardize = range(input_volume.shape[0]) 111 for ch in channels_to_standardize: 112 input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch]) 113 else: 114 input_volume = torch_em.transform.raw.standardize(input_volume) 115 return input_volume 116 117 118def get_prediction( 119 input_volume: ArrayLike, # [z, y, x] 120 tiling: Optional[Dict[str, Dict[str, int]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 121 model_path: Optional[str] = None, 122 model: Optional[torch.nn.Module] = None, 123 verbose: bool = True, 124 with_channels: bool = False, 125 channels_to_standardize: Optional[List[int]] = None, 126 mask: Optional[ArrayLike] = None, 127 prediction: Optional[ArrayLike] = None, 128 devices: Optional[List[str]] = None, 129 preprocess: Optional[callable] = None, 130) -> ArrayLike: 131 """Run prediction on a given volume. 132 133 This function will automatically choose the correct prediction implementation, 134 depending on the model type. 135 136 Args: 137 input_volume: The input volume to predict on. 138 model_path: The path to the model checkpoint if 'model' is not provided. 139 model: Pre-loaded model. Either model_path or model is required. 140 tiling: The tiling configuration for the prediction. 141 verbose: Whether to print timing information. 142 with_channels: Whether to predict with channels. 143 channels_to_standardize: List of channels to standardize. Defaults to None. 144 mask: Optional binary mask. If given, the prediction will only be run in 145 the foreground region of the mask. 146 prediction: An array like object for writing the prediction. 147 If not given, the prediction will be computed in moemory. 148 devices: The devices for running prediction. If not given will use the GPU 149 if available, otherwise the CPU. 150 151 Returns: 152 The predicted volume. 153 """ 154 # make sure either model path or model is passed 155 if model is None and model_path is None: 156 raise ValueError("Either 'model_path' or 'model' must be provided.") 157 158 if model is not None: 159 is_bioimageio = None 160 else: 161 is_bioimageio = model_path.endswith(".zip") 162 163 if tiling is None: 164 tiling = get_default_tiling() 165 166 # Normalize the whole input volume if it is a numpy array. 167 # Otherwise we have a zarr array or similar as input, and can't normalize it en-block. 168 # Normalization will be applied later per block in this case. 169 if isinstance(input_volume, np.ndarray): 170 input_volume = _preprocess(input_volume, with_channels, channels_to_standardize) 171 172 # Run prediction with the bioimage.io library. 173 if is_bioimageio: 174 if mask is not None: 175 raise NotImplementedError 176 raise NotImplementedError 177 178 # Run prediction with the torch-em library. 179 else: 180 if model is None: 181 # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself. 182 if model_path.endswith("best.pt"): 183 model_path = os.path.split(model_path)[0] 184 # print(f"tiling {tiling}") 185 # Create updated_tiling with the same structure 186 updated_tiling = { 187 "tile": {}, 188 "halo": tiling["halo"] # Keep the halo part unchanged 189 } 190 # Update tile dimensions 191 for dim in tiling["tile"]: 192 updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim] 193 # print(f"updated_tiling {updated_tiling}") 194 prediction = get_prediction_torch_em( 195 input_volume, updated_tiling, model_path, model, verbose, with_channels, 196 mask=mask, prediction=prediction, devices=devices, preprocess=preprocess, 197 ) 198 199 return prediction 200 201 202def get_prediction_torch_em( 203 input_volume: ArrayLike, # [z, y, x] 204 tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 205 model_path: Optional[str] = None, 206 model: Optional[torch.nn.Module] = None, 207 verbose: bool = True, 208 with_channels: bool = False, 209 mask: Optional[ArrayLike] = None, 210 prediction: Optional[ArrayLike] = None, 211 devices: Optional[List[str]] = None, 212 preprocess: Optional[callable] = None, 213) -> np.ndarray: 214 """Run prediction using torch-em on a given volume. 215 216 Args: 217 input_volume: The input volume to predict on. 218 model_path: The path to the model checkpoint if 'model' is not provided. 219 model: Pre-loaded model. Either model_path or model is required. 220 tiling: The tiling configuration for the prediction. 221 verbose: Whether to print timing information. 222 with_channels: Whether to predict with channels. 223 mask: Optional binary mask. If given, the prediction will only be run in 224 the foreground region of the mask. 225 prediction: An array like object for writing the prediction. 226 If not given, the prediction will be computed in moemory. 227 devices: The devices for running prediction. If not given will use the GPU 228 if available, otherwise the CPU. 229 230 Returns: 231 The predicted volume. 232 """ 233 # get block_shape and halo 234 block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]] 235 halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]] 236 237 t0 = time.time() 238 if devices is None: 239 devices = ["cuda" if torch.cuda.is_available() else "cpu"] 240 241 # Suppress warning when loading the model. 242 with warnings.catch_warnings(): 243 warnings.simplefilter("ignore") 244 if model is None: 245 if os.path.isdir(model_path): # Load the model from a torch_em checkpoint. 246 model = torch_em.util.load_model(checkpoint=model_path, device=devices[0]) 247 else: # Load the model directly from a serialized pytorch model. 248 model = torch.load(model_path, weights_only=False) 249 250 # Run prediction with the model. 251 with torch.no_grad(): 252 253 # Deal with 2D segmentation case 254 if len(input_volume.shape) == 2: 255 block_shape = [block_shape[1], block_shape[2]] 256 halo = [halo[1], halo[2]] 257 258 if mask is not None: 259 if verbose: 260 print("Run prediction with mask.") 261 mask = mask.astype("bool") 262 263 if preprocess is None: 264 preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize 265 else: 266 preprocess = preprocess 267 prediction = predict_with_halo( 268 input_volume, model, gpu_ids=devices, 269 block_shape=block_shape, halo=halo, 270 preprocess=preprocess, with_channels=with_channels, mask=mask, 271 output=prediction, 272 ) 273 if verbose: 274 print("Prediction time in", time.time() - t0, "s") 275 return prediction 276 277 278def _get_file_paths(input_path, ext=".mrc"): 279 if not os.path.exists(input_path): 280 raise Exception(f"Input path not found {input_path}") 281 282 if os.path.isfile(input_path): 283 input_files = [input_path] 284 input_root = None 285 else: 286 input_files = sorted(glob(os.path.join(input_path, "**", f"*{ext}"), recursive=True)) 287 input_root = input_path 288 289 return input_files, input_root 290 291 292def _load_input(img_path, extra_files, i): 293 # Load the input data. 294 if os.path.splitext(img_path)[-1] == ".tif": 295 input_volume = imageio.imread(img_path) 296 297 else: 298 with open_file(img_path, "r") as f: 299 # Try to automatically derive the key with the raw data. 300 keys = list(f.keys()) 301 if len(keys) == 1: 302 key = keys[0] 303 elif "data" in keys: 304 key = "data" 305 elif "raw" in keys: 306 key = "raw" 307 input_volume = f[key][:] 308 309 assert input_volume.ndim in (2, 3) 310 # For now we assume this is always tif. 311 if extra_files is not None: 312 extra_input = imageio.imread(extra_files[i]) 313 assert extra_input.shape == input_volume.shape 314 input_volume = np.stack([input_volume, extra_input], axis=0) 315 316 return input_volume 317 318 319def _derive_scale(img_path, model_resolution): 320 try: 321 with mrcfile.open(img_path, "r") as f: 322 voxel_size = f.voxel_size 323 if len(model_resolution) == 2: 324 voxel_size = [voxel_size.y, voxel_size.x] 325 else: 326 voxel_size = [voxel_size.z, voxel_size.y, voxel_size.x] 327 328 assert len(voxel_size) == len(model_resolution) 329 # The voxel size is given in Angstrom and we need to translate it to nanometer. 330 voxel_size = [vsize / 10 for vsize in voxel_size] 331 332 # Compute the correct scale factor. 333 scale = tuple(vsize / res for vsize, res in zip(voxel_size, model_resolution)) 334 print("Rescaling the data at", img_path, "by", scale, "to match the training voxel size", model_resolution) 335 336 except Exception: 337 warnings.warn( 338 f"The voxel size could not be read from the data for {img_path}. " 339 "This data will not be scaled for prediction." 340 ) 341 scale = None 342 343 return scale 344 345 346def inference_helper( 347 input_path: str, 348 output_root: str, 349 segmentation_function: callable, 350 data_ext: str = ".mrc", 351 extra_input_path: Optional[str] = None, 352 extra_input_ext: str = ".tif", 353 mask_input_path: Optional[str] = None, 354 mask_input_ext: str = ".tif", 355 force: bool = False, 356 output_key: Optional[str] = None, 357 model_resolution: Optional[Tuple[float, float, float]] = None, 358 scale: Optional[Tuple[float, float, float]] = None, 359 allocate_output: bool = False, 360) -> None: 361 """Helper function to run segmentation for mrc files. 362 363 Args: 364 input_path: The path to the input data. 365 Can either be a folder. In this case all mrc files below the folder will be segmented. 366 Or can be a single mrc file. In this case only this mrc file will be segmented. 367 output_root: The path to the output directory where the segmentation results will be saved. 368 segmentation_function: The function performing the segmentation. 369 This function must take the input_volume as the only argument and must return only the segmentation. 370 If you want to pass additional arguments to this function the use 'funtools.partial' 371 data_ext: File extension for the image data. By default '.mrc' is used. 372 extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. 373 This enables cristae segmentation with an extra mito channel. 374 extra_input_ext: File extension for the extra inputs (by default .tif). 375 mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation. 376 mask_input_ext: File extension for the mask inputs (by default .tif). 377 force: Whether to rerun segmentation for output files that are already present. 378 output_key: Output key for the prediction. If none will write an hdf5 file. 379 model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. 380 If given, the scaling factor will automatically be determined based on the voxel_size of the input data. 381 scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'. 382 allocate_output: Whether to allocate the output for the segmentation function. 383 """ 384 if (scale is not None) and (model_resolution is not None): 385 raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.") 386 387 # Get the input files. If input_path is a folder then this will load all 388 # the mrc files beneath it. Otherwise we assume this is an mrc file already 389 # and just return the path to this mrc file. 390 input_files, input_root = _get_file_paths(input_path, data_ext) 391 392 # Load extra inputs if the extra_input_path was specified. 393 if extra_input_path is None: 394 extra_files = None 395 else: 396 extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext) 397 assert len(input_files) == len(extra_files) 398 399 # Load the masks if they were specified. 400 if mask_input_path is None: 401 mask_files = None 402 else: 403 mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext) 404 assert len(input_files) == len(mask_files) 405 406 for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"): 407 # Determine the output file name. 408 input_folder, input_name = os.path.split(img_path) 409 410 if output_key is None: 411 fname = os.path.splitext(input_name)[0] + "_prediction.tif" 412 else: 413 fname = os.path.splitext(input_name)[0] + "_prediction.h5" 414 415 if input_root is None: 416 output_path = os.path.join(output_root, fname) 417 else: # If we have nested input folders then we preserve the folder structure in the output. 418 rel_folder = os.path.relpath(input_folder, input_root) 419 output_path = os.path.join(output_root, rel_folder, fname) 420 421 # Check if the output path is already present. 422 # If it is we skip the prediction, unless force was set to true. 423 if os.path.exists(output_path) and not force: 424 if output_key is None: 425 continue 426 else: 427 with open_file(output_path, "r") as f: 428 if output_key in f: 429 continue 430 431 # Load the input volume. If we have extra_files then this concatenates the 432 # data across a new first axis (= channel axis). 433 input_volume = _load_input(img_path, extra_files, i) 434 # Load the mask (if given). 435 mask = None if mask_files is None else imageio.imread(mask_files[i]) 436 437 # Determine the scale factor: 438 # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None. 439 if scale is None and model_resolution is None: 440 this_scale = None 441 elif scale is not None: # If 'scale' was passed then use it. 442 this_scale = scale 443 else: # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data 444 assert model_resolution is not None 445 this_scale = _derive_scale(img_path, model_resolution) 446 447 # Run the segmentation. 448 if allocate_output: 449 segmentation = np.zeros(input_volume.shape, dtype="uint32") 450 segmentation_function(input_volume, output=segmentation, mask=mask, scale=this_scale) 451 else: 452 segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale) 453 454 # Write the result to tif or h5. 455 os.makedirs(os.path.split(output_path)[0], exist_ok=True) 456 457 if output_key is None: 458 imageio.imwrite(output_path, segmentation, compression="zlib") 459 else: 460 with open_file(output_path, "a") as f: 461 f.create_dataset(output_key, data=segmentation, compression="gzip") 462 463 print(f"Saved segmentation to {output_path}.") 464 465 466def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: 467 """Determine the tile shape and halo depending on the available VRAM. 468 469 Args: 470 is_2d: Whether to return tiling settings for 2d inference. 471 472 Returns: 473 The default tiling settings for the available computational resources. 474 """ 475 if is_2d: 476 tile = {"x": 768, "y": 768, "z": 1} 477 halo = {"x": 128, "y": 128, "z": 0} 478 return {"tile": tile, "halo": halo} 479 480 if torch.cuda.is_available(): 481 # The default halo size. 482 halo = {"x": 64, "y": 64, "z": 16} 483 484 # Determine the GPU RAM and derive a suitable tiling. 485 vram = torch.cuda.get_device_properties(0).total_memory / 1e9 486 487 if vram >= 80: 488 tile = {"x": 640, "y": 640, "z": 80} 489 elif vram >= 40: 490 tile = {"x": 512, "y": 512, "z": 64} 491 elif vram >= 20: 492 tile = {"x": 352, "y": 352, "z": 48} 493 elif vram >= 10: 494 tile = {"x": 256, "y": 256, "z": 32} 495 halo = {"x": 64, "y": 64, "z": 8} # Choose a smaller halo in z. 496 else: 497 raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.") 498 499 tiling = {"tile": tile, "halo": halo} 500 print(f"Determined tile size for CUDA: {tiling}") 501 502 elif torch.backends.mps.is_available(): # Check for Apple Silicon (MPS) 503 tile = {"x": 256, "y": 256, "z": 16} 504 halo = {"x": 16, "y": 16, "z": 4} 505 tiling = {"tile": tile, "halo": halo} 506 print(f"Determined tile size for MPS: {tiling}") 507 508 # I am not sure what is reasonable on a cpu. For now choosing very small tiling. 509 # (This will not work well on a CPU in any case.) 510 else: 511 tiling = { 512 "tile": {"x": 96, "y": 96, "z": 16}, 513 "halo": {"x": 16, "y": 16, "z": 4}, 514 } 515 print(f"Determining default tiling for CPU: {tiling}") 516 517 return tiling 518 519 520def parse_tiling( 521 tile_shape: Tuple[int, int, int], 522 halo: Tuple[int, int, int], 523 is_2d: bool = False, 524) -> Dict[str, Dict[str, int]]: 525 """Helper function to parse tiling parameter input from the command line. 526 527 Args: 528 tile_shape: The tile shape. If None the default tile shape is used. 529 halo: The halo. If None the default halo is used. 530 is_2d: Whether to return tiling for a 2d model. 531 532 Returns: 533 The tiling specification. 534 """ 535 536 default_tiling = get_default_tiling(is_2d=is_2d) 537 538 if tile_shape is None: 539 tile_shape = default_tiling["tile"] 540 else: 541 assert len(tile_shape) == 3 542 tile_shape = dict(zip("zyx", tile_shape)) 543 544 if halo is None: 545 halo = default_tiling["halo"] 546 else: 547 assert len(halo) == 3 548 halo = dict(zip("zyx", halo)) 549 550 tiling = {"tile": tile_shape, "halo": halo} 551 return tiling 552 553 554# 555# Utils for post-processing. 556# 557 558 559def apply_size_filter( 560 segmentation: np.ndarray, 561 min_size: int, 562 verbose: bool = False, 563 block_shape: Tuple[int, int, int] = (128, 256, 256), 564) -> np.ndarray: 565 """Apply size filter to the segmentation to remove small objects. 566 567 Args: 568 segmentation: The segmentation. 569 min_size: The minimal object size in pixels. 570 verbose: Whether to print runtimes. 571 block_shape: Block shape for parallelizing the operations. 572 573 Returns: 574 The size filtered segmentation. 575 """ 576 if min_size == 0: 577 return segmentation 578 t0 = time.time() 579 if segmentation.ndim == 2 and len(block_shape) == 3: 580 block_shape_ = block_shape[1:] 581 else: 582 block_shape_ = block_shape 583 ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose) 584 filter_ids = ids[sizes < min_size] 585 segmentation[np.isin(segmentation, filter_ids)] = 0 586 if verbose: 587 print("Size filter in", time.time() - t0, "s") 588 return segmentation 589 590 591def _postprocess_seg_3d(seg, area_threshold=1000, iterations=4, iterations_3d=8): 592 # Structure lement for 2d dilation in 3d. 593 structure_element = np.ones((3, 3)) # 3x3 structure for XY plane 594 structure_3d = np.zeros((1, 3, 3)) # Only applied in the XY plane 595 structure_3d[0] = structure_element 596 597 props = regionprops(seg) 598 for prop in props: 599 # Get bounding box and mask. 600 bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:3], prop.bbox[3:])) 601 mask = seg[bb] == prop.label 602 603 # Fill small holes and apply closing. 604 mask = remove_small_holes(mask, area_threshold=area_threshold) 605 mask = np.logical_or(binary_closing(mask, iterations=iterations), mask) 606 mask = np.logical_or(binary_closing(mask, iterations=iterations_3d, structure=structure_3d), mask) 607 seg[bb][mask] = prop.label 608 609 return seg 610 611 612# 613# Utils for torch device. 614# 615 616def _get_default_device(): 617 # Check that we're in CI and use the CPU if we are. 618 # Otherwise the tests may run out of memory on MAC if MPS is used. 619 if os.getenv("GITHUB_ACTIONS") == "true": 620 return "cpu" 621 # Use cuda enabled gpu if it's available. 622 if torch.cuda.is_available(): 623 device = "cuda" 624 # As second priority use mps. 625 # See https://pytorch.org/docs/stable/notes/mps.html for details 626 elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): 627 device = "mps" 628 # Use the CPU as fallback. 629 else: 630 device = "cpu" 631 return device 632 633 634def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 635 """Get the torch device. 636 637 If no device is passed the default device for your system is used. 638 Else it will be checked if the device you have passed is supported. 639 640 Args: 641 device: The input device. 642 643 Returns: 644 The device. 645 """ 646 if device is None or device == "auto": 647 device = _get_default_device() 648 else: 649 device_type = device if isinstance(device, str) else device.type 650 if device_type.lower() == "cuda": 651 if not torch.cuda.is_available(): 652 raise RuntimeError("PyTorch CUDA backend is not available.") 653 elif device_type.lower() == "mps": 654 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 655 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 656 elif device_type.lower() == "cpu": 657 pass # cpu is always available 658 else: 659 raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.") 660 return device
119def get_prediction( 120 input_volume: ArrayLike, # [z, y, x] 121 tiling: Optional[Dict[str, Dict[str, int]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 122 model_path: Optional[str] = None, 123 model: Optional[torch.nn.Module] = None, 124 verbose: bool = True, 125 with_channels: bool = False, 126 channels_to_standardize: Optional[List[int]] = None, 127 mask: Optional[ArrayLike] = None, 128 prediction: Optional[ArrayLike] = None, 129 devices: Optional[List[str]] = None, 130 preprocess: Optional[callable] = None, 131) -> ArrayLike: 132 """Run prediction on a given volume. 133 134 This function will automatically choose the correct prediction implementation, 135 depending on the model type. 136 137 Args: 138 input_volume: The input volume to predict on. 139 model_path: The path to the model checkpoint if 'model' is not provided. 140 model: Pre-loaded model. Either model_path or model is required. 141 tiling: The tiling configuration for the prediction. 142 verbose: Whether to print timing information. 143 with_channels: Whether to predict with channels. 144 channels_to_standardize: List of channels to standardize. Defaults to None. 145 mask: Optional binary mask. If given, the prediction will only be run in 146 the foreground region of the mask. 147 prediction: An array like object for writing the prediction. 148 If not given, the prediction will be computed in moemory. 149 devices: The devices for running prediction. If not given will use the GPU 150 if available, otherwise the CPU. 151 152 Returns: 153 The predicted volume. 154 """ 155 # make sure either model path or model is passed 156 if model is None and model_path is None: 157 raise ValueError("Either 'model_path' or 'model' must be provided.") 158 159 if model is not None: 160 is_bioimageio = None 161 else: 162 is_bioimageio = model_path.endswith(".zip") 163 164 if tiling is None: 165 tiling = get_default_tiling() 166 167 # Normalize the whole input volume if it is a numpy array. 168 # Otherwise we have a zarr array or similar as input, and can't normalize it en-block. 169 # Normalization will be applied later per block in this case. 170 if isinstance(input_volume, np.ndarray): 171 input_volume = _preprocess(input_volume, with_channels, channels_to_standardize) 172 173 # Run prediction with the bioimage.io library. 174 if is_bioimageio: 175 if mask is not None: 176 raise NotImplementedError 177 raise NotImplementedError 178 179 # Run prediction with the torch-em library. 180 else: 181 if model is None: 182 # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself. 183 if model_path.endswith("best.pt"): 184 model_path = os.path.split(model_path)[0] 185 # print(f"tiling {tiling}") 186 # Create updated_tiling with the same structure 187 updated_tiling = { 188 "tile": {}, 189 "halo": tiling["halo"] # Keep the halo part unchanged 190 } 191 # Update tile dimensions 192 for dim in tiling["tile"]: 193 updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim] 194 # print(f"updated_tiling {updated_tiling}") 195 prediction = get_prediction_torch_em( 196 input_volume, updated_tiling, model_path, model, verbose, with_channels, 197 mask=mask, prediction=prediction, devices=devices, preprocess=preprocess, 198 ) 199 200 return prediction
Run prediction on a given volume.
This function will automatically choose the correct prediction implementation, depending on the model type.
Arguments:
- input_volume: The input volume to predict on.
- model_path: The path to the model checkpoint if 'model' is not provided.
- model: Pre-loaded model. Either model_path or model is required.
- tiling: The tiling configuration for the prediction.
- verbose: Whether to print timing information.
- with_channels: Whether to predict with channels.
- channels_to_standardize: List of channels to standardize. Defaults to None.
- mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
- prediction: An array like object for writing the prediction. If not given, the prediction will be computed in moemory.
- devices: The devices for running prediction. If not given will use the GPU if available, otherwise the CPU.
Returns:
The predicted volume.
203def get_prediction_torch_em( 204 input_volume: ArrayLike, # [z, y, x] 205 tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 206 model_path: Optional[str] = None, 207 model: Optional[torch.nn.Module] = None, 208 verbose: bool = True, 209 with_channels: bool = False, 210 mask: Optional[ArrayLike] = None, 211 prediction: Optional[ArrayLike] = None, 212 devices: Optional[List[str]] = None, 213 preprocess: Optional[callable] = None, 214) -> np.ndarray: 215 """Run prediction using torch-em on a given volume. 216 217 Args: 218 input_volume: The input volume to predict on. 219 model_path: The path to the model checkpoint if 'model' is not provided. 220 model: Pre-loaded model. Either model_path or model is required. 221 tiling: The tiling configuration for the prediction. 222 verbose: Whether to print timing information. 223 with_channels: Whether to predict with channels. 224 mask: Optional binary mask. If given, the prediction will only be run in 225 the foreground region of the mask. 226 prediction: An array like object for writing the prediction. 227 If not given, the prediction will be computed in moemory. 228 devices: The devices for running prediction. If not given will use the GPU 229 if available, otherwise the CPU. 230 231 Returns: 232 The predicted volume. 233 """ 234 # get block_shape and halo 235 block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]] 236 halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]] 237 238 t0 = time.time() 239 if devices is None: 240 devices = ["cuda" if torch.cuda.is_available() else "cpu"] 241 242 # Suppress warning when loading the model. 243 with warnings.catch_warnings(): 244 warnings.simplefilter("ignore") 245 if model is None: 246 if os.path.isdir(model_path): # Load the model from a torch_em checkpoint. 247 model = torch_em.util.load_model(checkpoint=model_path, device=devices[0]) 248 else: # Load the model directly from a serialized pytorch model. 249 model = torch.load(model_path, weights_only=False) 250 251 # Run prediction with the model. 252 with torch.no_grad(): 253 254 # Deal with 2D segmentation case 255 if len(input_volume.shape) == 2: 256 block_shape = [block_shape[1], block_shape[2]] 257 halo = [halo[1], halo[2]] 258 259 if mask is not None: 260 if verbose: 261 print("Run prediction with mask.") 262 mask = mask.astype("bool") 263 264 if preprocess is None: 265 preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize 266 else: 267 preprocess = preprocess 268 prediction = predict_with_halo( 269 input_volume, model, gpu_ids=devices, 270 block_shape=block_shape, halo=halo, 271 preprocess=preprocess, with_channels=with_channels, mask=mask, 272 output=prediction, 273 ) 274 if verbose: 275 print("Prediction time in", time.time() - t0, "s") 276 return prediction
Run prediction using torch-em on a given volume.
Arguments:
- input_volume: The input volume to predict on.
- model_path: The path to the model checkpoint if 'model' is not provided.
- model: Pre-loaded model. Either model_path or model is required.
- tiling: The tiling configuration for the prediction.
- verbose: Whether to print timing information.
- with_channels: Whether to predict with channels.
- mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
- prediction: An array like object for writing the prediction. If not given, the prediction will be computed in moemory.
- devices: The devices for running prediction. If not given will use the GPU if available, otherwise the CPU.
Returns:
The predicted volume.
347def inference_helper( 348 input_path: str, 349 output_root: str, 350 segmentation_function: callable, 351 data_ext: str = ".mrc", 352 extra_input_path: Optional[str] = None, 353 extra_input_ext: str = ".tif", 354 mask_input_path: Optional[str] = None, 355 mask_input_ext: str = ".tif", 356 force: bool = False, 357 output_key: Optional[str] = None, 358 model_resolution: Optional[Tuple[float, float, float]] = None, 359 scale: Optional[Tuple[float, float, float]] = None, 360 allocate_output: bool = False, 361) -> None: 362 """Helper function to run segmentation for mrc files. 363 364 Args: 365 input_path: The path to the input data. 366 Can either be a folder. In this case all mrc files below the folder will be segmented. 367 Or can be a single mrc file. In this case only this mrc file will be segmented. 368 output_root: The path to the output directory where the segmentation results will be saved. 369 segmentation_function: The function performing the segmentation. 370 This function must take the input_volume as the only argument and must return only the segmentation. 371 If you want to pass additional arguments to this function the use 'funtools.partial' 372 data_ext: File extension for the image data. By default '.mrc' is used. 373 extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. 374 This enables cristae segmentation with an extra mito channel. 375 extra_input_ext: File extension for the extra inputs (by default .tif). 376 mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation. 377 mask_input_ext: File extension for the mask inputs (by default .tif). 378 force: Whether to rerun segmentation for output files that are already present. 379 output_key: Output key for the prediction. If none will write an hdf5 file. 380 model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. 381 If given, the scaling factor will automatically be determined based on the voxel_size of the input data. 382 scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'. 383 allocate_output: Whether to allocate the output for the segmentation function. 384 """ 385 if (scale is not None) and (model_resolution is not None): 386 raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.") 387 388 # Get the input files. If input_path is a folder then this will load all 389 # the mrc files beneath it. Otherwise we assume this is an mrc file already 390 # and just return the path to this mrc file. 391 input_files, input_root = _get_file_paths(input_path, data_ext) 392 393 # Load extra inputs if the extra_input_path was specified. 394 if extra_input_path is None: 395 extra_files = None 396 else: 397 extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext) 398 assert len(input_files) == len(extra_files) 399 400 # Load the masks if they were specified. 401 if mask_input_path is None: 402 mask_files = None 403 else: 404 mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext) 405 assert len(input_files) == len(mask_files) 406 407 for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"): 408 # Determine the output file name. 409 input_folder, input_name = os.path.split(img_path) 410 411 if output_key is None: 412 fname = os.path.splitext(input_name)[0] + "_prediction.tif" 413 else: 414 fname = os.path.splitext(input_name)[0] + "_prediction.h5" 415 416 if input_root is None: 417 output_path = os.path.join(output_root, fname) 418 else: # If we have nested input folders then we preserve the folder structure in the output. 419 rel_folder = os.path.relpath(input_folder, input_root) 420 output_path = os.path.join(output_root, rel_folder, fname) 421 422 # Check if the output path is already present. 423 # If it is we skip the prediction, unless force was set to true. 424 if os.path.exists(output_path) and not force: 425 if output_key is None: 426 continue 427 else: 428 with open_file(output_path, "r") as f: 429 if output_key in f: 430 continue 431 432 # Load the input volume. If we have extra_files then this concatenates the 433 # data across a new first axis (= channel axis). 434 input_volume = _load_input(img_path, extra_files, i) 435 # Load the mask (if given). 436 mask = None if mask_files is None else imageio.imread(mask_files[i]) 437 438 # Determine the scale factor: 439 # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None. 440 if scale is None and model_resolution is None: 441 this_scale = None 442 elif scale is not None: # If 'scale' was passed then use it. 443 this_scale = scale 444 else: # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data 445 assert model_resolution is not None 446 this_scale = _derive_scale(img_path, model_resolution) 447 448 # Run the segmentation. 449 if allocate_output: 450 segmentation = np.zeros(input_volume.shape, dtype="uint32") 451 segmentation_function(input_volume, output=segmentation, mask=mask, scale=this_scale) 452 else: 453 segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale) 454 455 # Write the result to tif or h5. 456 os.makedirs(os.path.split(output_path)[0], exist_ok=True) 457 458 if output_key is None: 459 imageio.imwrite(output_path, segmentation, compression="zlib") 460 else: 461 with open_file(output_path, "a") as f: 462 f.create_dataset(output_key, data=segmentation, compression="gzip") 463 464 print(f"Saved segmentation to {output_path}.")
Helper function to run segmentation for mrc files.
Arguments:
- input_path: The path to the input data. Can either be a folder. In this case all mrc files below the folder will be segmented. Or can be a single mrc file. In this case only this mrc file will be segmented.
- output_root: The path to the output directory where the segmentation results will be saved.
- segmentation_function: The function performing the segmentation. This function must take the input_volume as the only argument and must return only the segmentation. If you want to pass additional arguments to this function the use 'funtools.partial'
- data_ext: File extension for the image data. By default '.mrc' is used.
- extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. This enables cristae segmentation with an extra mito channel.
- extra_input_ext: File extension for the extra inputs (by default .tif).
- mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
- mask_input_ext: File extension for the mask inputs (by default .tif).
- force: Whether to rerun segmentation for output files that are already present.
- output_key: Output key for the prediction. If none will write an hdf5 file.
- model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
- scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
- allocate_output: Whether to allocate the output for the segmentation function.
467def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: 468 """Determine the tile shape and halo depending on the available VRAM. 469 470 Args: 471 is_2d: Whether to return tiling settings for 2d inference. 472 473 Returns: 474 The default tiling settings for the available computational resources. 475 """ 476 if is_2d: 477 tile = {"x": 768, "y": 768, "z": 1} 478 halo = {"x": 128, "y": 128, "z": 0} 479 return {"tile": tile, "halo": halo} 480 481 if torch.cuda.is_available(): 482 # The default halo size. 483 halo = {"x": 64, "y": 64, "z": 16} 484 485 # Determine the GPU RAM and derive a suitable tiling. 486 vram = torch.cuda.get_device_properties(0).total_memory / 1e9 487 488 if vram >= 80: 489 tile = {"x": 640, "y": 640, "z": 80} 490 elif vram >= 40: 491 tile = {"x": 512, "y": 512, "z": 64} 492 elif vram >= 20: 493 tile = {"x": 352, "y": 352, "z": 48} 494 elif vram >= 10: 495 tile = {"x": 256, "y": 256, "z": 32} 496 halo = {"x": 64, "y": 64, "z": 8} # Choose a smaller halo in z. 497 else: 498 raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.") 499 500 tiling = {"tile": tile, "halo": halo} 501 print(f"Determined tile size for CUDA: {tiling}") 502 503 elif torch.backends.mps.is_available(): # Check for Apple Silicon (MPS) 504 tile = {"x": 256, "y": 256, "z": 16} 505 halo = {"x": 16, "y": 16, "z": 4} 506 tiling = {"tile": tile, "halo": halo} 507 print(f"Determined tile size for MPS: {tiling}") 508 509 # I am not sure what is reasonable on a cpu. For now choosing very small tiling. 510 # (This will not work well on a CPU in any case.) 511 else: 512 tiling = { 513 "tile": {"x": 96, "y": 96, "z": 16}, 514 "halo": {"x": 16, "y": 16, "z": 4}, 515 } 516 print(f"Determining default tiling for CPU: {tiling}") 517 518 return tiling
Determine the tile shape and halo depending on the available VRAM.
Arguments:
- is_2d: Whether to return tiling settings for 2d inference.
Returns:
The default tiling settings for the available computational resources.
521def parse_tiling( 522 tile_shape: Tuple[int, int, int], 523 halo: Tuple[int, int, int], 524 is_2d: bool = False, 525) -> Dict[str, Dict[str, int]]: 526 """Helper function to parse tiling parameter input from the command line. 527 528 Args: 529 tile_shape: The tile shape. If None the default tile shape is used. 530 halo: The halo. If None the default halo is used. 531 is_2d: Whether to return tiling for a 2d model. 532 533 Returns: 534 The tiling specification. 535 """ 536 537 default_tiling = get_default_tiling(is_2d=is_2d) 538 539 if tile_shape is None: 540 tile_shape = default_tiling["tile"] 541 else: 542 assert len(tile_shape) == 3 543 tile_shape = dict(zip("zyx", tile_shape)) 544 545 if halo is None: 546 halo = default_tiling["halo"] 547 else: 548 assert len(halo) == 3 549 halo = dict(zip("zyx", halo)) 550 551 tiling = {"tile": tile_shape, "halo": halo} 552 return tiling
Helper function to parse tiling parameter input from the command line.
Arguments:
- tile_shape: The tile shape. If None the default tile shape is used.
- halo: The halo. If None the default halo is used.
- is_2d: Whether to return tiling for a 2d model.
Returns:
The tiling specification.
560def apply_size_filter( 561 segmentation: np.ndarray, 562 min_size: int, 563 verbose: bool = False, 564 block_shape: Tuple[int, int, int] = (128, 256, 256), 565) -> np.ndarray: 566 """Apply size filter to the segmentation to remove small objects. 567 568 Args: 569 segmentation: The segmentation. 570 min_size: The minimal object size in pixels. 571 verbose: Whether to print runtimes. 572 block_shape: Block shape for parallelizing the operations. 573 574 Returns: 575 The size filtered segmentation. 576 """ 577 if min_size == 0: 578 return segmentation 579 t0 = time.time() 580 if segmentation.ndim == 2 and len(block_shape) == 3: 581 block_shape_ = block_shape[1:] 582 else: 583 block_shape_ = block_shape 584 ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose) 585 filter_ids = ids[sizes < min_size] 586 segmentation[np.isin(segmentation, filter_ids)] = 0 587 if verbose: 588 print("Size filter in", time.time() - t0, "s") 589 return segmentation
Apply size filter to the segmentation to remove small objects.
Arguments:
- segmentation: The segmentation.
- min_size: The minimal object size in pixels.
- verbose: Whether to print runtimes.
- block_shape: Block shape for parallelizing the operations.
Returns:
The size filtered segmentation.
635def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 636 """Get the torch device. 637 638 If no device is passed the default device for your system is used. 639 Else it will be checked if the device you have passed is supported. 640 641 Args: 642 device: The input device. 643 644 Returns: 645 The device. 646 """ 647 if device is None or device == "auto": 648 device = _get_default_device() 649 else: 650 device_type = device if isinstance(device, str) else device.type 651 if device_type.lower() == "cuda": 652 if not torch.cuda.is_available(): 653 raise RuntimeError("PyTorch CUDA backend is not available.") 654 elif device_type.lower() == "mps": 655 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 656 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 657 elif device_type.lower() == "cpu": 658 pass # cpu is always available 659 else: 660 raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.") 661 return device
Get the torch device.
If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.
Arguments:
- device: The input device.
Returns:
The device.