synapse_net.inference.util
1import os 2import time 3import warnings 4from glob import glob 5from typing import Dict, List, Optional, Tuple, Union 6 7# # Suppress annoying import warnings. 8# with warnings.catch_warnings(): 9# warnings.simplefilter("ignore") 10# import bioimageio.core 11 12import imageio.v3 as imageio 13import elf.parallel as parallel 14import mrcfile 15import numpy as np 16import torch 17import torch_em 18# import xarray 19 20from elf.io import open_file 21from numpy.typing import ArrayLike 22from scipy.ndimage import binary_closing 23from skimage.measure import regionprops 24from skimage.morphology import remove_small_holes 25from skimage.transform import rescale, resize 26from torch_em.util.prediction import predict_with_halo 27from tqdm import tqdm 28 29 30# 31# Utils for prediction. 32# 33 34 35class _Scaler: 36 def __init__(self, scale, verbose): 37 self.verbose = verbose 38 self._original_shape = None 39 40 if scale is None: 41 self.scale = None 42 return 43 44 # Convert scale to a NumPy array (ensures consistency) 45 scale = np.atleast_1d(scale).astype(np.float64) 46 47 # Validate scale values 48 if not np.issubdtype(scale.dtype, np.number): 49 raise TypeError(f"Scale contains non-numeric values: {scale}") 50 51 # Check if scaling is effectively identity (1.0 in all dimensions) 52 if np.allclose(scale, 1.0, atol=1e-3): 53 self.scale = None 54 else: 55 self.scale = scale 56 57 def scale_input(self, input_volume, is_segmentation=False): 58 t0 = time.time() 59 if self.scale is None: 60 return input_volume 61 62 if self._original_shape is None: 63 self._original_shape = input_volume.shape 64 elif self._original_shape != input_volume.shape: 65 raise RuntimeError( 66 "Scaler was called with different input shapes. " 67 "This is not supported, please create a new instance of the class for it." 68 ) 69 70 if is_segmentation: 71 input_volume = rescale( 72 input_volume, self.scale, preserve_range=True, order=0, anti_aliasing=False, 73 ).astype(input_volume.dtype) 74 else: 75 input_volume = rescale(input_volume, self.scale, preserve_range=True).astype(input_volume.dtype) 76 77 if self.verbose: 78 print("Rescaled volume from", self._original_shape, "to", input_volume.shape, "in", time.time() - t0, "s") 79 return input_volume 80 81 def rescale_output(self, output, is_segmentation): 82 t0 = time.time() 83 if self.scale is None: 84 return output 85 86 assert self._original_shape is not None 87 out_shape = self._original_shape 88 if output.ndim > len(out_shape): 89 assert output.ndim == len(out_shape) + 1 90 out_shape = (output.shape[0],) + out_shape 91 92 if is_segmentation: 93 output = resize(output, out_shape, preserve_range=True, order=0, anti_aliasing=False).astype(output.dtype) 94 else: 95 output = resize(output, out_shape, preserve_range=True).astype(output.dtype) 96 97 if self.verbose: 98 print("Resized prediction back to original shape", output.shape, "in", time.time() - t0, "s") 99 100 return output 101 102 103def _preprocess(input_volume, with_channels, channels_to_standardize): 104 # We standardize the data for the whole volume beforehand. 105 # If we have channels then the standardization is done independently per channel. 106 if with_channels: 107 input_volume = input_volume.astype(np.float32, copy=False) 108 # TODO Check that this is the correct axis. 109 if channels_to_standardize is None: # assume all channels 110 channels_to_standardize = range(input_volume.shape[0]) 111 for ch in channels_to_standardize: 112 input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch]) 113 else: 114 input_volume = torch_em.transform.raw.standardize(input_volume) 115 return input_volume 116 117 118def get_prediction( 119 input_volume: ArrayLike, # [z, y, x] 120 tiling: Optional[Dict[str, Dict[str, int]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 121 model_path: Optional[str] = None, 122 model: Optional[torch.nn.Module] = None, 123 verbose: bool = True, 124 with_channels: bool = False, 125 channels_to_standardize: Optional[List[int]] = None, 126 mask: Optional[ArrayLike] = None, 127 prediction: Optional[ArrayLike] = None, 128 devices: Optional[List[str]] = None, 129) -> ArrayLike: 130 """Run prediction on a given volume. 131 132 This function will automatically choose the correct prediction implementation, 133 depending on the model type. 134 135 Args: 136 input_volume: The input volume to predict on. 137 model_path: The path to the model checkpoint if 'model' is not provided. 138 model: Pre-loaded model. Either model_path or model is required. 139 tiling: The tiling configuration for the prediction. 140 verbose: Whether to print timing information. 141 with_channels: Whether to predict with channels. 142 channels_to_standardize: List of channels to standardize. Defaults to None. 143 mask: Optional binary mask. If given, the prediction will only be run in 144 the foreground region of the mask. 145 prediction: An array like object for writing the prediction. 146 If not given, the prediction will be computed in moemory. 147 devices: The devices for running prediction. If not given will use the GPU 148 if available, otherwise the CPU. 149 150 Returns: 151 The predicted volume. 152 """ 153 # make sure either model path or model is passed 154 if model is None and model_path is None: 155 raise ValueError("Either 'model_path' or 'model' must be provided.") 156 157 if model is not None: 158 is_bioimageio = None 159 else: 160 is_bioimageio = model_path.endswith(".zip") 161 162 if tiling is None: 163 tiling = get_default_tiling() 164 165 # Normalize the whole input volume if it is a numpy array. 166 # Otherwise we have a zarr array or similar as input, and can't normalize it en-block. 167 # Normalization will be applied later per block in this case. 168 if isinstance(input_volume, np.ndarray): 169 input_volume = _preprocess(input_volume, with_channels, channels_to_standardize) 170 171 # Run prediction with the bioimage.io library. 172 if is_bioimageio: 173 if mask is not None: 174 raise NotImplementedError 175 raise NotImplementedError 176 177 # Run prediction with the torch-em library. 178 else: 179 if model is None: 180 # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself. 181 if model_path.endswith("best.pt"): 182 model_path = os.path.split(model_path)[0] 183 # print(f"tiling {tiling}") 184 # Create updated_tiling with the same structure 185 updated_tiling = { 186 "tile": {}, 187 "halo": tiling["halo"] # Keep the halo part unchanged 188 } 189 # Update tile dimensions 190 for dim in tiling["tile"]: 191 updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim] 192 # print(f"updated_tiling {updated_tiling}") 193 prediction = get_prediction_torch_em( 194 input_volume, updated_tiling, model_path, model, verbose, with_channels, 195 mask=mask, prediction=prediction, devices=devices, 196 ) 197 198 return prediction 199 200 201def get_prediction_torch_em( 202 input_volume: ArrayLike, # [z, y, x] 203 tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 204 model_path: Optional[str] = None, 205 model: Optional[torch.nn.Module] = None, 206 verbose: bool = True, 207 with_channels: bool = False, 208 mask: Optional[ArrayLike] = None, 209 prediction: Optional[ArrayLike] = None, 210 devices: Optional[List[str]] = None, 211) -> np.ndarray: 212 """Run prediction using torch-em on a given volume. 213 214 Args: 215 input_volume: The input volume to predict on. 216 model_path: The path to the model checkpoint if 'model' is not provided. 217 model: Pre-loaded model. Either model_path or model is required. 218 tiling: The tiling configuration for the prediction. 219 verbose: Whether to print timing information. 220 with_channels: Whether to predict with channels. 221 mask: Optional binary mask. If given, the prediction will only be run in 222 the foreground region of the mask. 223 prediction: An array like object for writing the prediction. 224 If not given, the prediction will be computed in moemory. 225 devices: The devices for running prediction. If not given will use the GPU 226 if available, otherwise the CPU. 227 228 Returns: 229 The predicted volume. 230 """ 231 # get block_shape and halo 232 block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]] 233 halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]] 234 235 t0 = time.time() 236 if devices is None: 237 devices = ["cuda" if torch.cuda.is_available() else "cpu"] 238 239 # Suppress warning when loading the model. 240 with warnings.catch_warnings(): 241 warnings.simplefilter("ignore") 242 if model is None: 243 if os.path.isdir(model_path): # Load the model from a torch_em checkpoint. 244 model = torch_em.util.load_model(checkpoint=model_path, device=devices[0]) 245 else: # Load the model directly from a serialized pytorch model. 246 model = torch.load(model_path, weights_only=False) 247 248 # Run prediction with the model. 249 with torch.no_grad(): 250 251 # Deal with 2D segmentation case 252 if len(input_volume.shape) == 2: 253 block_shape = [block_shape[1], block_shape[2]] 254 halo = [halo[1], halo[2]] 255 256 if mask is not None: 257 if verbose: 258 print("Run prediction with mask.") 259 mask = mask.astype("bool") 260 261 preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize 262 prediction = predict_with_halo( 263 input_volume, model, gpu_ids=devices, 264 block_shape=block_shape, halo=halo, 265 preprocess=preprocess, with_channels=with_channels, mask=mask, 266 output=prediction, 267 ) 268 if verbose: 269 print("Prediction time in", time.time() - t0, "s") 270 return prediction 271 272 273def _get_file_paths(input_path, ext=".mrc"): 274 if not os.path.exists(input_path): 275 raise Exception(f"Input path not found {input_path}") 276 277 if os.path.isfile(input_path): 278 input_files = [input_path] 279 input_root = None 280 else: 281 input_files = sorted(glob(os.path.join(input_path, "**", f"*{ext}"), recursive=True)) 282 input_root = input_path 283 284 return input_files, input_root 285 286 287def _load_input(img_path, extra_files, i): 288 # Load the input data. 289 if os.path.splitext(img_path)[-1] == ".tif": 290 input_volume = imageio.imread(img_path) 291 292 else: 293 with open_file(img_path, "r") as f: 294 # Try to automatically derive the key with the raw data. 295 keys = list(f.keys()) 296 if len(keys) == 1: 297 key = keys[0] 298 elif "data" in keys: 299 key = "data" 300 elif "raw" in keys: 301 key = "raw" 302 input_volume = f[key][:] 303 304 assert input_volume.ndim in (2, 3) 305 # For now we assume this is always tif. 306 if extra_files is not None: 307 extra_input = imageio.imread(extra_files[i]) 308 assert extra_input.shape == input_volume.shape 309 input_volume = np.stack([input_volume, extra_input], axis=0) 310 311 return input_volume 312 313 314def _derive_scale(img_path, model_resolution): 315 try: 316 with mrcfile.open(img_path, "r") as f: 317 voxel_size = f.voxel_size 318 if len(model_resolution) == 2: 319 voxel_size = [voxel_size.y, voxel_size.x] 320 else: 321 voxel_size = [voxel_size.z, voxel_size.y, voxel_size.x] 322 323 assert len(voxel_size) == len(model_resolution) 324 # The voxel size is given in Angstrom and we need to translate it to nanometer. 325 voxel_size = [vsize / 10 for vsize in voxel_size] 326 327 # Compute the correct scale factor. 328 scale = tuple(vsize / res for vsize, res in zip(voxel_size, model_resolution)) 329 print("Rescaling the data at", img_path, "by", scale, "to match the training voxel size", model_resolution) 330 331 except Exception: 332 warnings.warn( 333 f"The voxel size could not be read from the data for {img_path}. " 334 "This data will not be scaled for prediction." 335 ) 336 scale = None 337 338 return scale 339 340 341def inference_helper( 342 input_path: str, 343 output_root: str, 344 segmentation_function: callable, 345 data_ext: str = ".mrc", 346 extra_input_path: Optional[str] = None, 347 extra_input_ext: str = ".tif", 348 mask_input_path: Optional[str] = None, 349 mask_input_ext: str = ".tif", 350 force: bool = False, 351 output_key: Optional[str] = None, 352 model_resolution: Optional[Tuple[float, float, float]] = None, 353 scale: Optional[Tuple[float, float, float]] = None, 354 allocate_output: bool = False, 355) -> None: 356 """Helper function to run segmentation for mrc files. 357 358 Args: 359 input_path: The path to the input data. 360 Can either be a folder. In this case all mrc files below the folder will be segmented. 361 Or can be a single mrc file. In this case only this mrc file will be segmented. 362 output_root: The path to the output directory where the segmentation results will be saved. 363 segmentation_function: The function performing the segmentation. 364 This function must take the input_volume as the only argument and must return only the segmentation. 365 If you want to pass additional arguments to this function the use 'funtools.partial' 366 data_ext: File extension for the image data. By default '.mrc' is used. 367 extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. 368 This enables cristae segmentation with an extra mito channel. 369 extra_input_ext: File extension for the extra inputs (by default .tif). 370 mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation. 371 mask_input_ext: File extension for the mask inputs (by default .tif). 372 force: Whether to rerun segmentation for output files that are already present. 373 output_key: Output key for the prediction. If none will write an hdf5 file. 374 model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. 375 If given, the scaling factor will automatically be determined based on the voxel_size of the input data. 376 scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'. 377 allocate_output: Whether to allocate the output for the segmentation function. 378 """ 379 if (scale is not None) and (model_resolution is not None): 380 raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.") 381 382 # Get the input files. If input_path is a folder then this will load all 383 # the mrc files beneath it. Otherwise we assume this is an mrc file already 384 # and just return the path to this mrc file. 385 input_files, input_root = _get_file_paths(input_path, data_ext) 386 387 # Load extra inputs if the extra_input_path was specified. 388 if extra_input_path is None: 389 extra_files = None 390 else: 391 extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext) 392 assert len(input_files) == len(extra_files) 393 394 # Load the masks if they were specified. 395 if mask_input_path is None: 396 mask_files = None 397 else: 398 mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext) 399 assert len(input_files) == len(mask_files) 400 401 for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"): 402 # Determine the output file name. 403 input_folder, input_name = os.path.split(img_path) 404 405 if output_key is None: 406 fname = os.path.splitext(input_name)[0] + "_prediction.tif" 407 else: 408 fname = os.path.splitext(input_name)[0] + "_prediction.h5" 409 410 if input_root is None: 411 output_path = os.path.join(output_root, fname) 412 else: # If we have nested input folders then we preserve the folder structure in the output. 413 rel_folder = os.path.relpath(input_folder, input_root) 414 output_path = os.path.join(output_root, rel_folder, fname) 415 416 # Check if the output path is already present. 417 # If it is we skip the prediction, unless force was set to true. 418 if os.path.exists(output_path) and not force: 419 if output_key is None: 420 continue 421 else: 422 with open_file(output_path, "r") as f: 423 if output_key in f: 424 continue 425 426 # Load the input volume. If we have extra_files then this concatenates the 427 # data across a new first axis (= channel axis). 428 input_volume = _load_input(img_path, extra_files, i) 429 # Load the mask (if given). 430 mask = None if mask_files is None else imageio.imread(mask_files[i]) 431 432 # Determine the scale factor: 433 # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None. 434 if scale is None and model_resolution is None: 435 this_scale = None 436 elif scale is not None: # If 'scale' was passed then use it. 437 this_scale = scale 438 else: # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data 439 assert model_resolution is not None 440 this_scale = _derive_scale(img_path, model_resolution) 441 442 # Run the segmentation. 443 if allocate_output: 444 segmentation = np.zeros(input_volume.shape, dtype="uint32") 445 segmentation_function(input_volume, output=segmentation, mask=mask, scale=this_scale) 446 else: 447 segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale) 448 449 # Write the result to tif or h5. 450 os.makedirs(os.path.split(output_path)[0], exist_ok=True) 451 452 if output_key is None: 453 imageio.imwrite(output_path, segmentation, compression="zlib") 454 else: 455 with open_file(output_path, "a") as f: 456 f.create_dataset(output_key, data=segmentation, compression="gzip") 457 458 print(f"Saved segmentation to {output_path}.") 459 460 461def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: 462 """Determine the tile shape and halo depending on the available VRAM. 463 464 Args: 465 is_2d: Whether to return tiling settings for 2d inference. 466 467 Returns: 468 The default tiling settings for the available computational resources. 469 """ 470 if is_2d: 471 tile = {"x": 768, "y": 768, "z": 1} 472 halo = {"x": 128, "y": 128, "z": 0} 473 return {"tile": tile, "halo": halo} 474 475 if torch.cuda.is_available(): 476 # The default halo size. 477 halo = {"x": 64, "y": 64, "z": 16} 478 479 # Determine the GPU RAM and derive a suitable tiling. 480 vram = torch.cuda.get_device_properties(0).total_memory / 1e9 481 482 if vram >= 80: 483 tile = {"x": 640, "y": 640, "z": 80} 484 elif vram >= 40: 485 tile = {"x": 512, "y": 512, "z": 64} 486 elif vram >= 20: 487 tile = {"x": 352, "y": 352, "z": 48} 488 elif vram >= 10: 489 tile = {"x": 256, "y": 256, "z": 32} 490 halo = {"x": 64, "y": 64, "z": 8} # Choose a smaller halo in z. 491 else: 492 raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.") 493 494 tiling = {"tile": tile, "halo": halo} 495 print(f"Determined tile size for CUDA: {tiling}") 496 497 elif torch.backends.mps.is_available(): # Check for Apple Silicon (MPS) 498 tile = {"x": 256, "y": 256, "z": 16} 499 halo = {"x": 16, "y": 16, "z": 4} 500 tiling = {"tile": tile, "halo": halo} 501 print(f"Determined tile size for MPS: {tiling}") 502 503 # I am not sure what is reasonable on a cpu. For now choosing very small tiling. 504 # (This will not work well on a CPU in any case.) 505 else: 506 tiling = { 507 "tile": {"x": 96, "y": 96, "z": 16}, 508 "halo": {"x": 16, "y": 16, "z": 4}, 509 } 510 print(f"Determining default tiling for CPU: {tiling}") 511 512 return tiling 513 514 515def parse_tiling( 516 tile_shape: Tuple[int, int, int], 517 halo: Tuple[int, int, int], 518 is_2d: bool = False, 519) -> Dict[str, Dict[str, int]]: 520 """Helper function to parse tiling parameter input from the command line. 521 522 Args: 523 tile_shape: The tile shape. If None the default tile shape is used. 524 halo: The halo. If None the default halo is used. 525 is_2d: Whether to return tiling for a 2d model. 526 527 Returns: 528 The tiling specification. 529 """ 530 531 default_tiling = get_default_tiling(is_2d=is_2d) 532 533 if tile_shape is None: 534 tile_shape = default_tiling["tile"] 535 else: 536 assert len(tile_shape) == 3 537 tile_shape = dict(zip("zyx", tile_shape)) 538 539 if halo is None: 540 halo = default_tiling["halo"] 541 else: 542 assert len(halo) == 3 543 halo = dict(zip("zyx", halo)) 544 545 tiling = {"tile": tile_shape, "halo": halo} 546 return tiling 547 548 549# 550# Utils for post-processing. 551# 552 553 554def apply_size_filter( 555 segmentation: np.ndarray, 556 min_size: int, 557 verbose: bool = False, 558 block_shape: Tuple[int, int, int] = (128, 256, 256), 559) -> np.ndarray: 560 """Apply size filter to the segmentation to remove small objects. 561 562 Args: 563 segmentation: The segmentation. 564 min_size: The minimal object size in pixels. 565 verbose: Whether to print runtimes. 566 block_shape: Block shape for parallelizing the operations. 567 568 Returns: 569 The size filtered segmentation. 570 """ 571 if min_size == 0: 572 return segmentation 573 t0 = time.time() 574 if segmentation.ndim == 2 and len(block_shape) == 3: 575 block_shape_ = block_shape[1:] 576 else: 577 block_shape_ = block_shape 578 ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose) 579 filter_ids = ids[sizes < min_size] 580 segmentation[np.isin(segmentation, filter_ids)] = 0 581 if verbose: 582 print("Size filter in", time.time() - t0, "s") 583 return segmentation 584 585 586def _postprocess_seg_3d(seg, area_threshold=1000, iterations=4, iterations_3d=8): 587 # Structure lement for 2d dilation in 3d. 588 structure_element = np.ones((3, 3)) # 3x3 structure for XY plane 589 structure_3d = np.zeros((1, 3, 3)) # Only applied in the XY plane 590 structure_3d[0] = structure_element 591 592 props = regionprops(seg) 593 for prop in props: 594 # Get bounding box and mask. 595 bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:3], prop.bbox[3:])) 596 mask = seg[bb] == prop.label 597 598 # Fill small holes and apply closing. 599 mask = remove_small_holes(mask, area_threshold=area_threshold) 600 mask = np.logical_or(binary_closing(mask, iterations=iterations), mask) 601 mask = np.logical_or(binary_closing(mask, iterations=iterations_3d, structure=structure_3d), mask) 602 seg[bb][mask] = prop.label 603 604 return seg 605 606 607# 608# Utils for torch device. 609# 610 611def _get_default_device(): 612 # Check that we're in CI and use the CPU if we are. 613 # Otherwise the tests may run out of memory on MAC if MPS is used. 614 if os.getenv("GITHUB_ACTIONS") == "true": 615 return "cpu" 616 # Use cuda enabled gpu if it's available. 617 if torch.cuda.is_available(): 618 device = "cuda" 619 # As second priority use mps. 620 # See https://pytorch.org/docs/stable/notes/mps.html for details 621 elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): 622 device = "mps" 623 # Use the CPU as fallback. 624 else: 625 device = "cpu" 626 return device 627 628 629def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 630 """Get the torch device. 631 632 If no device is passed the default device for your system is used. 633 Else it will be checked if the device you have passed is supported. 634 635 Args: 636 device: The input device. 637 638 Returns: 639 The device. 640 """ 641 if device is None or device == "auto": 642 device = _get_default_device() 643 else: 644 device_type = device if isinstance(device, str) else device.type 645 if device_type.lower() == "cuda": 646 if not torch.cuda.is_available(): 647 raise RuntimeError("PyTorch CUDA backend is not available.") 648 elif device_type.lower() == "mps": 649 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 650 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 651 elif device_type.lower() == "cpu": 652 pass # cpu is always available 653 else: 654 raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.") 655 return device
119def get_prediction( 120 input_volume: ArrayLike, # [z, y, x] 121 tiling: Optional[Dict[str, Dict[str, int]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 122 model_path: Optional[str] = None, 123 model: Optional[torch.nn.Module] = None, 124 verbose: bool = True, 125 with_channels: bool = False, 126 channels_to_standardize: Optional[List[int]] = None, 127 mask: Optional[ArrayLike] = None, 128 prediction: Optional[ArrayLike] = None, 129 devices: Optional[List[str]] = None, 130) -> ArrayLike: 131 """Run prediction on a given volume. 132 133 This function will automatically choose the correct prediction implementation, 134 depending on the model type. 135 136 Args: 137 input_volume: The input volume to predict on. 138 model_path: The path to the model checkpoint if 'model' is not provided. 139 model: Pre-loaded model. Either model_path or model is required. 140 tiling: The tiling configuration for the prediction. 141 verbose: Whether to print timing information. 142 with_channels: Whether to predict with channels. 143 channels_to_standardize: List of channels to standardize. Defaults to None. 144 mask: Optional binary mask. If given, the prediction will only be run in 145 the foreground region of the mask. 146 prediction: An array like object for writing the prediction. 147 If not given, the prediction will be computed in moemory. 148 devices: The devices for running prediction. If not given will use the GPU 149 if available, otherwise the CPU. 150 151 Returns: 152 The predicted volume. 153 """ 154 # make sure either model path or model is passed 155 if model is None and model_path is None: 156 raise ValueError("Either 'model_path' or 'model' must be provided.") 157 158 if model is not None: 159 is_bioimageio = None 160 else: 161 is_bioimageio = model_path.endswith(".zip") 162 163 if tiling is None: 164 tiling = get_default_tiling() 165 166 # Normalize the whole input volume if it is a numpy array. 167 # Otherwise we have a zarr array or similar as input, and can't normalize it en-block. 168 # Normalization will be applied later per block in this case. 169 if isinstance(input_volume, np.ndarray): 170 input_volume = _preprocess(input_volume, with_channels, channels_to_standardize) 171 172 # Run prediction with the bioimage.io library. 173 if is_bioimageio: 174 if mask is not None: 175 raise NotImplementedError 176 raise NotImplementedError 177 178 # Run prediction with the torch-em library. 179 else: 180 if model is None: 181 # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself. 182 if model_path.endswith("best.pt"): 183 model_path = os.path.split(model_path)[0] 184 # print(f"tiling {tiling}") 185 # Create updated_tiling with the same structure 186 updated_tiling = { 187 "tile": {}, 188 "halo": tiling["halo"] # Keep the halo part unchanged 189 } 190 # Update tile dimensions 191 for dim in tiling["tile"]: 192 updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim] 193 # print(f"updated_tiling {updated_tiling}") 194 prediction = get_prediction_torch_em( 195 input_volume, updated_tiling, model_path, model, verbose, with_channels, 196 mask=mask, prediction=prediction, devices=devices, 197 ) 198 199 return prediction
Run prediction on a given volume.
This function will automatically choose the correct prediction implementation, depending on the model type.
Arguments:
- input_volume: The input volume to predict on.
- model_path: The path to the model checkpoint if 'model' is not provided.
- model: Pre-loaded model. Either model_path or model is required.
- tiling: The tiling configuration for the prediction.
- verbose: Whether to print timing information.
- with_channels: Whether to predict with channels.
- channels_to_standardize: List of channels to standardize. Defaults to None.
- mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
- prediction: An array like object for writing the prediction. If not given, the prediction will be computed in moemory.
- devices: The devices for running prediction. If not given will use the GPU if available, otherwise the CPU.
Returns:
The predicted volume.
202def get_prediction_torch_em( 203 input_volume: ArrayLike, # [z, y, x] 204 tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 205 model_path: Optional[str] = None, 206 model: Optional[torch.nn.Module] = None, 207 verbose: bool = True, 208 with_channels: bool = False, 209 mask: Optional[ArrayLike] = None, 210 prediction: Optional[ArrayLike] = None, 211 devices: Optional[List[str]] = None, 212) -> np.ndarray: 213 """Run prediction using torch-em on a given volume. 214 215 Args: 216 input_volume: The input volume to predict on. 217 model_path: The path to the model checkpoint if 'model' is not provided. 218 model: Pre-loaded model. Either model_path or model is required. 219 tiling: The tiling configuration for the prediction. 220 verbose: Whether to print timing information. 221 with_channels: Whether to predict with channels. 222 mask: Optional binary mask. If given, the prediction will only be run in 223 the foreground region of the mask. 224 prediction: An array like object for writing the prediction. 225 If not given, the prediction will be computed in moemory. 226 devices: The devices for running prediction. If not given will use the GPU 227 if available, otherwise the CPU. 228 229 Returns: 230 The predicted volume. 231 """ 232 # get block_shape and halo 233 block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]] 234 halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]] 235 236 t0 = time.time() 237 if devices is None: 238 devices = ["cuda" if torch.cuda.is_available() else "cpu"] 239 240 # Suppress warning when loading the model. 241 with warnings.catch_warnings(): 242 warnings.simplefilter("ignore") 243 if model is None: 244 if os.path.isdir(model_path): # Load the model from a torch_em checkpoint. 245 model = torch_em.util.load_model(checkpoint=model_path, device=devices[0]) 246 else: # Load the model directly from a serialized pytorch model. 247 model = torch.load(model_path, weights_only=False) 248 249 # Run prediction with the model. 250 with torch.no_grad(): 251 252 # Deal with 2D segmentation case 253 if len(input_volume.shape) == 2: 254 block_shape = [block_shape[1], block_shape[2]] 255 halo = [halo[1], halo[2]] 256 257 if mask is not None: 258 if verbose: 259 print("Run prediction with mask.") 260 mask = mask.astype("bool") 261 262 preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize 263 prediction = predict_with_halo( 264 input_volume, model, gpu_ids=devices, 265 block_shape=block_shape, halo=halo, 266 preprocess=preprocess, with_channels=with_channels, mask=mask, 267 output=prediction, 268 ) 269 if verbose: 270 print("Prediction time in", time.time() - t0, "s") 271 return prediction
Run prediction using torch-em on a given volume.
Arguments:
- input_volume: The input volume to predict on.
- model_path: The path to the model checkpoint if 'model' is not provided.
- model: Pre-loaded model. Either model_path or model is required.
- tiling: The tiling configuration for the prediction.
- verbose: Whether to print timing information.
- with_channels: Whether to predict with channels.
- mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
- prediction: An array like object for writing the prediction. If not given, the prediction will be computed in moemory.
- devices: The devices for running prediction. If not given will use the GPU if available, otherwise the CPU.
Returns:
The predicted volume.
342def inference_helper( 343 input_path: str, 344 output_root: str, 345 segmentation_function: callable, 346 data_ext: str = ".mrc", 347 extra_input_path: Optional[str] = None, 348 extra_input_ext: str = ".tif", 349 mask_input_path: Optional[str] = None, 350 mask_input_ext: str = ".tif", 351 force: bool = False, 352 output_key: Optional[str] = None, 353 model_resolution: Optional[Tuple[float, float, float]] = None, 354 scale: Optional[Tuple[float, float, float]] = None, 355 allocate_output: bool = False, 356) -> None: 357 """Helper function to run segmentation for mrc files. 358 359 Args: 360 input_path: The path to the input data. 361 Can either be a folder. In this case all mrc files below the folder will be segmented. 362 Or can be a single mrc file. In this case only this mrc file will be segmented. 363 output_root: The path to the output directory where the segmentation results will be saved. 364 segmentation_function: The function performing the segmentation. 365 This function must take the input_volume as the only argument and must return only the segmentation. 366 If you want to pass additional arguments to this function the use 'funtools.partial' 367 data_ext: File extension for the image data. By default '.mrc' is used. 368 extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. 369 This enables cristae segmentation with an extra mito channel. 370 extra_input_ext: File extension for the extra inputs (by default .tif). 371 mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation. 372 mask_input_ext: File extension for the mask inputs (by default .tif). 373 force: Whether to rerun segmentation for output files that are already present. 374 output_key: Output key for the prediction. If none will write an hdf5 file. 375 model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. 376 If given, the scaling factor will automatically be determined based on the voxel_size of the input data. 377 scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'. 378 allocate_output: Whether to allocate the output for the segmentation function. 379 """ 380 if (scale is not None) and (model_resolution is not None): 381 raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.") 382 383 # Get the input files. If input_path is a folder then this will load all 384 # the mrc files beneath it. Otherwise we assume this is an mrc file already 385 # and just return the path to this mrc file. 386 input_files, input_root = _get_file_paths(input_path, data_ext) 387 388 # Load extra inputs if the extra_input_path was specified. 389 if extra_input_path is None: 390 extra_files = None 391 else: 392 extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext) 393 assert len(input_files) == len(extra_files) 394 395 # Load the masks if they were specified. 396 if mask_input_path is None: 397 mask_files = None 398 else: 399 mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext) 400 assert len(input_files) == len(mask_files) 401 402 for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"): 403 # Determine the output file name. 404 input_folder, input_name = os.path.split(img_path) 405 406 if output_key is None: 407 fname = os.path.splitext(input_name)[0] + "_prediction.tif" 408 else: 409 fname = os.path.splitext(input_name)[0] + "_prediction.h5" 410 411 if input_root is None: 412 output_path = os.path.join(output_root, fname) 413 else: # If we have nested input folders then we preserve the folder structure in the output. 414 rel_folder = os.path.relpath(input_folder, input_root) 415 output_path = os.path.join(output_root, rel_folder, fname) 416 417 # Check if the output path is already present. 418 # If it is we skip the prediction, unless force was set to true. 419 if os.path.exists(output_path) and not force: 420 if output_key is None: 421 continue 422 else: 423 with open_file(output_path, "r") as f: 424 if output_key in f: 425 continue 426 427 # Load the input volume. If we have extra_files then this concatenates the 428 # data across a new first axis (= channel axis). 429 input_volume = _load_input(img_path, extra_files, i) 430 # Load the mask (if given). 431 mask = None if mask_files is None else imageio.imread(mask_files[i]) 432 433 # Determine the scale factor: 434 # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None. 435 if scale is None and model_resolution is None: 436 this_scale = None 437 elif scale is not None: # If 'scale' was passed then use it. 438 this_scale = scale 439 else: # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data 440 assert model_resolution is not None 441 this_scale = _derive_scale(img_path, model_resolution) 442 443 # Run the segmentation. 444 if allocate_output: 445 segmentation = np.zeros(input_volume.shape, dtype="uint32") 446 segmentation_function(input_volume, output=segmentation, mask=mask, scale=this_scale) 447 else: 448 segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale) 449 450 # Write the result to tif or h5. 451 os.makedirs(os.path.split(output_path)[0], exist_ok=True) 452 453 if output_key is None: 454 imageio.imwrite(output_path, segmentation, compression="zlib") 455 else: 456 with open_file(output_path, "a") as f: 457 f.create_dataset(output_key, data=segmentation, compression="gzip") 458 459 print(f"Saved segmentation to {output_path}.")
Helper function to run segmentation for mrc files.
Arguments:
- input_path: The path to the input data. Can either be a folder. In this case all mrc files below the folder will be segmented. Or can be a single mrc file. In this case only this mrc file will be segmented.
- output_root: The path to the output directory where the segmentation results will be saved.
- segmentation_function: The function performing the segmentation. This function must take the input_volume as the only argument and must return only the segmentation. If you want to pass additional arguments to this function the use 'funtools.partial'
- data_ext: File extension for the image data. By default '.mrc' is used.
- extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. This enables cristae segmentation with an extra mito channel.
- extra_input_ext: File extension for the extra inputs (by default .tif).
- mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
- mask_input_ext: File extension for the mask inputs (by default .tif).
- force: Whether to rerun segmentation for output files that are already present.
- output_key: Output key for the prediction. If none will write an hdf5 file.
- model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
- scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
- allocate_output: Whether to allocate the output for the segmentation function.
462def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: 463 """Determine the tile shape and halo depending on the available VRAM. 464 465 Args: 466 is_2d: Whether to return tiling settings for 2d inference. 467 468 Returns: 469 The default tiling settings for the available computational resources. 470 """ 471 if is_2d: 472 tile = {"x": 768, "y": 768, "z": 1} 473 halo = {"x": 128, "y": 128, "z": 0} 474 return {"tile": tile, "halo": halo} 475 476 if torch.cuda.is_available(): 477 # The default halo size. 478 halo = {"x": 64, "y": 64, "z": 16} 479 480 # Determine the GPU RAM and derive a suitable tiling. 481 vram = torch.cuda.get_device_properties(0).total_memory / 1e9 482 483 if vram >= 80: 484 tile = {"x": 640, "y": 640, "z": 80} 485 elif vram >= 40: 486 tile = {"x": 512, "y": 512, "z": 64} 487 elif vram >= 20: 488 tile = {"x": 352, "y": 352, "z": 48} 489 elif vram >= 10: 490 tile = {"x": 256, "y": 256, "z": 32} 491 halo = {"x": 64, "y": 64, "z": 8} # Choose a smaller halo in z. 492 else: 493 raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.") 494 495 tiling = {"tile": tile, "halo": halo} 496 print(f"Determined tile size for CUDA: {tiling}") 497 498 elif torch.backends.mps.is_available(): # Check for Apple Silicon (MPS) 499 tile = {"x": 256, "y": 256, "z": 16} 500 halo = {"x": 16, "y": 16, "z": 4} 501 tiling = {"tile": tile, "halo": halo} 502 print(f"Determined tile size for MPS: {tiling}") 503 504 # I am not sure what is reasonable on a cpu. For now choosing very small tiling. 505 # (This will not work well on a CPU in any case.) 506 else: 507 tiling = { 508 "tile": {"x": 96, "y": 96, "z": 16}, 509 "halo": {"x": 16, "y": 16, "z": 4}, 510 } 511 print(f"Determining default tiling for CPU: {tiling}") 512 513 return tiling
Determine the tile shape and halo depending on the available VRAM.
Arguments:
- is_2d: Whether to return tiling settings for 2d inference.
Returns:
The default tiling settings for the available computational resources.
516def parse_tiling( 517 tile_shape: Tuple[int, int, int], 518 halo: Tuple[int, int, int], 519 is_2d: bool = False, 520) -> Dict[str, Dict[str, int]]: 521 """Helper function to parse tiling parameter input from the command line. 522 523 Args: 524 tile_shape: The tile shape. If None the default tile shape is used. 525 halo: The halo. If None the default halo is used. 526 is_2d: Whether to return tiling for a 2d model. 527 528 Returns: 529 The tiling specification. 530 """ 531 532 default_tiling = get_default_tiling(is_2d=is_2d) 533 534 if tile_shape is None: 535 tile_shape = default_tiling["tile"] 536 else: 537 assert len(tile_shape) == 3 538 tile_shape = dict(zip("zyx", tile_shape)) 539 540 if halo is None: 541 halo = default_tiling["halo"] 542 else: 543 assert len(halo) == 3 544 halo = dict(zip("zyx", halo)) 545 546 tiling = {"tile": tile_shape, "halo": halo} 547 return tiling
Helper function to parse tiling parameter input from the command line.
Arguments:
- tile_shape: The tile shape. If None the default tile shape is used.
- halo: The halo. If None the default halo is used.
- is_2d: Whether to return tiling for a 2d model.
Returns:
The tiling specification.
555def apply_size_filter( 556 segmentation: np.ndarray, 557 min_size: int, 558 verbose: bool = False, 559 block_shape: Tuple[int, int, int] = (128, 256, 256), 560) -> np.ndarray: 561 """Apply size filter to the segmentation to remove small objects. 562 563 Args: 564 segmentation: The segmentation. 565 min_size: The minimal object size in pixels. 566 verbose: Whether to print runtimes. 567 block_shape: Block shape for parallelizing the operations. 568 569 Returns: 570 The size filtered segmentation. 571 """ 572 if min_size == 0: 573 return segmentation 574 t0 = time.time() 575 if segmentation.ndim == 2 and len(block_shape) == 3: 576 block_shape_ = block_shape[1:] 577 else: 578 block_shape_ = block_shape 579 ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose) 580 filter_ids = ids[sizes < min_size] 581 segmentation[np.isin(segmentation, filter_ids)] = 0 582 if verbose: 583 print("Size filter in", time.time() - t0, "s") 584 return segmentation
Apply size filter to the segmentation to remove small objects.
Arguments:
- segmentation: The segmentation.
- min_size: The minimal object size in pixels.
- verbose: Whether to print runtimes.
- block_shape: Block shape for parallelizing the operations.
Returns:
The size filtered segmentation.
630def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 631 """Get the torch device. 632 633 If no device is passed the default device for your system is used. 634 Else it will be checked if the device you have passed is supported. 635 636 Args: 637 device: The input device. 638 639 Returns: 640 The device. 641 """ 642 if device is None or device == "auto": 643 device = _get_default_device() 644 else: 645 device_type = device if isinstance(device, str) else device.type 646 if device_type.lower() == "cuda": 647 if not torch.cuda.is_available(): 648 raise RuntimeError("PyTorch CUDA backend is not available.") 649 elif device_type.lower() == "mps": 650 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 651 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 652 elif device_type.lower() == "cpu": 653 pass # cpu is always available 654 else: 655 raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.") 656 return device
Get the torch device.
If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.
Arguments:
- device: The input device.
Returns:
The device.