synapse_net.inference.util
1import os 2import time 3import warnings 4from glob import glob 5from typing import Dict, List, Optional, Tuple, Union 6 7# # Suppress annoying import warnings. 8# with warnings.catch_warnings(): 9# warnings.simplefilter("ignore") 10# import bioimageio.core 11 12import imageio.v3 as imageio 13import elf.parallel as parallel 14import mrcfile 15import numpy as np 16import torch 17import torch_em 18# import xarray 19 20from elf.io import open_file 21from scipy.ndimage import binary_closing 22from skimage.measure import regionprops 23from skimage.morphology import remove_small_holes 24from skimage.transform import rescale, resize 25from torch_em.util.prediction import predict_with_halo 26from tqdm import tqdm 27 28 29# 30# Utils for prediction. 31# 32 33 34class _Scaler: 35 def __init__(self, scale, verbose): 36 self.verbose = verbose 37 self._original_shape = None 38 39 if scale is None: 40 self.scale = None 41 return 42 43 # Convert scale to a NumPy array (ensures consistency) 44 scale = np.atleast_1d(scale).astype(np.float64) 45 46 # Validate scale values 47 if not np.issubdtype(scale.dtype, np.number): 48 raise TypeError(f"Scale contains non-numeric values: {scale}") 49 50 # Check if scaling is effectively identity (1.0 in all dimensions) 51 if np.allclose(scale, 1.0, atol=1e-3): 52 self.scale = None 53 else: 54 self.scale = scale 55 56 def scale_input(self, input_volume, is_segmentation=False): 57 t0 = time.time() 58 if self.scale is None: 59 return input_volume 60 61 if self._original_shape is None: 62 self._original_shape = input_volume.shape 63 elif self._original_shape != input_volume.shape: 64 raise RuntimeError( 65 "Scaler was called with different input shapes. " 66 "This is not supported, please create a new instance of the class for it." 67 ) 68 69 if is_segmentation: 70 input_volume = rescale( 71 input_volume, self.scale, preserve_range=True, order=0, anti_aliasing=False, 72 ).astype(input_volume.dtype) 73 else: 74 input_volume = rescale(input_volume, self.scale, preserve_range=True).astype(input_volume.dtype) 75 76 if self.verbose: 77 print("Rescaled volume from", self._original_shape, "to", input_volume.shape, "in", time.time() - t0, "s") 78 return input_volume 79 80 def rescale_output(self, output, is_segmentation): 81 t0 = time.time() 82 if self.scale is None: 83 return output 84 85 assert self._original_shape is not None 86 out_shape = self._original_shape 87 if output.ndim > len(out_shape): 88 assert output.ndim == len(out_shape) + 1 89 out_shape = (output.shape[0],) + out_shape 90 91 if is_segmentation: 92 output = resize(output, out_shape, preserve_range=True, order=0, anti_aliasing=False).astype(output.dtype) 93 else: 94 output = resize(output, out_shape, preserve_range=True).astype(output.dtype) 95 96 if self.verbose: 97 print("Resized prediction back to original shape", output.shape, "in", time.time() - t0, "s") 98 99 return output 100 101 102def get_prediction( 103 input_volume: np.ndarray, # [z, y, x] 104 tiling: Optional[Dict[str, Dict[str, int]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 105 model_path: Optional[str] = None, 106 model: Optional[torch.nn.Module] = None, 107 verbose: bool = True, 108 with_channels: bool = False, 109 channels_to_standardize: Optional[List[int]] = None, 110 mask: Optional[np.ndarray] = None, 111) -> np.ndarray: 112 """Run prediction on a given volume. 113 114 This function will automatically choose the correct prediction implementation, 115 depending on the model type. 116 117 Args: 118 input_volume: The input volume to predict on. 119 model_path: The path to the model checkpoint if 'model' is not provided. 120 model: Pre-loaded model. Either model_path or model is required. 121 tiling: The tiling configuration for the prediction. 122 verbose: Whether to print timing information. 123 with_channels: Whether to predict with channels. 124 channels_to_standardize: List of channels to standardize. Defaults to None. 125 mask: Optional binary mask. If given, the prediction will only be run in 126 the foreground region of the mask. 127 128 Returns: 129 The predicted volume. 130 """ 131 # make sure either model path or model is passed 132 if model is None and model_path is None: 133 raise ValueError("Either 'model_path' or 'model' must be provided.") 134 135 if model is not None: 136 is_bioimageio = None 137 else: 138 is_bioimageio = model_path.endswith(".zip") 139 140 if tiling is None: 141 tiling = get_default_tiling() 142 143 # We standardize the data for the whole volume beforehand. 144 # If we have channels then the standardization is done independently per channel. 145 if with_channels: 146 input_volume = input_volume.astype(np.float32, copy=False) 147 # TODO Check that this is the correct axis. 148 if channels_to_standardize is None: # assume all channels 149 channels_to_standardize = range(input_volume.shape[0]) 150 for ch in channels_to_standardize: 151 input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch]) 152 else: 153 input_volume = torch_em.transform.raw.standardize(input_volume) 154 155 # Run prediction with the bioimage.io library. 156 if is_bioimageio: 157 if mask is not None: 158 raise NotImplementedError 159 raise NotImplementedError 160 161 # Run prediction with the torch-em library. 162 else: 163 if model is None: 164 # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself. 165 if model_path.endswith("best.pt"): 166 model_path = os.path.split(model_path)[0] 167 # print(f"tiling {tiling}") 168 # Create updated_tiling with the same structure 169 updated_tiling = { 170 "tile": {}, 171 "halo": tiling["halo"] # Keep the halo part unchanged 172 } 173 # Update tile dimensions 174 for dim in tiling["tile"]: 175 updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim] 176 # print(f"updated_tiling {updated_tiling}") 177 pred = get_prediction_torch_em( 178 input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask 179 ) 180 181 return pred 182 183 184def get_prediction_torch_em( 185 input_volume: np.ndarray, # [z, y, x] 186 tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 187 model_path: Optional[str] = None, 188 model: Optional[torch.nn.Module] = None, 189 verbose: bool = True, 190 with_channels: bool = False, 191 mask: Optional[np.ndarray] = None, 192) -> np.ndarray: 193 """Run prediction using torch-em on a given volume. 194 195 Args: 196 input_volume: The input volume to predict on. 197 model_path: The path to the model checkpoint if 'model' is not provided. 198 model: Pre-loaded model. Either model_path or model is required. 199 tiling: The tiling configuration for the prediction. 200 verbose: Whether to print timing information. 201 with_channels: Whether to predict with channels. 202 mask: Optional binary mask. If given, the prediction will only be run in 203 the foreground region of the mask. 204 205 Returns: 206 The predicted volume. 207 """ 208 # get block_shape and halo 209 block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]] 210 halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]] 211 212 t0 = time.time() 213 device = "cuda" if torch.cuda.is_available() else "cpu" 214 215 # Suppress warning when loading the model. 216 with warnings.catch_warnings(): 217 warnings.simplefilter("ignore") 218 if model is None: 219 if os.path.isdir(model_path): # Load the model from a torch_em checkpoint. 220 model = torch_em.util.load_model(checkpoint=model_path, device=device) 221 else: # Load the model directly from a serialized pytorch model. 222 model = torch.load(model_path) 223 224 # Run prediction with the model. 225 with torch.no_grad(): 226 227 # Deal with 2D segmentation case 228 if len(input_volume.shape) == 2: 229 block_shape = [block_shape[1], block_shape[2]] 230 halo = [halo[1], halo[2]] 231 232 if mask is not None: 233 if verbose: 234 print("Run prediction with mask.") 235 mask = mask.astype("bool") 236 237 pred = predict_with_halo( 238 input_volume, model, gpu_ids=[device], 239 block_shape=block_shape, halo=halo, 240 preprocess=None, with_channels=with_channels, mask=mask, 241 ) 242 if verbose: 243 print("Prediction time in", time.time() - t0, "s") 244 return pred 245 246 247def _get_file_paths(input_path, ext=".mrc"): 248 if not os.path.exists(input_path): 249 raise Exception(f"Input path not found {input_path}") 250 251 if os.path.isfile(input_path): 252 input_files = [input_path] 253 input_root = None 254 else: 255 input_files = sorted(glob(os.path.join(input_path, "**", f"*{ext}"), recursive=True)) 256 input_root = input_path 257 258 return input_files, input_root 259 260 261def _load_input(img_path, extra_files, i): 262 # Load the input data. 263 if os.path.splitext(img_path)[-1] == ".tif": 264 input_volume = imageio.imread(img_path) 265 266 else: 267 with open_file(img_path, "r") as f: 268 # Try to automatically derive the key with the raw data. 269 keys = list(f.keys()) 270 if len(keys) == 1: 271 key = keys[0] 272 elif "data" in keys: 273 key = "data" 274 elif "raw" in keys: 275 key = "raw" 276 input_volume = f[key][:] 277 278 assert input_volume.ndim in (2, 3) 279 # For now we assume this is always tif. 280 if extra_files is not None: 281 extra_input = imageio.imread(extra_files[i]) 282 assert extra_input.shape == input_volume.shape 283 input_volume = np.stack([input_volume, extra_input], axis=0) 284 285 return input_volume 286 287 288def _derive_scale(img_path, model_resolution): 289 try: 290 with mrcfile.open(img_path, "r") as f: 291 voxel_size = f.voxel_size 292 if len(model_resolution) == 2: 293 voxel_size = [voxel_size.y, voxel_size.x] 294 else: 295 voxel_size = [voxel_size.z, voxel_size.y, voxel_size.x] 296 297 assert len(voxel_size) == len(model_resolution) 298 # The voxel size is given in Angstrom and we need to translate it to nanometer. 299 voxel_size = [vsize / 10 for vsize in voxel_size] 300 301 # Compute the correct scale factor. 302 scale = tuple(vsize / res for vsize, res in zip(voxel_size, model_resolution)) 303 print("Rescaling the data at", img_path, "by", scale, "to match the training voxel size", model_resolution) 304 305 except Exception: 306 warnings.warn( 307 f"The voxel size could not be read from the data for {img_path}. " 308 "This data will not be scaled for prediction." 309 ) 310 scale = None 311 312 return scale 313 314 315def inference_helper( 316 input_path: str, 317 output_root: str, 318 segmentation_function: callable, 319 data_ext: str = ".mrc", 320 extra_input_path: Optional[str] = None, 321 extra_input_ext: str = ".tif", 322 mask_input_path: Optional[str] = None, 323 mask_input_ext: str = ".tif", 324 force: bool = False, 325 output_key: Optional[str] = None, 326 model_resolution: Optional[Tuple[float, float, float]] = None, 327 scale: Optional[Tuple[float, float, float]] = None, 328) -> None: 329 """Helper function to run segmentation for mrc files. 330 331 Args: 332 input_path: The path to the input data. 333 Can either be a folder. In this case all mrc files below the folder will be segmented. 334 Or can be a single mrc file. In this case only this mrc file will be segmented. 335 output_root: The path to the output directory where the segmentation results will be saved. 336 segmentation_function: The function performing the segmentation. 337 This function must take the input_volume as the only argument and must return only the segmentation. 338 If you want to pass additional arguments to this function the use 'funtools.partial' 339 data_ext: File extension for the image data. By default '.mrc' is used. 340 extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. 341 This enables cristae segmentation with an extra mito channel. 342 extra_input_ext: File extension for the extra inputs (by default .tif). 343 mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation. 344 mask_input_ext: File extension for the mask inputs (by default .tif). 345 force: Whether to rerun segmentation for output files that are already present. 346 output_key: Output key for the prediction. If none will write an hdf5 file. 347 model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. 348 If given, the scaling factor will automatically be determined based on the voxel_size of the input data. 349 scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'. 350 """ 351 if (scale is not None) and (model_resolution is not None): 352 raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.") 353 354 # Get the input files. If input_path is a folder then this will load all 355 # the mrc files beneath it. Otherwise we assume this is an mrc file already 356 # and just return the path to this mrc file. 357 input_files, input_root = _get_file_paths(input_path, data_ext) 358 359 # Load extra inputs if the extra_input_path was specified. 360 if extra_input_path is None: 361 extra_files = None 362 else: 363 extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext) 364 assert len(input_files) == len(extra_files) 365 366 # Load the masks if they were specified. 367 if mask_input_path is None: 368 mask_files = None 369 else: 370 mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext) 371 assert len(input_files) == len(mask_files) 372 373 for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"): 374 # Determine the output file name. 375 input_folder, input_name = os.path.split(img_path) 376 377 if output_key is None: 378 fname = os.path.splitext(input_name)[0] + "_prediction.tif" 379 else: 380 fname = os.path.splitext(input_name)[0] + "_prediction.h5" 381 382 if input_root is None: 383 output_path = os.path.join(output_root, fname) 384 else: # If we have nested input folders then we preserve the folder structure in the output. 385 rel_folder = os.path.relpath(input_folder, input_root) 386 output_path = os.path.join(output_root, rel_folder, fname) 387 388 # Check if the output path is already present. 389 # If it is we skip the prediction, unless force was set to true. 390 if os.path.exists(output_path) and not force: 391 if output_key is None: 392 continue 393 else: 394 with open_file(output_path, "r") as f: 395 if output_key in f: 396 continue 397 398 # Load the input volume. If we have extra_files then this concatenates the 399 # data across a new first axis (= channel axis). 400 input_volume = _load_input(img_path, extra_files, i) 401 # Load the mask (if given). 402 mask = None if mask_files is None else imageio.imread(mask_files[i]) 403 404 # Determine the scale factor: 405 # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None. 406 if scale is None and model_resolution is None: 407 this_scale = None 408 elif scale is not None: # If 'scale' was passed then use it. 409 this_scale = scale 410 else: # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data 411 assert model_resolution is not None 412 this_scale = _derive_scale(img_path, model_resolution) 413 414 # Run the segmentation. 415 segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale) 416 417 # Write the result to tif or h5. 418 os.makedirs(os.path.split(output_path)[0], exist_ok=True) 419 420 if output_key is None: 421 imageio.imwrite(output_path, segmentation, compression="zlib") 422 else: 423 with open_file(output_path, "a") as f: 424 f.create_dataset(output_key, data=segmentation, compression="gzip") 425 426 print(f"Saved segmentation to {output_path}.") 427 428 429def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: 430 """Determine the tile shape and halo depending on the available VRAM. 431 432 Args: 433 is_2d: Whether to return tiling settings for 2d inference. 434 435 Returns: 436 The default tiling settings for the available computational resources. 437 """ 438 if is_2d: 439 tile = {"x": 768, "y": 768, "z": 1} 440 halo = {"x": 128, "y": 128, "z": 0} 441 return {"tile": tile, "halo": halo} 442 443 if torch.cuda.is_available(): 444 # The default halo size. 445 halo = {"x": 64, "y": 64, "z": 16} 446 447 # Determine the GPU RAM and derive a suitable tiling. 448 vram = torch.cuda.get_device_properties(0).total_memory / 1e9 449 450 if vram >= 80: 451 tile = {"x": 640, "y": 640, "z": 80} 452 elif vram >= 40: 453 tile = {"x": 512, "y": 512, "z": 64} 454 elif vram >= 20: 455 tile = {"x": 352, "y": 352, "z": 48} 456 elif vram >= 10: 457 tile = {"x": 256, "y": 256, "z": 32} 458 halo = {"x": 64, "y": 64, "z": 8} # Choose a smaller halo in z. 459 else: 460 raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.") 461 462 tiling = {"tile": tile, "halo": halo} 463 print(f"Determined tile size for CUDA: {tiling}") 464 465 elif torch.backends.mps.is_available(): # Check for Apple Silicon (MPS) 466 tile = {"x": 256, "y": 256, "z": 16} 467 halo = {"x": 16, "y": 16, "z": 4} 468 tiling = {"tile": tile, "halo": halo} 469 print(f"Determined tile size for MPS: {tiling}") 470 471 # I am not sure what is reasonable on a cpu. For now choosing very small tiling. 472 # (This will not work well on a CPU in any case.) 473 else: 474 tiling = { 475 "tile": {"x": 96, "y": 96, "z": 16}, 476 "halo": {"x": 16, "y": 16, "z": 4}, 477 } 478 print(f"Determining default tiling for CPU: {tiling}") 479 480 return tiling 481 482 483def parse_tiling( 484 tile_shape: Tuple[int, int, int], 485 halo: Tuple[int, int, int], 486 is_2d: bool = False, 487) -> Dict[str, Dict[str, int]]: 488 """Helper function to parse tiling parameter input from the command line. 489 490 Args: 491 tile_shape: The tile shape. If None the default tile shape is used. 492 halo: The halo. If None the default halo is used. 493 is_2d: Whether to return tiling for a 2d model. 494 495 Returns: 496 The tiling specification. 497 """ 498 499 default_tiling = get_default_tiling(is_2d=is_2d) 500 501 if tile_shape is None: 502 tile_shape = default_tiling["tile"] 503 else: 504 assert len(tile_shape) == 3 505 tile_shape = dict(zip("zyx", tile_shape)) 506 507 if halo is None: 508 halo = default_tiling["halo"] 509 else: 510 assert len(halo) == 3 511 halo = dict(zip("zyx", halo)) 512 513 tiling = {"tile": tile_shape, "halo": halo} 514 return tiling 515 516 517# 518# Utils for post-processing. 519# 520 521 522def apply_size_filter( 523 segmentation: np.ndarray, 524 min_size: int, 525 verbose: bool = False, 526 block_shape: Tuple[int, int, int] = (128, 256, 256), 527) -> np.ndarray: 528 """Apply size filter to the segmentation to remove small objects. 529 530 Args: 531 segmentation: The segmentation. 532 min_size: The minimal object size in pixels. 533 verbose: Whether to print runtimes. 534 block_shape: Block shape for parallelizing the operations. 535 536 Returns: 537 The size filtered segmentation. 538 """ 539 if min_size == 0: 540 return segmentation 541 t0 = time.time() 542 if segmentation.ndim == 2 and len(block_shape) == 3: 543 block_shape_ = block_shape[1:] 544 else: 545 block_shape_ = block_shape 546 ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose) 547 filter_ids = ids[sizes < min_size] 548 segmentation[np.isin(segmentation, filter_ids)] = 0 549 if verbose: 550 print("Size filter in", time.time() - t0, "s") 551 return segmentation 552 553 554def _postprocess_seg_3d(seg, area_threshold=1000, iterations=4, iterations_3d=8): 555 # Structure lement for 2d dilation in 3d. 556 structure_element = np.ones((3, 3)) # 3x3 structure for XY plane 557 structure_3d = np.zeros((1, 3, 3)) # Only applied in the XY plane 558 structure_3d[0] = structure_element 559 560 props = regionprops(seg) 561 for prop in props: 562 # Get bounding box and mask. 563 bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:3], prop.bbox[3:])) 564 mask = seg[bb] == prop.label 565 566 # Fill small holes and apply closing. 567 mask = remove_small_holes(mask, area_threshold=area_threshold) 568 mask = np.logical_or(binary_closing(mask, iterations=iterations), mask) 569 mask = np.logical_or(binary_closing(mask, iterations=iterations_3d, structure=structure_3d), mask) 570 seg[bb][mask] = prop.label 571 572 return seg 573 574 575# 576# Utils for torch device. 577# 578 579def _get_default_device(): 580 # Check that we're in CI and use the CPU if we are. 581 # Otherwise the tests may run out of memory on MAC if MPS is used. 582 if os.getenv("GITHUB_ACTIONS") == "true": 583 return "cpu" 584 # Use cuda enabled gpu if it's available. 585 if torch.cuda.is_available(): 586 device = "cuda" 587 # As second priority use mps. 588 # See https://pytorch.org/docs/stable/notes/mps.html for details 589 elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): 590 device = "mps" 591 # Use the CPU as fallback. 592 else: 593 device = "cpu" 594 return device 595 596 597def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 598 """Get the torch device. 599 600 If no device is passed the default device for your system is used. 601 Else it will be checked if the device you have passed is supported. 602 603 Args: 604 device: The input device. 605 606 Returns: 607 The device. 608 """ 609 if device is None or device == "auto": 610 device = _get_default_device() 611 else: 612 device_type = device if isinstance(device, str) else device.type 613 if device_type.lower() == "cuda": 614 if not torch.cuda.is_available(): 615 raise RuntimeError("PyTorch CUDA backend is not available.") 616 elif device_type.lower() == "mps": 617 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 618 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 619 elif device_type.lower() == "cpu": 620 pass # cpu is always available 621 else: 622 raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.") 623 return device
103def get_prediction( 104 input_volume: np.ndarray, # [z, y, x] 105 tiling: Optional[Dict[str, Dict[str, int]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 106 model_path: Optional[str] = None, 107 model: Optional[torch.nn.Module] = None, 108 verbose: bool = True, 109 with_channels: bool = False, 110 channels_to_standardize: Optional[List[int]] = None, 111 mask: Optional[np.ndarray] = None, 112) -> np.ndarray: 113 """Run prediction on a given volume. 114 115 This function will automatically choose the correct prediction implementation, 116 depending on the model type. 117 118 Args: 119 input_volume: The input volume to predict on. 120 model_path: The path to the model checkpoint if 'model' is not provided. 121 model: Pre-loaded model. Either model_path or model is required. 122 tiling: The tiling configuration for the prediction. 123 verbose: Whether to print timing information. 124 with_channels: Whether to predict with channels. 125 channels_to_standardize: List of channels to standardize. Defaults to None. 126 mask: Optional binary mask. If given, the prediction will only be run in 127 the foreground region of the mask. 128 129 Returns: 130 The predicted volume. 131 """ 132 # make sure either model path or model is passed 133 if model is None and model_path is None: 134 raise ValueError("Either 'model_path' or 'model' must be provided.") 135 136 if model is not None: 137 is_bioimageio = None 138 else: 139 is_bioimageio = model_path.endswith(".zip") 140 141 if tiling is None: 142 tiling = get_default_tiling() 143 144 # We standardize the data for the whole volume beforehand. 145 # If we have channels then the standardization is done independently per channel. 146 if with_channels: 147 input_volume = input_volume.astype(np.float32, copy=False) 148 # TODO Check that this is the correct axis. 149 if channels_to_standardize is None: # assume all channels 150 channels_to_standardize = range(input_volume.shape[0]) 151 for ch in channels_to_standardize: 152 input_volume[ch] = torch_em.transform.raw.standardize(input_volume[ch]) 153 else: 154 input_volume = torch_em.transform.raw.standardize(input_volume) 155 156 # Run prediction with the bioimage.io library. 157 if is_bioimageio: 158 if mask is not None: 159 raise NotImplementedError 160 raise NotImplementedError 161 162 # Run prediction with the torch-em library. 163 else: 164 if model is None: 165 # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself. 166 if model_path.endswith("best.pt"): 167 model_path = os.path.split(model_path)[0] 168 # print(f"tiling {tiling}") 169 # Create updated_tiling with the same structure 170 updated_tiling = { 171 "tile": {}, 172 "halo": tiling["halo"] # Keep the halo part unchanged 173 } 174 # Update tile dimensions 175 for dim in tiling["tile"]: 176 updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim] 177 # print(f"updated_tiling {updated_tiling}") 178 pred = get_prediction_torch_em( 179 input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask 180 ) 181 182 return pred
Run prediction on a given volume.
This function will automatically choose the correct prediction implementation, depending on the model type.
Arguments:
- input_volume: The input volume to predict on.
- model_path: The path to the model checkpoint if 'model' is not provided.
- model: Pre-loaded model. Either model_path or model is required.
- tiling: The tiling configuration for the prediction.
- verbose: Whether to print timing information.
- with_channels: Whether to predict with channels.
- channels_to_standardize: List of channels to standardize. Defaults to None.
- mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
Returns:
The predicted volume.
185def get_prediction_torch_em( 186 input_volume: np.ndarray, # [z, y, x] 187 tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 188 model_path: Optional[str] = None, 189 model: Optional[torch.nn.Module] = None, 190 verbose: bool = True, 191 with_channels: bool = False, 192 mask: Optional[np.ndarray] = None, 193) -> np.ndarray: 194 """Run prediction using torch-em on a given volume. 195 196 Args: 197 input_volume: The input volume to predict on. 198 model_path: The path to the model checkpoint if 'model' is not provided. 199 model: Pre-loaded model. Either model_path or model is required. 200 tiling: The tiling configuration for the prediction. 201 verbose: Whether to print timing information. 202 with_channels: Whether to predict with channels. 203 mask: Optional binary mask. If given, the prediction will only be run in 204 the foreground region of the mask. 205 206 Returns: 207 The predicted volume. 208 """ 209 # get block_shape and halo 210 block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]] 211 halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]] 212 213 t0 = time.time() 214 device = "cuda" if torch.cuda.is_available() else "cpu" 215 216 # Suppress warning when loading the model. 217 with warnings.catch_warnings(): 218 warnings.simplefilter("ignore") 219 if model is None: 220 if os.path.isdir(model_path): # Load the model from a torch_em checkpoint. 221 model = torch_em.util.load_model(checkpoint=model_path, device=device) 222 else: # Load the model directly from a serialized pytorch model. 223 model = torch.load(model_path) 224 225 # Run prediction with the model. 226 with torch.no_grad(): 227 228 # Deal with 2D segmentation case 229 if len(input_volume.shape) == 2: 230 block_shape = [block_shape[1], block_shape[2]] 231 halo = [halo[1], halo[2]] 232 233 if mask is not None: 234 if verbose: 235 print("Run prediction with mask.") 236 mask = mask.astype("bool") 237 238 pred = predict_with_halo( 239 input_volume, model, gpu_ids=[device], 240 block_shape=block_shape, halo=halo, 241 preprocess=None, with_channels=with_channels, mask=mask, 242 ) 243 if verbose: 244 print("Prediction time in", time.time() - t0, "s") 245 return pred
Run prediction using torch-em on a given volume.
Arguments:
- input_volume: The input volume to predict on.
- model_path: The path to the model checkpoint if 'model' is not provided.
- model: Pre-loaded model. Either model_path or model is required.
- tiling: The tiling configuration for the prediction.
- verbose: Whether to print timing information.
- with_channels: Whether to predict with channels.
- mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
Returns:
The predicted volume.
316def inference_helper( 317 input_path: str, 318 output_root: str, 319 segmentation_function: callable, 320 data_ext: str = ".mrc", 321 extra_input_path: Optional[str] = None, 322 extra_input_ext: str = ".tif", 323 mask_input_path: Optional[str] = None, 324 mask_input_ext: str = ".tif", 325 force: bool = False, 326 output_key: Optional[str] = None, 327 model_resolution: Optional[Tuple[float, float, float]] = None, 328 scale: Optional[Tuple[float, float, float]] = None, 329) -> None: 330 """Helper function to run segmentation for mrc files. 331 332 Args: 333 input_path: The path to the input data. 334 Can either be a folder. In this case all mrc files below the folder will be segmented. 335 Or can be a single mrc file. In this case only this mrc file will be segmented. 336 output_root: The path to the output directory where the segmentation results will be saved. 337 segmentation_function: The function performing the segmentation. 338 This function must take the input_volume as the only argument and must return only the segmentation. 339 If you want to pass additional arguments to this function the use 'funtools.partial' 340 data_ext: File extension for the image data. By default '.mrc' is used. 341 extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. 342 This enables cristae segmentation with an extra mito channel. 343 extra_input_ext: File extension for the extra inputs (by default .tif). 344 mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation. 345 mask_input_ext: File extension for the mask inputs (by default .tif). 346 force: Whether to rerun segmentation for output files that are already present. 347 output_key: Output key for the prediction. If none will write an hdf5 file. 348 model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. 349 If given, the scaling factor will automatically be determined based on the voxel_size of the input data. 350 scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'. 351 """ 352 if (scale is not None) and (model_resolution is not None): 353 raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.") 354 355 # Get the input files. If input_path is a folder then this will load all 356 # the mrc files beneath it. Otherwise we assume this is an mrc file already 357 # and just return the path to this mrc file. 358 input_files, input_root = _get_file_paths(input_path, data_ext) 359 360 # Load extra inputs if the extra_input_path was specified. 361 if extra_input_path is None: 362 extra_files = None 363 else: 364 extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext) 365 assert len(input_files) == len(extra_files) 366 367 # Load the masks if they were specified. 368 if mask_input_path is None: 369 mask_files = None 370 else: 371 mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext) 372 assert len(input_files) == len(mask_files) 373 374 for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"): 375 # Determine the output file name. 376 input_folder, input_name = os.path.split(img_path) 377 378 if output_key is None: 379 fname = os.path.splitext(input_name)[0] + "_prediction.tif" 380 else: 381 fname = os.path.splitext(input_name)[0] + "_prediction.h5" 382 383 if input_root is None: 384 output_path = os.path.join(output_root, fname) 385 else: # If we have nested input folders then we preserve the folder structure in the output. 386 rel_folder = os.path.relpath(input_folder, input_root) 387 output_path = os.path.join(output_root, rel_folder, fname) 388 389 # Check if the output path is already present. 390 # If it is we skip the prediction, unless force was set to true. 391 if os.path.exists(output_path) and not force: 392 if output_key is None: 393 continue 394 else: 395 with open_file(output_path, "r") as f: 396 if output_key in f: 397 continue 398 399 # Load the input volume. If we have extra_files then this concatenates the 400 # data across a new first axis (= channel axis). 401 input_volume = _load_input(img_path, extra_files, i) 402 # Load the mask (if given). 403 mask = None if mask_files is None else imageio.imread(mask_files[i]) 404 405 # Determine the scale factor: 406 # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None. 407 if scale is None and model_resolution is None: 408 this_scale = None 409 elif scale is not None: # If 'scale' was passed then use it. 410 this_scale = scale 411 else: # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data 412 assert model_resolution is not None 413 this_scale = _derive_scale(img_path, model_resolution) 414 415 # Run the segmentation. 416 segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale) 417 418 # Write the result to tif or h5. 419 os.makedirs(os.path.split(output_path)[0], exist_ok=True) 420 421 if output_key is None: 422 imageio.imwrite(output_path, segmentation, compression="zlib") 423 else: 424 with open_file(output_path, "a") as f: 425 f.create_dataset(output_key, data=segmentation, compression="gzip") 426 427 print(f"Saved segmentation to {output_path}.")
Helper function to run segmentation for mrc files.
Arguments:
- input_path: The path to the input data. Can either be a folder. In this case all mrc files below the folder will be segmented. Or can be a single mrc file. In this case only this mrc file will be segmented.
- output_root: The path to the output directory where the segmentation results will be saved.
- segmentation_function: The function performing the segmentation. This function must take the input_volume as the only argument and must return only the segmentation. If you want to pass additional arguments to this function the use 'funtools.partial'
- data_ext: File extension for the image data. By default '.mrc' is used.
- extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. This enables cristae segmentation with an extra mito channel.
- extra_input_ext: File extension for the extra inputs (by default .tif).
- mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
- mask_input_ext: File extension for the mask inputs (by default .tif).
- force: Whether to rerun segmentation for output files that are already present.
- output_key: Output key for the prediction. If none will write an hdf5 file.
- model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
- scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
430def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: 431 """Determine the tile shape and halo depending on the available VRAM. 432 433 Args: 434 is_2d: Whether to return tiling settings for 2d inference. 435 436 Returns: 437 The default tiling settings for the available computational resources. 438 """ 439 if is_2d: 440 tile = {"x": 768, "y": 768, "z": 1} 441 halo = {"x": 128, "y": 128, "z": 0} 442 return {"tile": tile, "halo": halo} 443 444 if torch.cuda.is_available(): 445 # The default halo size. 446 halo = {"x": 64, "y": 64, "z": 16} 447 448 # Determine the GPU RAM and derive a suitable tiling. 449 vram = torch.cuda.get_device_properties(0).total_memory / 1e9 450 451 if vram >= 80: 452 tile = {"x": 640, "y": 640, "z": 80} 453 elif vram >= 40: 454 tile = {"x": 512, "y": 512, "z": 64} 455 elif vram >= 20: 456 tile = {"x": 352, "y": 352, "z": 48} 457 elif vram >= 10: 458 tile = {"x": 256, "y": 256, "z": 32} 459 halo = {"x": 64, "y": 64, "z": 8} # Choose a smaller halo in z. 460 else: 461 raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.") 462 463 tiling = {"tile": tile, "halo": halo} 464 print(f"Determined tile size for CUDA: {tiling}") 465 466 elif torch.backends.mps.is_available(): # Check for Apple Silicon (MPS) 467 tile = {"x": 256, "y": 256, "z": 16} 468 halo = {"x": 16, "y": 16, "z": 4} 469 tiling = {"tile": tile, "halo": halo} 470 print(f"Determined tile size for MPS: {tiling}") 471 472 # I am not sure what is reasonable on a cpu. For now choosing very small tiling. 473 # (This will not work well on a CPU in any case.) 474 else: 475 tiling = { 476 "tile": {"x": 96, "y": 96, "z": 16}, 477 "halo": {"x": 16, "y": 16, "z": 4}, 478 } 479 print(f"Determining default tiling for CPU: {tiling}") 480 481 return tiling
Determine the tile shape and halo depending on the available VRAM.
Arguments:
- is_2d: Whether to return tiling settings for 2d inference.
Returns:
The default tiling settings for the available computational resources.
484def parse_tiling( 485 tile_shape: Tuple[int, int, int], 486 halo: Tuple[int, int, int], 487 is_2d: bool = False, 488) -> Dict[str, Dict[str, int]]: 489 """Helper function to parse tiling parameter input from the command line. 490 491 Args: 492 tile_shape: The tile shape. If None the default tile shape is used. 493 halo: The halo. If None the default halo is used. 494 is_2d: Whether to return tiling for a 2d model. 495 496 Returns: 497 The tiling specification. 498 """ 499 500 default_tiling = get_default_tiling(is_2d=is_2d) 501 502 if tile_shape is None: 503 tile_shape = default_tiling["tile"] 504 else: 505 assert len(tile_shape) == 3 506 tile_shape = dict(zip("zyx", tile_shape)) 507 508 if halo is None: 509 halo = default_tiling["halo"] 510 else: 511 assert len(halo) == 3 512 halo = dict(zip("zyx", halo)) 513 514 tiling = {"tile": tile_shape, "halo": halo} 515 return tiling
Helper function to parse tiling parameter input from the command line.
Arguments:
- tile_shape: The tile shape. If None the default tile shape is used.
- halo: The halo. If None the default halo is used.
- is_2d: Whether to return tiling for a 2d model.
Returns:
The tiling specification.
523def apply_size_filter( 524 segmentation: np.ndarray, 525 min_size: int, 526 verbose: bool = False, 527 block_shape: Tuple[int, int, int] = (128, 256, 256), 528) -> np.ndarray: 529 """Apply size filter to the segmentation to remove small objects. 530 531 Args: 532 segmentation: The segmentation. 533 min_size: The minimal object size in pixels. 534 verbose: Whether to print runtimes. 535 block_shape: Block shape for parallelizing the operations. 536 537 Returns: 538 The size filtered segmentation. 539 """ 540 if min_size == 0: 541 return segmentation 542 t0 = time.time() 543 if segmentation.ndim == 2 and len(block_shape) == 3: 544 block_shape_ = block_shape[1:] 545 else: 546 block_shape_ = block_shape 547 ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose) 548 filter_ids = ids[sizes < min_size] 549 segmentation[np.isin(segmentation, filter_ids)] = 0 550 if verbose: 551 print("Size filter in", time.time() - t0, "s") 552 return segmentation
Apply size filter to the segmentation to remove small objects.
Arguments:
- segmentation: The segmentation.
- min_size: The minimal object size in pixels.
- verbose: Whether to print runtimes.
- block_shape: Block shape for parallelizing the operations.
Returns:
The size filtered segmentation.
598def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 599 """Get the torch device. 600 601 If no device is passed the default device for your system is used. 602 Else it will be checked if the device you have passed is supported. 603 604 Args: 605 device: The input device. 606 607 Returns: 608 The device. 609 """ 610 if device is None or device == "auto": 611 device = _get_default_device() 612 else: 613 device_type = device if isinstance(device, str) else device.type 614 if device_type.lower() == "cuda": 615 if not torch.cuda.is_available(): 616 raise RuntimeError("PyTorch CUDA backend is not available.") 617 elif device_type.lower() == "mps": 618 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 619 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 620 elif device_type.lower() == "cpu": 621 pass # cpu is always available 622 else: 623 raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.") 624 return device
Get the torch device.
If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.
Arguments:
- device: The input device.
Returns:
The device.