synapse_net.inference.util
1import os 2import time 3import warnings 4from glob import glob 5from typing import Dict, Optional, Tuple, Union 6 7# # Suppress annoying import warnings. 8# with warnings.catch_warnings(): 9# warnings.simplefilter("ignore") 10# import bioimageio.core 11 12import imageio.v3 as imageio 13import elf.parallel as parallel 14import mrcfile 15import numpy as np 16import torch 17import torch_em 18# import xarray 19 20from elf.io import open_file 21from scipy.ndimage import binary_closing 22from skimage.measure import regionprops 23from skimage.morphology import remove_small_holes 24from skimage.transform import rescale, resize 25from torch_em.util.prediction import predict_with_halo 26from tqdm import tqdm 27 28 29# 30# Utils for prediction. 31# 32 33 34class _Scaler: 35 def __init__(self, scale, verbose): 36 self.scale = scale 37 self.verbose = verbose 38 self._original_shape = None 39 40 def scale_input(self, input_volume, is_segmentation=False): 41 if self.scale is None: 42 return input_volume 43 44 if self._original_shape is None: 45 self._original_shape = input_volume.shape 46 elif self._oringal_shape != input_volume.shape: 47 raise RuntimeError( 48 "Scaler was called with different input shapes. " 49 "This is not supported, please create a new instance of the class for it." 50 ) 51 52 if is_segmentation: 53 input_volume = rescale( 54 input_volume, self.scale, preserve_range=True, order=0, anti_aliasing=False, 55 ).astype(input_volume.dtype) 56 else: 57 input_volume = rescale(input_volume, self.scale, preserve_range=True).astype(input_volume.dtype) 58 59 if self.verbose: 60 print("Rescaled volume from", self._original_shape, "to", input_volume.shape) 61 return input_volume 62 63 def rescale_output(self, output, is_segmentation): 64 if self.scale is None: 65 return output 66 67 assert self._original_shape is not None 68 out_shape = self._original_shape 69 if output.ndim > len(out_shape): 70 assert output.ndim == len(out_shape) + 1 71 out_shape = (output.shape[0],) + out_shape 72 73 if is_segmentation: 74 output = resize(output, out_shape, preserve_range=True, order=0, anti_aliasing=False).astype(output.dtype) 75 else: 76 output = resize(output, out_shape, preserve_range=True).astype(output.dtype) 77 78 return output 79 80 81def get_prediction( 82 input_volume: np.ndarray, # [z, y, x] 83 tiling: Optional[Dict[str, Dict[str, int]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 84 model_path: Optional[str] = None, 85 model: Optional[torch.nn.Module] = None, 86 verbose: bool = True, 87 with_channels: bool = False, 88 mask: Optional[np.ndarray] = None, 89) -> np.ndarray: 90 """Run prediction on a given volume. 91 92 This function will automatically choose the correct prediction implementation, 93 depending on the model type. 94 95 Args: 96 input_volume: The input volume to predict on. 97 model_path: The path to the model checkpoint if 'model' is not provided. 98 model: Pre-loaded model. Either model_path or model is required. 99 tiling: The tiling configuration for the prediction. 100 verbose: Whether to print timing information. 101 with_channels: Whether to predict with channels. 102 mask: Optional binary mask. If given, the prediction will only be run in 103 the foreground region of the mask. 104 105 Returns: 106 The predicted volume. 107 """ 108 # make sure either model path or model is passed 109 if model is None and model_path is None: 110 raise ValueError("Either 'model_path' or 'model' must be provided.") 111 112 if model is not None: 113 is_bioimageio = None 114 else: 115 is_bioimageio = model_path.endswith(".zip") 116 117 if tiling is None: 118 tiling = get_default_tiling() 119 120 # We standardize the data for the whole volume beforehand. 121 # If we have channels then the standardization is done independently per channel. 122 if with_channels: 123 # TODO Check that this is the correct axis. 124 input_volume = torch_em.transform.raw.standardize(input_volume, axis=(1, 2, 3)) 125 else: 126 input_volume = torch_em.transform.raw.standardize(input_volume) 127 128 # Run prediction with the bioimage.io library. 129 if is_bioimageio: 130 if mask is not None: 131 raise NotImplementedError 132 raise NotImplementedError 133 134 # Run prediction with the torch-em library. 135 else: 136 if model is None: 137 # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself. 138 if model_path.endswith("best.pt"): 139 model_path = os.path.split(model_path)[0] 140 # print(f"tiling {tiling}") 141 # Create updated_tiling with the same structure 142 updated_tiling = { 143 "tile": {}, 144 "halo": tiling["halo"] # Keep the halo part unchanged 145 } 146 # Update tile dimensions 147 for dim in tiling["tile"]: 148 updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim] 149 # print(f"updated_tiling {updated_tiling}") 150 pred = get_prediction_torch_em( 151 input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask 152 ) 153 154 return pred 155 156 157def get_prediction_torch_em( 158 input_volume: np.ndarray, # [z, y, x] 159 tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 160 model_path: Optional[str] = None, 161 model: Optional[torch.nn.Module] = None, 162 verbose: bool = True, 163 with_channels: bool = False, 164 mask: Optional[np.ndarray] = None, 165) -> np.ndarray: 166 """Run prediction using torch-em on a given volume. 167 168 Args: 169 input_volume: The input volume to predict on. 170 model_path: The path to the model checkpoint if 'model' is not provided. 171 model: Pre-loaded model. Either model_path or model is required. 172 tiling: The tiling configuration for the prediction. 173 verbose: Whether to print timing information. 174 with_channels: Whether to predict with channels. 175 mask: Optional binary mask. If given, the prediction will only be run in 176 the foreground region of the mask. 177 178 Returns: 179 The predicted volume. 180 """ 181 # get block_shape and halo 182 block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]] 183 halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]] 184 185 t0 = time.time() 186 device = "cuda" if torch.cuda.is_available() else "cpu" 187 188 # Suppress warning when loading the model. 189 with warnings.catch_warnings(): 190 warnings.simplefilter("ignore") 191 if model is None: 192 if os.path.isdir(model_path): # Load the model from a torch_em checkpoint. 193 model = torch_em.util.load_model(checkpoint=model_path, device=device) 194 else: # Load the model directly from a serialized pytorch model. 195 model = torch.load(model_path) 196 197 # Run prediction with the model. 198 with torch.no_grad(): 199 200 # Deal with 2D segmentation case 201 if len(input_volume.shape) == 2: 202 block_shape = [block_shape[1], block_shape[2]] 203 halo = [halo[1], halo[2]] 204 205 if mask is not None: 206 if verbose: 207 print("Run prediction with mask.") 208 mask = mask.astype("bool") 209 210 pred = predict_with_halo( 211 input_volume, model, gpu_ids=[device], 212 block_shape=block_shape, halo=halo, 213 preprocess=None, with_channels=with_channels, mask=mask, 214 ) 215 if verbose: 216 print("Prediction time in", time.time() - t0, "s") 217 return pred 218 219 220def _get_file_paths(input_path, ext=".mrc"): 221 if not os.path.exists(input_path): 222 raise Exception(f"Input path not found {input_path}") 223 224 if os.path.isfile(input_path): 225 input_files = [input_path] 226 input_root = None 227 else: 228 input_files = sorted(glob(os.path.join(input_path, "**", f"*{ext}"), recursive=True)) 229 input_root = input_path 230 231 return input_files, input_root 232 233 234def _load_input(img_path, extra_files, i): 235 # Load the input data. 236 if os.path.splitext(img_path)[-1] == ".tif": 237 input_volume = imageio.imread(img_path) 238 239 else: 240 with open_file(img_path, "r") as f: 241 # Try to automatically derive the key with the raw data. 242 keys = list(f.keys()) 243 if len(keys) == 1: 244 key = keys[0] 245 elif "data" in keys: 246 key = "data" 247 elif "raw" in keys: 248 key = "raw" 249 input_volume = f[key][:] 250 251 assert input_volume.ndim in (2, 3) 252 # For now we assume this is always tif. 253 if extra_files is not None: 254 extra_input = imageio.imread(extra_files[i]) 255 assert extra_input.shape == input_volume.shape 256 input_volume = np.stack([input_volume, extra_input], axis=0) 257 258 return input_volume 259 260 261def _derive_scale(img_path, model_resolution): 262 try: 263 with mrcfile.open(img_path, "r") as f: 264 voxel_size = f.voxel_size 265 if len(model_resolution) == 2: 266 voxel_size = [voxel_size.y, voxel_size.x] 267 else: 268 voxel_size = [voxel_size.z, voxel_size.y, voxel_size.x] 269 270 assert len(voxel_size) == len(model_resolution) 271 # The voxel size is given in Angstrom and we need to translate it to nanometer. 272 voxel_size = [vsize / 10 for vsize in voxel_size] 273 274 # Compute the correct scale factor. 275 scale = tuple(vsize / res for vsize, res in zip(voxel_size, model_resolution)) 276 print("Rescaling the data at", img_path, "by", scale, "to match the training voxel size", model_resolution) 277 278 except Exception: 279 warnings.warn( 280 f"The voxel size could not be read from the data for {img_path}. " 281 "This data will not be scaled for prediction." 282 ) 283 scale = None 284 285 return scale 286 287 288def inference_helper( 289 input_path: str, 290 output_root: str, 291 segmentation_function: callable, 292 data_ext: str = ".mrc", 293 extra_input_path: Optional[str] = None, 294 extra_input_ext: str = ".tif", 295 mask_input_path: Optional[str] = None, 296 mask_input_ext: str = ".tif", 297 force: bool = False, 298 output_key: Optional[str] = None, 299 model_resolution: Optional[Tuple[float, float, float]] = None, 300 scale: Optional[Tuple[float, float, float]] = None, 301) -> None: 302 """Helper function to run segmentation for mrc files. 303 304 Args: 305 input_path: The path to the input data. 306 Can either be a folder. In this case all mrc files below the folder will be segmented. 307 Or can be a single mrc file. In this case only this mrc file will be segmented. 308 output_root: The path to the output directory where the segmentation results will be saved. 309 segmentation_function: The function performing the segmentation. 310 This function must take the input_volume as the only argument and must return only the segmentation. 311 If you want to pass additional arguments to this function the use 'funtools.partial' 312 data_ext: File extension for the image data. By default '.mrc' is used. 313 extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. 314 This enables cristae segmentation with an extra mito channel. 315 extra_input_ext: File extension for the extra inputs (by default .tif). 316 mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation. 317 mask_input_ext: File extension for the mask inputs (by default .tif). 318 force: Whether to rerun segmentation for output files that are already present. 319 output_key: Output key for the prediction. If none will write an hdf5 file. 320 model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. 321 If given, the scaling factor will automatically be determined based on the voxel_size of the input data. 322 scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'. 323 """ 324 if (scale is not None) and (model_resolution is not None): 325 raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.") 326 327 # Get the input files. If input_path is a folder then this will load all 328 # the mrc files beneath it. Otherwise we assume this is an mrc file already 329 # and just return the path to this mrc file. 330 input_files, input_root = _get_file_paths(input_path, data_ext) 331 332 # Load extra inputs if the extra_input_path was specified. 333 if extra_input_path is None: 334 extra_files = None 335 else: 336 extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext) 337 assert len(input_files) == len(extra_files) 338 339 # Load the masks if they were specified. 340 if mask_input_path is None: 341 mask_files = None 342 else: 343 mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext) 344 assert len(input_files) == len(mask_files) 345 346 for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"): 347 # Determine the output file name. 348 input_folder, input_name = os.path.split(img_path) 349 350 if output_key is None: 351 fname = os.path.splitext(input_name)[0] + "_prediction.tif" 352 else: 353 fname = os.path.splitext(input_name)[0] + "_prediction.h5" 354 355 if input_root is None: 356 output_path = os.path.join(output_root, fname) 357 else: # If we have nested input folders then we preserve the folder structure in the output. 358 rel_folder = os.path.relpath(input_folder, input_root) 359 output_path = os.path.join(output_root, rel_folder, fname) 360 361 # Check if the output path is already present. 362 # If it is we skip the prediction, unless force was set to true. 363 if os.path.exists(output_path) and not force: 364 if output_key is None: 365 continue 366 else: 367 with open_file(output_path, "r") as f: 368 if output_key in f: 369 continue 370 371 # Load the input volume. If we have extra_files then this concatenates the 372 # data across a new first axis (= channel axis). 373 input_volume = _load_input(img_path, extra_files, i) 374 # Load the mask (if given). 375 mask = None if mask_files is None else imageio.imread(mask_files[i]) 376 377 # Determine the scale factor: 378 # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None. 379 if scale is None and model_resolution is None: 380 this_scale = None 381 elif scale is not None: # If 'scale' was passed then use it. 382 this_scale = scale 383 else: # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data 384 assert model_resolution is not None 385 this_scale = _derive_scale(img_path, model_resolution) 386 387 # Run the segmentation. 388 segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale) 389 390 # Write the result to tif or h5. 391 os.makedirs(os.path.split(output_path)[0], exist_ok=True) 392 393 if output_key is None: 394 imageio.imwrite(output_path, segmentation, compression="zlib") 395 else: 396 with open_file(output_path, "a") as f: 397 f.create_dataset(output_key, data=segmentation, compression="gzip") 398 399 print(f"Saved segmentation to {output_path}.") 400 401 402def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: 403 """Determine the tile shape and halo depending on the available VRAM. 404 405 Args: 406 is_2d: Whether to return tiling settings for 2d inference. 407 408 Returns: 409 The default tiling settings for the available computational resources. 410 """ 411 if is_2d: 412 tile = {"x": 768, "y": 768, "z": 1} 413 halo = {"x": 128, "y": 128, "z": 0} 414 return {"tile": tile, "halo": halo} 415 416 if torch.cuda.is_available(): 417 # The default halo size. 418 halo = {"x": 64, "y": 64, "z": 16} 419 420 # Determine the GPU RAM and derive a suitable tiling. 421 vram = torch.cuda.get_device_properties(0).total_memory / 1e9 422 423 if vram >= 80: 424 tile = {"x": 640, "y": 640, "z": 80} 425 elif vram >= 40: 426 tile = {"x": 512, "y": 512, "z": 64} 427 elif vram >= 20: 428 tile = {"x": 352, "y": 352, "z": 48} 429 elif vram >= 10: 430 tile = {"x": 256, "y": 256, "z": 32} 431 halo = {"x": 64, "y": 64, "z": 8} # Choose a smaller halo in z. 432 else: 433 raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.") 434 435 print(f"Determined tile size: {tile}") 436 tiling = {"tile": tile, "halo": halo} 437 438 # I am not sure what is reasonable on a cpu. For now choosing very small tiling. 439 # (This will not work well on a CPU in any case.) 440 else: 441 print("Determining default tiling") 442 tiling = { 443 "tile": {"x": 96, "y": 96, "z": 16}, 444 "halo": {"x": 16, "y": 16, "z": 4}, 445 } 446 447 return tiling 448 449 450def parse_tiling( 451 tile_shape: Tuple[int, int, int], 452 halo: Tuple[int, int, int], 453 is_2d: bool = False, 454) -> Dict[str, Dict[str, int]]: 455 """Helper function to parse tiling parameter input from the command line. 456 457 Args: 458 tile_shape: The tile shape. If None the default tile shape is used. 459 halo: The halo. If None the default halo is used. 460 is_2d: Whether to return tiling for a 2d model. 461 462 Returns: 463 The tiling specification. 464 """ 465 466 default_tiling = get_default_tiling(is_2d=is_2d) 467 468 if tile_shape is None: 469 tile_shape = default_tiling["tile"] 470 else: 471 assert len(tile_shape) == 3 472 tile_shape = dict(zip("zyx", tile_shape)) 473 474 if halo is None: 475 halo = default_tiling["halo"] 476 else: 477 assert len(halo) == 3 478 halo = dict(zip("zyx", halo)) 479 480 tiling = {"tile": tile_shape, "halo": halo} 481 return tiling 482 483 484# 485# Utils for post-processing. 486# 487 488 489def apply_size_filter( 490 segmentation: np.ndarray, 491 min_size: int, 492 verbose: bool = False, 493 block_shape: Tuple[int, int, int] = (128, 256, 256), 494) -> np.ndarray: 495 """Apply size filter to the segmentation to remove small objects. 496 497 Args: 498 segmentation: The segmentation. 499 min_size: The minimal object size in pixels. 500 verbose: Whether to print runtimes. 501 block_shape: Block shape for parallelizing the operations. 502 503 Returns: 504 The size filtered segmentation. 505 """ 506 if min_size == 0: 507 return segmentation 508 t0 = time.time() 509 if segmentation.ndim == 2 and len(block_shape) == 3: 510 block_shape_ = block_shape[1:] 511 else: 512 block_shape_ = block_shape 513 ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose) 514 filter_ids = ids[sizes < min_size] 515 segmentation[np.isin(segmentation, filter_ids)] = 0 516 if verbose: 517 print("Size filter in", time.time() - t0, "s") 518 return segmentation 519 520 521def _postprocess_seg_3d(seg, area_threshold=1000, iterations=4, iterations_3d=8): 522 # Structure lement for 2d dilation in 3d. 523 structure_element = np.ones((3, 3)) # 3x3 structure for XY plane 524 structure_3d = np.zeros((1, 3, 3)) # Only applied in the XY plane 525 structure_3d[0] = structure_element 526 527 props = regionprops(seg) 528 for prop in props: 529 # Get bounding box and mask. 530 bb = tuple(slice(start, stop) for start, stop in zip(prop.bbox[:3], prop.bbox[3:])) 531 mask = seg[bb] == prop.label 532 533 # Fill small holes and apply closing. 534 mask = remove_small_holes(mask, area_threshold=area_threshold) 535 mask = np.logical_or(binary_closing(mask, iterations=iterations), mask) 536 mask = np.logical_or(binary_closing(mask, iterations=iterations_3d, structure=structure_3d), mask) 537 seg[bb][mask] = prop.label 538 539 return seg 540 541 542# 543# Utils for torch device. 544# 545 546def _get_default_device(): 547 # Check that we're in CI and use the CPU if we are. 548 # Otherwise the tests may run out of memory on MAC if MPS is used. 549 if os.getenv("GITHUB_ACTIONS") == "true": 550 return "cpu" 551 # Use cuda enabled gpu if it's available. 552 if torch.cuda.is_available(): 553 device = "cuda" 554 # As second priority use mps. 555 # See https://pytorch.org/docs/stable/notes/mps.html for details 556 elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): 557 device = "mps" 558 # Use the CPU as fallback. 559 else: 560 device = "cpu" 561 return device 562 563 564def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 565 """Get the torch device. 566 567 If no device is passed the default device for your system is used. 568 Else it will be checked if the device you have passed is supported. 569 570 Args: 571 device: The input device. 572 573 Returns: 574 The device. 575 """ 576 if device is None or device == "auto": 577 device = _get_default_device() 578 else: 579 device_type = device if isinstance(device, str) else device.type 580 if device_type.lower() == "cuda": 581 if not torch.cuda.is_available(): 582 raise RuntimeError("PyTorch CUDA backend is not available.") 583 elif device_type.lower() == "mps": 584 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 585 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 586 elif device_type.lower() == "cpu": 587 pass # cpu is always available 588 else: 589 raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.") 590 return device
82def get_prediction( 83 input_volume: np.ndarray, # [z, y, x] 84 tiling: Optional[Dict[str, Dict[str, int]]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 85 model_path: Optional[str] = None, 86 model: Optional[torch.nn.Module] = None, 87 verbose: bool = True, 88 with_channels: bool = False, 89 mask: Optional[np.ndarray] = None, 90) -> np.ndarray: 91 """Run prediction on a given volume. 92 93 This function will automatically choose the correct prediction implementation, 94 depending on the model type. 95 96 Args: 97 input_volume: The input volume to predict on. 98 model_path: The path to the model checkpoint if 'model' is not provided. 99 model: Pre-loaded model. Either model_path or model is required. 100 tiling: The tiling configuration for the prediction. 101 verbose: Whether to print timing information. 102 with_channels: Whether to predict with channels. 103 mask: Optional binary mask. If given, the prediction will only be run in 104 the foreground region of the mask. 105 106 Returns: 107 The predicted volume. 108 """ 109 # make sure either model path or model is passed 110 if model is None and model_path is None: 111 raise ValueError("Either 'model_path' or 'model' must be provided.") 112 113 if model is not None: 114 is_bioimageio = None 115 else: 116 is_bioimageio = model_path.endswith(".zip") 117 118 if tiling is None: 119 tiling = get_default_tiling() 120 121 # We standardize the data for the whole volume beforehand. 122 # If we have channels then the standardization is done independently per channel. 123 if with_channels: 124 # TODO Check that this is the correct axis. 125 input_volume = torch_em.transform.raw.standardize(input_volume, axis=(1, 2, 3)) 126 else: 127 input_volume = torch_em.transform.raw.standardize(input_volume) 128 129 # Run prediction with the bioimage.io library. 130 if is_bioimageio: 131 if mask is not None: 132 raise NotImplementedError 133 raise NotImplementedError 134 135 # Run prediction with the torch-em library. 136 else: 137 if model is None: 138 # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself. 139 if model_path.endswith("best.pt"): 140 model_path = os.path.split(model_path)[0] 141 # print(f"tiling {tiling}") 142 # Create updated_tiling with the same structure 143 updated_tiling = { 144 "tile": {}, 145 "halo": tiling["halo"] # Keep the halo part unchanged 146 } 147 # Update tile dimensions 148 for dim in tiling["tile"]: 149 updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim] 150 # print(f"updated_tiling {updated_tiling}") 151 pred = get_prediction_torch_em( 152 input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask 153 ) 154 155 return pred
Run prediction on a given volume.
This function will automatically choose the correct prediction implementation, depending on the model type.
Arguments:
- input_volume: The input volume to predict on.
- model_path: The path to the model checkpoint if 'model' is not provided.
- model: Pre-loaded model. Either model_path or model is required.
- tiling: The tiling configuration for the prediction.
- verbose: Whether to print timing information.
- with_channels: Whether to predict with channels.
- mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
Returns:
The predicted volume.
158def get_prediction_torch_em( 159 input_volume: np.ndarray, # [z, y, x] 160 tiling: Dict[str, Dict[str, int]], # {"tile": {"z": int, ...}, "halo": {"z": int, ...}} 161 model_path: Optional[str] = None, 162 model: Optional[torch.nn.Module] = None, 163 verbose: bool = True, 164 with_channels: bool = False, 165 mask: Optional[np.ndarray] = None, 166) -> np.ndarray: 167 """Run prediction using torch-em on a given volume. 168 169 Args: 170 input_volume: The input volume to predict on. 171 model_path: The path to the model checkpoint if 'model' is not provided. 172 model: Pre-loaded model. Either model_path or model is required. 173 tiling: The tiling configuration for the prediction. 174 verbose: Whether to print timing information. 175 with_channels: Whether to predict with channels. 176 mask: Optional binary mask. If given, the prediction will only be run in 177 the foreground region of the mask. 178 179 Returns: 180 The predicted volume. 181 """ 182 # get block_shape and halo 183 block_shape = [tiling["tile"]["z"], tiling["tile"]["x"], tiling["tile"]["y"]] 184 halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]] 185 186 t0 = time.time() 187 device = "cuda" if torch.cuda.is_available() else "cpu" 188 189 # Suppress warning when loading the model. 190 with warnings.catch_warnings(): 191 warnings.simplefilter("ignore") 192 if model is None: 193 if os.path.isdir(model_path): # Load the model from a torch_em checkpoint. 194 model = torch_em.util.load_model(checkpoint=model_path, device=device) 195 else: # Load the model directly from a serialized pytorch model. 196 model = torch.load(model_path) 197 198 # Run prediction with the model. 199 with torch.no_grad(): 200 201 # Deal with 2D segmentation case 202 if len(input_volume.shape) == 2: 203 block_shape = [block_shape[1], block_shape[2]] 204 halo = [halo[1], halo[2]] 205 206 if mask is not None: 207 if verbose: 208 print("Run prediction with mask.") 209 mask = mask.astype("bool") 210 211 pred = predict_with_halo( 212 input_volume, model, gpu_ids=[device], 213 block_shape=block_shape, halo=halo, 214 preprocess=None, with_channels=with_channels, mask=mask, 215 ) 216 if verbose: 217 print("Prediction time in", time.time() - t0, "s") 218 return pred
Run prediction using torch-em on a given volume.
Arguments:
- input_volume: The input volume to predict on.
- model_path: The path to the model checkpoint if 'model' is not provided.
- model: Pre-loaded model. Either model_path or model is required.
- tiling: The tiling configuration for the prediction.
- verbose: Whether to print timing information.
- with_channels: Whether to predict with channels.
- mask: Optional binary mask. If given, the prediction will only be run in the foreground region of the mask.
Returns:
The predicted volume.
289def inference_helper( 290 input_path: str, 291 output_root: str, 292 segmentation_function: callable, 293 data_ext: str = ".mrc", 294 extra_input_path: Optional[str] = None, 295 extra_input_ext: str = ".tif", 296 mask_input_path: Optional[str] = None, 297 mask_input_ext: str = ".tif", 298 force: bool = False, 299 output_key: Optional[str] = None, 300 model_resolution: Optional[Tuple[float, float, float]] = None, 301 scale: Optional[Tuple[float, float, float]] = None, 302) -> None: 303 """Helper function to run segmentation for mrc files. 304 305 Args: 306 input_path: The path to the input data. 307 Can either be a folder. In this case all mrc files below the folder will be segmented. 308 Or can be a single mrc file. In this case only this mrc file will be segmented. 309 output_root: The path to the output directory where the segmentation results will be saved. 310 segmentation_function: The function performing the segmentation. 311 This function must take the input_volume as the only argument and must return only the segmentation. 312 If you want to pass additional arguments to this function the use 'funtools.partial' 313 data_ext: File extension for the image data. By default '.mrc' is used. 314 extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. 315 This enables cristae segmentation with an extra mito channel. 316 extra_input_ext: File extension for the extra inputs (by default .tif). 317 mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation. 318 mask_input_ext: File extension for the mask inputs (by default .tif). 319 force: Whether to rerun segmentation for output files that are already present. 320 output_key: Output key for the prediction. If none will write an hdf5 file. 321 model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. 322 If given, the scaling factor will automatically be determined based on the voxel_size of the input data. 323 scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'. 324 """ 325 if (scale is not None) and (model_resolution is not None): 326 raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.") 327 328 # Get the input files. If input_path is a folder then this will load all 329 # the mrc files beneath it. Otherwise we assume this is an mrc file already 330 # and just return the path to this mrc file. 331 input_files, input_root = _get_file_paths(input_path, data_ext) 332 333 # Load extra inputs if the extra_input_path was specified. 334 if extra_input_path is None: 335 extra_files = None 336 else: 337 extra_files, _ = _get_file_paths(extra_input_path, extra_input_ext) 338 assert len(input_files) == len(extra_files) 339 340 # Load the masks if they were specified. 341 if mask_input_path is None: 342 mask_files = None 343 else: 344 mask_files, _ = _get_file_paths(mask_input_path, mask_input_ext) 345 assert len(input_files) == len(mask_files) 346 347 for i, img_path in tqdm(enumerate(input_files), total=len(input_files), desc="Processing files"): 348 # Determine the output file name. 349 input_folder, input_name = os.path.split(img_path) 350 351 if output_key is None: 352 fname = os.path.splitext(input_name)[0] + "_prediction.tif" 353 else: 354 fname = os.path.splitext(input_name)[0] + "_prediction.h5" 355 356 if input_root is None: 357 output_path = os.path.join(output_root, fname) 358 else: # If we have nested input folders then we preserve the folder structure in the output. 359 rel_folder = os.path.relpath(input_folder, input_root) 360 output_path = os.path.join(output_root, rel_folder, fname) 361 362 # Check if the output path is already present. 363 # If it is we skip the prediction, unless force was set to true. 364 if os.path.exists(output_path) and not force: 365 if output_key is None: 366 continue 367 else: 368 with open_file(output_path, "r") as f: 369 if output_key in f: 370 continue 371 372 # Load the input volume. If we have extra_files then this concatenates the 373 # data across a new first axis (= channel axis). 374 input_volume = _load_input(img_path, extra_files, i) 375 # Load the mask (if given). 376 mask = None if mask_files is None else imageio.imread(mask_files[i]) 377 378 # Determine the scale factor: 379 # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None. 380 if scale is None and model_resolution is None: 381 this_scale = None 382 elif scale is not None: # If 'scale' was passed then use it. 383 this_scale = scale 384 else: # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data 385 assert model_resolution is not None 386 this_scale = _derive_scale(img_path, model_resolution) 387 388 # Run the segmentation. 389 segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale) 390 391 # Write the result to tif or h5. 392 os.makedirs(os.path.split(output_path)[0], exist_ok=True) 393 394 if output_key is None: 395 imageio.imwrite(output_path, segmentation, compression="zlib") 396 else: 397 with open_file(output_path, "a") as f: 398 f.create_dataset(output_key, data=segmentation, compression="gzip") 399 400 print(f"Saved segmentation to {output_path}.")
Helper function to run segmentation for mrc files.
Arguments:
- input_path: The path to the input data. Can either be a folder. In this case all mrc files below the folder will be segmented. Or can be a single mrc file. In this case only this mrc file will be segmented.
- output_root: The path to the output directory where the segmentation results will be saved.
- segmentation_function: The function performing the segmentation. This function must take the input_volume as the only argument and must return only the segmentation. If you want to pass additional arguments to this function the use 'funtools.partial'
- data_ext: File extension for the image data. By default '.mrc' is used.
- extra_input_path: Filepath to extra inputs that need to be concatenated to the raw data loaded from mrc. This enables cristae segmentation with an extra mito channel.
- extra_input_ext: File extension for the extra inputs (by default .tif).
- mask_input_path: Filepath to mask(s) that will be used to restrict the segmentation.
- mask_input_ext: File extension for the mask inputs (by default .tif).
- force: Whether to rerun segmentation for output files that are already present.
- output_key: Output key for the prediction. If none will write an hdf5 file.
- model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
- scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
403def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: 404 """Determine the tile shape and halo depending on the available VRAM. 405 406 Args: 407 is_2d: Whether to return tiling settings for 2d inference. 408 409 Returns: 410 The default tiling settings for the available computational resources. 411 """ 412 if is_2d: 413 tile = {"x": 768, "y": 768, "z": 1} 414 halo = {"x": 128, "y": 128, "z": 0} 415 return {"tile": tile, "halo": halo} 416 417 if torch.cuda.is_available(): 418 # The default halo size. 419 halo = {"x": 64, "y": 64, "z": 16} 420 421 # Determine the GPU RAM and derive a suitable tiling. 422 vram = torch.cuda.get_device_properties(0).total_memory / 1e9 423 424 if vram >= 80: 425 tile = {"x": 640, "y": 640, "z": 80} 426 elif vram >= 40: 427 tile = {"x": 512, "y": 512, "z": 64} 428 elif vram >= 20: 429 tile = {"x": 352, "y": 352, "z": 48} 430 elif vram >= 10: 431 tile = {"x": 256, "y": 256, "z": 32} 432 halo = {"x": 64, "y": 64, "z": 8} # Choose a smaller halo in z. 433 else: 434 raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.") 435 436 print(f"Determined tile size: {tile}") 437 tiling = {"tile": tile, "halo": halo} 438 439 # I am not sure what is reasonable on a cpu. For now choosing very small tiling. 440 # (This will not work well on a CPU in any case.) 441 else: 442 print("Determining default tiling") 443 tiling = { 444 "tile": {"x": 96, "y": 96, "z": 16}, 445 "halo": {"x": 16, "y": 16, "z": 4}, 446 } 447 448 return tiling
Determine the tile shape and halo depending on the available VRAM.
Arguments:
- is_2d: Whether to return tiling settings for 2d inference.
Returns:
The default tiling settings for the available computational resources.
451def parse_tiling( 452 tile_shape: Tuple[int, int, int], 453 halo: Tuple[int, int, int], 454 is_2d: bool = False, 455) -> Dict[str, Dict[str, int]]: 456 """Helper function to parse tiling parameter input from the command line. 457 458 Args: 459 tile_shape: The tile shape. If None the default tile shape is used. 460 halo: The halo. If None the default halo is used. 461 is_2d: Whether to return tiling for a 2d model. 462 463 Returns: 464 The tiling specification. 465 """ 466 467 default_tiling = get_default_tiling(is_2d=is_2d) 468 469 if tile_shape is None: 470 tile_shape = default_tiling["tile"] 471 else: 472 assert len(tile_shape) == 3 473 tile_shape = dict(zip("zyx", tile_shape)) 474 475 if halo is None: 476 halo = default_tiling["halo"] 477 else: 478 assert len(halo) == 3 479 halo = dict(zip("zyx", halo)) 480 481 tiling = {"tile": tile_shape, "halo": halo} 482 return tiling
Helper function to parse tiling parameter input from the command line.
Arguments:
- tile_shape: The tile shape. If None the default tile shape is used.
- halo: The halo. If None the default halo is used.
- is_2d: Whether to return tiling for a 2d model.
Returns:
The tiling specification.
490def apply_size_filter( 491 segmentation: np.ndarray, 492 min_size: int, 493 verbose: bool = False, 494 block_shape: Tuple[int, int, int] = (128, 256, 256), 495) -> np.ndarray: 496 """Apply size filter to the segmentation to remove small objects. 497 498 Args: 499 segmentation: The segmentation. 500 min_size: The minimal object size in pixels. 501 verbose: Whether to print runtimes. 502 block_shape: Block shape for parallelizing the operations. 503 504 Returns: 505 The size filtered segmentation. 506 """ 507 if min_size == 0: 508 return segmentation 509 t0 = time.time() 510 if segmentation.ndim == 2 and len(block_shape) == 3: 511 block_shape_ = block_shape[1:] 512 else: 513 block_shape_ = block_shape 514 ids, sizes = parallel.unique(segmentation, return_counts=True, block_shape=block_shape_, verbose=verbose) 515 filter_ids = ids[sizes < min_size] 516 segmentation[np.isin(segmentation, filter_ids)] = 0 517 if verbose: 518 print("Size filter in", time.time() - t0, "s") 519 return segmentation
Apply size filter to the segmentation to remove small objects.
Arguments:
- segmentation: The segmentation.
- min_size: The minimal object size in pixels.
- verbose: Whether to print runtimes.
- block_shape: Block shape for parallelizing the operations.
Returns:
The size filtered segmentation.
565def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 566 """Get the torch device. 567 568 If no device is passed the default device for your system is used. 569 Else it will be checked if the device you have passed is supported. 570 571 Args: 572 device: The input device. 573 574 Returns: 575 The device. 576 """ 577 if device is None or device == "auto": 578 device = _get_default_device() 579 else: 580 device_type = device if isinstance(device, str) else device.type 581 if device_type.lower() == "cuda": 582 if not torch.cuda.is_available(): 583 raise RuntimeError("PyTorch CUDA backend is not available.") 584 elif device_type.lower() == "mps": 585 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 586 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 587 elif device_type.lower() == "cpu": 588 pass # cpu is always available 589 else: 590 raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.") 591 return device
Get the torch device.
If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.
Arguments:
- device: The input device.
Returns:
The device.