micro_sam.evaluation.benchmark_datasets
1import os 2import time 3from glob import glob 4from tqdm import tqdm 5from natsort import natsorted 6from typing import Union, Optional, List, Literal 7 8import numpy as np 9import pandas as pd 10import imageio.v3 as imageio 11from skimage.measure import label as connected_components 12 13from nifty.tools import blocking 14 15import torch 16 17from torch_em.data import datasets 18 19from micro_sam import util 20 21from . import run_evaluation 22from ..training.training import _filter_warnings 23from .inference import run_inference_with_iterative_prompting 24from .evaluation import run_evaluation_for_iterative_prompting 25from .multi_dimensional_segmentation import segment_slices_from_ground_truth 26from ..automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter 27 28 29LM_2D_DATASETS = [ 30 "livecell", "deepbacs", "tissuenet", "neurips_cellseg", "dynamicnuclearnet", 31 "hpa", "covid_if", "pannuke", "lizard", "orgasegment", "omnipose", "dic_hepg2", 32] 33 34LM_3D_DATASETS = [ 35 "plantseg_root", "plantseg_ovules", "gonuclear", "mouse_embryo", "embegseg", "cellseg3d" 36] 37 38EM_2D_DATASETS = ["mitolab_tem"] 39 40EM_3D_DATASETS = [ 41 "mitoem_rat", "mitoem_human", "platynereis_nuclei", "lucchi", "mitolab", "nuc_mm_mouse", 42 "num_mm_zebrafish", "uro_cell", "sponge_em", "platynereis_cilia", "vnc", "asem_mito", 43] 44 45DATASET_RETURNS_FOLDER = { 46 "deepbacs": "*.tif" 47} 48 49DATASET_CONTAINER_KEYS = { 50 "lucchi": ["raw", "labels"], 51} 52 53 54def _download_benchmark_datasets(path, dataset_choice): 55 """Ensures whether all the datasets have been downloaded or not. 56 57 Args: 58 path: The path to directory where the supported datasets will be downloaded 59 for benchmarking Segment Anything models. 60 dataset_choice: The choice of dataset, expects the lower case name for the dataset. 61 62 Returns: 63 List of choice of dataset(s). 64 """ 65 available_datasets = { 66 # Light Microscopy datasets 67 "livecell": lambda: datasets.livecell.get_livecell_data( 68 path=os.path.join(path, "livecell"), split="test", download=True, 69 ), 70 "deepbacs": lambda: datasets.deepbacs.get_deepbacs_data( 71 path=os.path.join(path, "deepbacs"), bac_type="mixed", download=True, 72 ), 73 "tissuenet": lambda: datasets.tissuenet.get_tissuenet_data( 74 path=os.path.join(path, "tissuenet"), split="test", download=True, 75 ), 76 "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_data( 77 root=os.path.join(path, "neurips_cellseg"), split="test", download=True, 78 ), 79 "plantseg_root": lambda: datasets.plantseg.get_plantseg_data( 80 path=os.path.join(path, "plantseg"), download=True, name="root", 81 ), 82 "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_data( 83 path=os.path.join(path, "plantseg"), download=True, name="ovules", 84 ), 85 "covid_if": lambda: datasets.covid_if.get_covid_if_data( 86 path=os.path.join(path, "covid_if"), download=True, 87 ), 88 "hpa": lambda: datasets.hpa.get_hpa_segmentation_data( 89 path=os.path.join(path, "hpa"), download=True, 90 ), 91 "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_data( 92 path=os.path.join(path, "dynamicnuclearnet"), split="test", download=True, 93 ), 94 "pannuke": lambda: datasets.pannuke.get_pannuke_data( 95 path=os.path.join(path, "pannuke"), download=True, folds=["fold_1", "fold_2", "fold_3"], 96 ), 97 "lizard": lambda: datasets.lizard.get_lizard_data( 98 path=os.path.join(path, "lizard"), download=True, 99 ), 100 "orgasegment": lambda: datasets.orgasegment.get_orgasegment_data( 101 path=os.path.join(path, "orgasegment"), split="eval", download=True, 102 ), 103 "omnipose": lambda: datasets.omnipose.get_omnipose_data( 104 path=os.path.join(path, "omnipose"), download=True, 105 ), 106 "gonuclear": lambda: datasets.gonuclear.get_gonuclear_data( 107 path=os.path.join(path, "gonuclear"), download=True, 108 ), 109 "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_data( 110 path=os.path.join(path, "mouse_embryo"), download=True, 111 ), 112 "embedseg_data": lambda: [ 113 datasets.embedseg_data.get_embedseg_data(path=os.path.join(path, "embedseg_data"), download=True, name=name) 114 for name in datasets.embedseg_data.URLS.keys() 115 ], 116 "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_data( 117 path=os.path.join(path, "cellseg_3d"), download=True, 118 ), 119 "dic_hepg2": lambda: datasets.dic_hepg2.get_dic_hepg2_data( 120 path=os.path.join(path, "dic_hepg2"), download=True, 121 ), 122 123 # Electron Microscopy datasets 124 "mitoem_rat": lambda: datasets.mitoem.get_mitoem_data( 125 path=os.path.join(path, "mitoem"), samples="rat", split="test", download=True, 126 ), 127 "mitoem_human": lambda: datasets.mitoem.get_mitoem_data( 128 path=os.path.join(path, "mitoem"), samples="human", split="test", download=True, 129 ), 130 "platynereis_nuclei": lambda: datasets.platynereis.get_platy_data( 131 path=os.path.join(path, "platynereis"), name="nuclei", download=True, 132 ), 133 "platynereis_cilia": lambda: datasets.platynereis.get_platy_data( 134 path=os.path.join(path, "platynereis"), name="cilia", download=True, 135 ), 136 "lucchi": lambda: datasets.lucchi.get_lucchi_data( 137 path=os.path.join(path, "lucchi"), split="test", download=True, 138 ), 139 "mitolab_3d": lambda: [ 140 datasets.cem.get_benchmark_data( 141 path=os.path.join(path, "mitolab"), dataset_id=dataset_id, download=True, 142 ) for dataset_id in range(1, 7) 143 ], 144 "mitolab_tem": lambda: datasets.cem.get_benchmark_data( 145 path=os.path.join(path, "mitolab"), dataset_id=7, download=True 146 ), 147 "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_data( 148 path=os.path.join(path, "nuc_mm"), sample="mouse", download=True, 149 ), 150 "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_data( 151 path=os.path.join(path, "nuc_mm"), sample="zebrafish", download=True, 152 ), 153 "uro_cell": lambda: datasets.uro_cell.get_uro_cell_data( 154 path=os.path.join(path, "uro_cell"), download=True, 155 ), 156 "sponge_em": lambda: datasets.sponge_em.get_sponge_em_data( 157 path=os.path.join(path, "sponge_em"), download=True, 158 ), 159 "vnc": lambda: datasets.vnc.get_vnc_data( 160 path=os.path.join(path, "vnc"), download=True, 161 ), 162 "asem_mito": lambda: datasets.asem.get_asem_data( 163 path=os.path.join(path, "asem"), volume_ids=datasets.asem.ORGANELLES["mito"], download=True, 164 ) 165 } 166 167 if dataset_choice is None: 168 dataset_choice = available_datasets.keys() 169 else: 170 if not isinstance(dataset_choice, list): 171 dataset_choice = [dataset_choice] 172 173 for choice in dataset_choice: 174 if choice in available_datasets: 175 available_datasets[choice]() 176 else: 177 raise ValueError(f"'{choice}' is not a supported choice of dataset.") 178 179 return dataset_choice 180 181 182def _extract_slices_from_dataset(path, dataset_choice, crops_per_input=10): 183 """Extracts crops of desired shapes for performing evaluation in both 2d and 3d using `micro-sam`. 184 185 Args: 186 path: The path to directory where the supported datasets have be downloaded 187 for benchmarking Segment Anything models. 188 dataset_choice: The name of the dataset of choice to extract crops. 189 crops_per_input: The maximum number of crops to extract per inputs. 190 extract_2d: Whether to extract 2d crops from 3d patches. 191 192 Returns: 193 Filepath to the folder where extracted images are stored. 194 Filepath to the folder where corresponding extracted labels are stored. 195 The number of dimensions supported by the input. 196 """ 197 ndim = 2 if dataset_choice in [*LM_2D_DATASETS, *EM_2D_DATASETS] else 3 198 tile_shape = (512, 512) if ndim == 2 else (32, 512, 512) 199 200 # For 3d inputs, we extract both 2d and 3d crops. 201 extract_2d_crops_from_volumes = (ndim == 3) 202 203 available_datasets = { 204 # Light Microscopy datasets 205 "livecell": lambda: datasets.livecell.get_livecell_paths(path=path, split="test"), 206 "deepbacs": lambda: datasets.deepbacs.get_deepbacs_paths(path=path, split="test", bac_type="mixed"), 207 "tissuenet": lambda: datasets.tissuenet.get_tissuenet_paths(path=path, split="test"), 208 "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_paths(root=path, split="test"), 209 "plantseg_root": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="root", split="test"), 210 "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="ovules", split="test"), 211 "covid_if": lambda: datasets.covid_if.get_covid_if_paths(path=path), 212 "hpa": lambda: datasets.hpa.get_hpa_segmentation_paths(path=path, split="test"), 213 "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_paths(path=path, split="test"), 214 "pannuke": lambda: datasets.pannuke.get_pannuke_paths(path=path), 215 "lizard": lambda: datasets.lizard.get_lizard_paths(parth=path), 216 "orgasegment": lambda: datasets.orgasegment.get_orgasegment_paths(path=path, split="eval"), 217 "omnipose": lambda: datasets.omnipose.get_omnipose_paths(path=path, split="test"), 218 "gonuclear": lambda: datasets.gonuclear.get_gonuclear_paths(path-path), 219 "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_paths(path=path, name="nuclei", split="val"), 220 "embedseg_data": lambda: datasets.embedseg_data.get_embedseg_paths( 221 path=path, name=list(datasets.embedseg_data.URLS.keys())[0], split="test" 222 ), 223 "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_paths(path=path), 224 "dic_hepg2": lambda: datasets.dic_hepg2.get_dic_hepg2_paths(path=path, split="test"), 225 226 # Electron Microscopy datasets 227 "mitoem_rat": lambda: datasets.mitoem.get_mitoem_paths(path=path, splits="test", samples="rat"), 228 "mitem_human": lambda: datasets.mitoem.get_mitoem_paths(path=path, splits="test", samples="human"), 229 "platynereis_nuclei": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="nuclei"), 230 "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="cilia"), 231 "lucchi": lambda: datasets.lucchi.get_lucchi_paths(path=path, split="test"), 232 "mitolab_3d": lambda: ( 233 [rpath for i in range(1, 7) for rpath in datasets.cem.get_benchmark_paths(path=path, dataset_id=i)[0]], 234 [lpath for i in range(1, 7) for lpath in datasets.cem.get_benchmark_paths(path=path, dataset_id=i)[1]] 235 ), 236 "mitolab_tem": lambda: datasets.cem.get_benchmark_paths(path=path, dataset_id=7), 237 "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="mouse", split="val"), 238 "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="zebrafish", split="val"), 239 "uro_cell": lambda: datasets.uro_cell.get_uro_cell_paths(path=path, target="mito"), 240 "sponge_em": lambda: datasets.sponge_em.get_sponge_em_paths(path=path, sample_ids=None), 241 "vnc": lambda: datasets.vnc.get_vnc_mito_paths(path=path), 242 "asem_mito": lambda: datasets.asem.get_asem_paths(path=path, volume_ids=datasets.asem.ORGANELLES["mito"]) 243 } 244 245 if ndim == 2: 246 image_paths, gt_paths = available_datasets[dataset_choice]() 247 248 if dataset_choice in DATASET_RETURNS_FOLDER: 249 image_paths = glob(os.path.join(image_paths, DATASET_RETURNS_FOLDER[dataset_choice])) 250 gt_paths = glob(os.path.join(gt_paths, DATASET_RETURNS_FOLDER[dataset_choice])) 251 252 image_paths, gt_paths = natsorted(image_paths), natsorted(gt_paths) 253 assert len(image_paths) == len(gt_paths) 254 255 paths_set = zip(image_paths, gt_paths) 256 257 else: 258 image_paths = available_datasets[dataset_choice]() 259 if isinstance(image_paths, str): 260 paths_set = [image_paths] 261 else: 262 paths_set = natsorted(image_paths) 263 264 # Directory where we store the extracted ROIs. 265 save_image_dir = [os.path.join(path, f"roi_{ndim}d", "inputs")] 266 save_gt_dir = [os.path.join(path, f"roi_{ndim}d", "labels")] 267 if extract_2d_crops_from_volumes: 268 save_image_dir.append(os.path.join(path, "roi_2d", "inputs")) 269 save_gt_dir.append(os.path.join(path, "roi_2d", "labels")) 270 271 _dir_exists = [ 272 os.path.exists(idir) and os.path.exists(gdir) for idir, gdir in zip(save_image_dir, save_gt_dir) 273 ] 274 if all(_dir_exists): 275 return ndim 276 277 [os.makedirs(idir, exist_ok=True) for idir in save_image_dir] 278 [os.makedirs(gdir, exist_ok=True) for gdir in save_gt_dir] 279 280 # Logic to extract relevant patches for inference 281 image_counter = 1 282 for per_paths in tqdm(paths_set, desc=f"Extracting patches for {dataset_choice}"): 283 if ndim == 2: 284 image_path, gt_path = per_paths 285 image, gt = util.load_image_data(image_path), util.load_image_data(gt_path) 286 else: 287 image_path = per_paths 288 image = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][0]) 289 gt = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][1]) 290 291 skip_smaller_shape = (np.array(image.shape) >= np.array(tile_shape)).all() 292 293 # Ensure ground truth has instance labels. 294 gt = connected_components(gt) 295 296 if len(np.unique(gt)) == 1: # There could be labels which does not have any annotated foreground. 297 continue 298 299 # Let's extract and save all the crops. 300 # NOTE: The first round of extraction is always to match the desired input dimensions. 301 image_crops, gt_crops = _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input) 302 image_counter = _save_image_label_crops( 303 image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir[0], save_gt_dir[0] 304 ) 305 306 # NOTE: The next round of extraction is to get 2d crops from 3d inputs. 307 if extract_2d_crops_from_volumes: 308 curr_tile_shape = tile_shape[-2:] # NOTE: We expect 2d tile shape for this stage. 309 310 curr_image_crops, curr_gt_crops = [], [] 311 for per_z_im, per_z_gt in zip(image, gt): 312 curr_skip_smaller_shape = (np.array(per_z_im.shape) >= np.array(curr_tile_shape)).all() 313 314 image_crops, gt_crops = _get_crops_for_input( 315 image=per_z_im, gt=per_z_gt, ndim=2, 316 tile_shape=curr_tile_shape, 317 skip_smaller_shape=curr_skip_smaller_shape, 318 crops_per_input=crops_per_input, 319 ) 320 curr_image_crops.extend(image_crops) 321 curr_gt_crops.extend(gt_crops) 322 323 image_counter = _save_image_label_crops( 324 curr_image_crops, curr_gt_crops, dataset_choice, 2, image_counter, save_image_dir[1], save_gt_dir[1] 325 ) 326 327 return ndim 328 329 330def _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input): 331 tiling = blocking([0] * ndim, gt.shape, tile_shape) 332 n_tiles = tiling.numberOfBlocks 333 tiles = [tiling.getBlock(tile_id) for tile_id in range(n_tiles)] 334 crop_boxes = [ 335 tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) for tile in tiles 336 ] 337 n_ids = [idx for idx in range(len(crop_boxes))] 338 n_instances = [len(np.unique(gt[crop])) for crop in crop_boxes] 339 340 # Extract the desired number of patches with higher number of instances. 341 image_crops, gt_crops = [], [] 342 for i, (per_n_instance, per_id) in enumerate(sorted(zip(n_instances, n_ids), reverse=True), start=1): 343 crop_box = crop_boxes[per_id] 344 crop_image, crop_gt = image[crop_box], gt[crop_box] 345 # NOTE: We avoid using the crops which do not match the desired tile shape. 346 if skip_smaller_shape and crop_image.shape != tile_shape: 347 continue 348 349 # NOTE: There could be a case where some later patches are invalid. 350 if per_n_instance == 1: 351 break 352 353 image_crops.append(crop_image) 354 gt_crops.append(crop_gt) 355 356 # NOTE: If the number of patches extracted have been fulfiled, we stop sampling patches. 357 if len(image_crops) > 0 and i >= crops_per_input: 358 break 359 360 return image_crops, gt_crops 361 362 363def _save_image_label_crops(image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir, save_gt_dir): 364 for image_crop, gt_crop in tqdm( 365 zip(image_crops, gt_crops), total=len(image_crops), desc=f"Saving {ndim}d crops for {dataset_choice}" 366 ): 367 fname = f"{dataset_choice}_{image_counter:05}.tif" 368 assert image_crop.shape == gt_crop.shape 369 imageio.imwrite(os.path.join(save_image_dir, fname), image_crop, compression="zlib") 370 imageio.imwrite(os.path.join(save_gt_dir, fname), gt_crop, compression="zlib") 371 image_counter += 1 372 373 return image_counter 374 375 376def _get_image_label_paths(path, ndim): 377 image_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "inputs", "*"))) 378 gt_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "labels", "*"))) 379 return image_paths, gt_paths 380 381 382def _run_automatic_segmentation_per_dataset( 383 image_paths: List[Union[os.PathLike, str]], 384 gt_paths: List[Union[os.PathLike, str]], 385 model_type: str, 386 output_folder: Union[os.PathLike, str], 387 ndim: Optional[int] = None, 388 device: Optional[Union[torch.device, str]] = None, 389 checkpoint_path: Optional[Union[os.PathLike, str]] = None, 390 run_amg: bool = False, 391 **auto_seg_kwargs 392): 393 """Functionality to run automatic segmentation for multiple input files at once. 394 It stores the evaluated automatic segmentation results (quantitative). 395 396 Args: 397 image_paths: List of filepaths for the input image data. 398 gt_paths: List of filepaths for the corresponding label data. 399 model_type: The choice of image encoder for the Segment Anything model. 400 output_folder: Filepath to the folder where we store all the results. 401 ndim: The number of input dimensions. 402 device: The torch device. 403 checkpoint_path: The filepath where the model checkpoints are stored. 404 run_amg: Whether to run automatic segmentation in AMG mode. 405 auto_seg_kwargs: Additional arguments for automatic segmentation parameters. 406 """ 407 experiment_name = "AMG" if run_amg else "AIS" 408 fname = f"{experiment_name.lower()}_{ndim}d" 409 410 result_path = os.path.join(output_folder, "results", f"{fname}.csv") 411 prediction_dir = os.path.join(output_folder, fname, "inference") 412 if os.path.exists(prediction_dir): 413 return 414 415 os.makedirs(prediction_dir, exist_ok=True) 416 417 # Get the predictor (and the additional instance segmentation decoder, if available). 418 predictor, segmenter = get_predictor_and_segmenter( 419 model_type=model_type, checkpoint=checkpoint_path, device=device, amg=run_amg, is_tiled=False, 420 ) 421 422 for image_path in tqdm(image_paths, desc=f"Run {experiment_name} in {ndim}d"): 423 output_path = os.path.join(prediction_dir, os.path.basename(image_path)) 424 if os.path.exists(output_path): 425 continue 426 427 # Run Automatic Segmentation (AMG and AIS) 428 automatic_instance_segmentation( 429 predictor=predictor, 430 segmenter=segmenter, 431 input_path=image_path, 432 output_path=output_path, 433 ndim=ndim, 434 verbose=False, 435 **auto_seg_kwargs 436 ) 437 438 prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*"))) 439 run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths, save_path=result_path) 440 441 442def _run_interactive_segmentation_per_dataset( 443 image_paths: List[Union[os.PathLike, str]], 444 gt_paths: List[Union[os.PathLike, str]], 445 output_folder: Union[os.PathLike, str], 446 model_type: str, 447 prompt_choice: Literal["box", "points"], 448 device: Optional[Union[torch.device, str]] = None, 449 ndim: Optional[int] = None, 450 checkpoint_path: Optional[Union[os.PathLike, str]] = None, 451): 452 """Functionality to run interactive segmentation for multiple input files at once. 453 It stores the evaluated interactive segmentation results. 454 455 Args: 456 image_paths: List of filepaths for the input image data. 457 gt_paths: List of filepaths for the corresponding label data. 458 output_folder: Filepath to the folder where we store all the results. 459 model_type: The choice of model type for Segment Anything. 460 prompt_choice: The choice of initial prompts to begin the interactive segmentation. 461 device: The torch device. 462 ndim: The number of input dimensions. 463 checkpoint_path: The filepath for stored checkpoints. 464 """ 465 if ndim == 2: 466 # Get the Segment Anything predictor. 467 predictor = util.get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path) 468 469 # Run interactive instance segmentation 470 # (starting with box and points followed by iterative prompt-based correction) 471 run_inference_with_iterative_prompting( 472 predictor=predictor, 473 image_paths=image_paths, 474 gt_paths=gt_paths, 475 embedding_dir=None, # We set this to None to compute embeddings on-the-fly. 476 prediction_dir=os.path.join(output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}"), 477 start_with_box_prompt=(prompt_choice == "box"), 478 # TODO: add parameter for deform over box prompts (to simulate prompts in practice). 479 ) 480 481 # Evaluate the interactive instance segmentation. 482 run_evaluation_for_iterative_prompting( 483 gt_paths=gt_paths, 484 prediction_root=os.path.join(output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}"), 485 experiment_folder=output_folder, 486 start_with_box_prompt=(prompt_choice == "box"), 487 ) 488 489 else: 490 save_path = os.path.join(output_folder, "results", f"interactive_segmentation_3d_with_{prompt_choice}.csv") 491 if os.path.exists(save_path): 492 print( 493 f"Results for 3d interactive segmentation with '{prompt_choice}' are already stored at '{save_path}'." 494 ) 495 return 496 497 results = [] 498 for image_path, gt_path in tqdm( 499 zip(image_paths, gt_paths), total=len(image_paths), 500 desc=f"Run interactive segmentation in 3d with '{prompt_choice}'" 501 ): 502 prediction_dir = os.path.join(output_folder, "interactive_segmentation_3d", f"{prompt_choice}") 503 os.makedirs(prediction_dir, exist_ok=True) 504 505 prediction_path = os.path.join(prediction_dir, os.path.basename(image_path)) 506 if os.path.exists(prediction_path): 507 continue 508 509 per_vol_result = segment_slices_from_ground_truth( 510 volume=imageio.imread(image_path), 511 ground_truth=imageio.imread(gt_path), 512 model_type=model_type, 513 checkpoint_path=checkpoint_path, 514 save_path=prediction_path, 515 device=device, 516 interactive_seg_mode=prompt_choice, 517 min_size=10, 518 ) 519 results.append(per_vol_result) 520 521 results = pd.concat(results) 522 results = results.groupby(results.index).mean() 523 results.to_csv(save_path) 524 525 526def _run_benchmark_evaluation_series( 527 image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg, 528): 529 seg_kwargs = { 530 "image_paths": image_paths, 531 "gt_paths": gt_paths, 532 "output_folder": output_folder, 533 "ndim": ndim, 534 "model_type": model_type, 535 "device": device, 536 "checkpoint_path": checkpoint_path, 537 } 538 539 # Perform: 540 # a. automatic segmentation (supported in both 2d and 3d, wherever relevant) 541 # The automatic segmentation steps below are configured in a way that AIS has priority (if decoder is found) 542 # Else, it runs for AMG. 543 # Next, we check if the user expects to run AMG as well (after the run for AIS). 544 545 # i. Run automatic segmentation method supported with the SAM model (AMG or AIS). 546 _run_automatic_segmentation_per_dataset(run_amg=False, **seg_kwargs) 547 548 # ii. Run automatic mask generation (AMG) (in case the first run is AIS). 549 _run_automatic_segmentation_per_dataset(run_amg=run_amg, **seg_kwargs) 550 551 # b. Run interactive segmentation (supported in both 2d and 3d, wherever relevant) 552 _run_interactive_segmentation_per_dataset(prompt_choice="box", **seg_kwargs) 553 _run_interactive_segmentation_per_dataset(prompt_choice="points", **seg_kwargs) 554 555 556def _clear_cached_items(retain, path, output_folder): 557 import shutil 558 from pathlib import Path 559 560 REMOVE_LIST = ["data", "crops", "auto", "int"] 561 if retain is None: 562 remove_list = REMOVE_LIST 563 else: 564 assert isinstance(retain, list) 565 remove_list = set(REMOVE_LIST) - set(retain) 566 567 paths = [] 568 # Stage 1: Remove inputs. 569 if "data" in remove_list or "crops" in remove_list: 570 all_paths = glob(os.path.join(path, "*")) 571 572 # In case we want to remove both data and crops, we remove the data folder entirely. 573 if "data" in remove_list and "crops" in remove_list: 574 paths.extend(all_paths) 575 return 576 577 # Next, we verify whether the we only remove either of data or crops. 578 for curr_path in all_paths: 579 if os.path.basename(curr_path).startswith("roi") and "crops" in remove_list: 580 paths.append(curr_path) 581 elif "data" in remove_list: 582 paths.append(curr_path) 583 584 # Stage 2: Remove predictions 585 if "auto" in remove_list: 586 paths.extend(glob(os.path.join(output_folder, "amg_*"))) 587 paths.extend(glob(os.path.join(output_folder, "ais_*"))) 588 589 if "int" in remove_list: 590 paths.extend(glob(os.path.join(output_folder, "interactive_segmentation_*"))) 591 592 [shutil.rmtree(_path) if Path(_path).is_dir() else os.remove(_path) for _path in paths] 593 594 595def run_benchmark_evaluations( 596 input_folder: Union[os.PathLike, str], 597 dataset_choice: str, 598 model_type: str = util._DEFAULT_MODEL, 599 output_folder: Optional[Union[str, os.PathLike]] = None, 600 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 601 run_amg: bool = False, 602 retain: Optional[List[str]] = None, 603 ignore_warnings: bool = False, 604): 605 """Run evaluation for benchmarking Segment Anything models on microscopy datasets. 606 607 Args: 608 input_folder: The path to directory where all inputs will be stored and preprocessed. 609 dataset_choice: The dataset choice. 610 model_type: The model choice for SAM. 611 output_folder: The path to directory where all outputs will be stored. 612 checkpoint_path: The checkpoint path 613 run_amg: Whether to run automatic segmentation in AMG mode. 614 retain: Whether to retain certain parts of the benchmark runs. 615 By default, removes everything besides quantitative results. 616 There is the choice to retain 'data', 'crops', 'auto', or 'int'. 617 ignore_warnings: Whether to ignore warnings. 618 """ 619 start = time.time() 620 621 with _filter_warnings(ignore_warnings): 622 device = util._get_default_device() 623 624 # Ensure if all the datasets have been installed by default. 625 dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice) 626 627 for choice in dataset_choice: 628 output_folder = os.path.join(output_folder, choice) 629 result_dir = os.path.join(output_folder, "results") 630 if os.path.exists(result_dir): 631 continue 632 633 os.makedirs(result_dir, exist_ok=True) 634 635 data_path = os.path.join(input_folder, choice) 636 637 # Extrapolate desired set from the datasets: 638 # a. for 2d datasets - 2d patches with the most number of labels present 639 # (in case of volumetric data, choose 2d patches per slice). 640 # b. for 3d datasets - 3d regions of interest with the most number of labels present. 641 ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10) 642 643 # Run inference and evaluation scripts on benchmark datasets. 644 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim) 645 _run_benchmark_evaluation_series( 646 image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg 647 ) 648 649 # Run inference and evaluation scripts on '2d' crops for volumetric datasets 650 if ndim == 3: 651 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2) 652 _run_benchmark_evaluation_series( 653 image_paths, gt_paths, model_type, output_folder, 2, device, checkpoint_path, run_amg 654 ) 655 656 _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder) 657 658 diff = time.time() - start 659 hours, rest = divmod(diff, 3600) 660 minutes, seconds = divmod(rest, 60) 661 print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {seconds:.2f}s") 662 663 664def main(): 665 """@private""" 666 import argparse 667 668 available_models = list(util.get_model_names()) 669 available_models = ", ".join(available_models) 670 671 parser = argparse.ArgumentParser( 672 description="Run evaluation for benchmarking Segment Anything models on microscopy datasets." 673 ) 674 parser.add_argument( 675 "-i", "--input_folder", type=str, required=True, 676 help="The path to a directory where the microscopy datasets are / will be stored." 677 ) 678 parser.add_argument( 679 "-m", "--model_type", type=str, default=util._DEFAULT_MODEL, 680 help=f"The segment anything model that will be used, one of {available_models}." 681 ) 682 parser.add_argument( 683 "-c", "--checkpoint_path", type=str, default=None, 684 help="Checkpoint from which the SAM model will be loaded loaded." 685 ) 686 parser.add_argument( 687 "-d", "--dataset_choice", type=str, nargs='*', default=None, 688 help="The choice(s) of dataset for evaluating SAM models. Multiple datasets can be specified." 689 ) 690 parser.add_argument( 691 "-o", "--output_folder", type=str, required=True, 692 help="The path where the results for automatic and interactive instance segmentation will be stored as 'csv'." 693 ) 694 parser.add_argument( 695 "--amg", action="store_true", 696 help="Whether to run automatic segmentation in AMG mode (i.e. the default auto-seg approach for SAM)." 697 ) 698 parser.add_argument( 699 "--retain", nargs="*", default=None, 700 help="By default, the functionality removes all besides quantitative results required for running benchmarks. " 701 "In case you would like to retain parts of the benchmark evaluation for visualization / reproducability, " 702 "you should choose one or multiple of 'data', 'crops', 'auto', 'int'. " 703 "where they are responsible for either retaining original inputs / extracted crops / " 704 "predictions of automatic segmentation / predictions of interactive segmentation, respectively." 705 ) 706 args = parser.parse_args() 707 708 run_benchmark_evaluations( 709 input_folder=args.input_folder, 710 dataset_choice=args.dataset_choice, 711 model_type=args.model_type, 712 output_folder=args.output_folder, 713 checkpoint_path=args.checkpoint_path, 714 run_amg=args.amg, 715 retain=args.retain, 716 ignore_warnings=True, 717 ) 718 719 720if __name__ == "__main__": 721 main()
LM_2D_DATASETS =
['livecell', 'deepbacs', 'tissuenet', 'neurips_cellseg', 'dynamicnuclearnet', 'hpa', 'covid_if', 'pannuke', 'lizard', 'orgasegment', 'omnipose', 'dic_hepg2']
LM_3D_DATASETS =
['plantseg_root', 'plantseg_ovules', 'gonuclear', 'mouse_embryo', 'embegseg', 'cellseg3d']
EM_2D_DATASETS =
['mitolab_tem']
EM_3D_DATASETS =
['mitoem_rat', 'mitoem_human', 'platynereis_nuclei', 'lucchi', 'mitolab', 'nuc_mm_mouse', 'num_mm_zebrafish', 'uro_cell', 'sponge_em', 'platynereis_cilia', 'vnc', 'asem_mito']
DATASET_RETURNS_FOLDER =
{'deepbacs': '*.tif'}
DATASET_CONTAINER_KEYS =
{'lucchi': ['raw', 'labels']}
def
run_benchmark_evaluations( input_folder: Union[os.PathLike, str], dataset_choice: str, model_type: str = 'vit_l', output_folder: Union[os.PathLike, str, NoneType] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, run_amg: bool = False, retain: Optional[List[str]] = None, ignore_warnings: bool = False):
596def run_benchmark_evaluations( 597 input_folder: Union[os.PathLike, str], 598 dataset_choice: str, 599 model_type: str = util._DEFAULT_MODEL, 600 output_folder: Optional[Union[str, os.PathLike]] = None, 601 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 602 run_amg: bool = False, 603 retain: Optional[List[str]] = None, 604 ignore_warnings: bool = False, 605): 606 """Run evaluation for benchmarking Segment Anything models on microscopy datasets. 607 608 Args: 609 input_folder: The path to directory where all inputs will be stored and preprocessed. 610 dataset_choice: The dataset choice. 611 model_type: The model choice for SAM. 612 output_folder: The path to directory where all outputs will be stored. 613 checkpoint_path: The checkpoint path 614 run_amg: Whether to run automatic segmentation in AMG mode. 615 retain: Whether to retain certain parts of the benchmark runs. 616 By default, removes everything besides quantitative results. 617 There is the choice to retain 'data', 'crops', 'auto', or 'int'. 618 ignore_warnings: Whether to ignore warnings. 619 """ 620 start = time.time() 621 622 with _filter_warnings(ignore_warnings): 623 device = util._get_default_device() 624 625 # Ensure if all the datasets have been installed by default. 626 dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice) 627 628 for choice in dataset_choice: 629 output_folder = os.path.join(output_folder, choice) 630 result_dir = os.path.join(output_folder, "results") 631 if os.path.exists(result_dir): 632 continue 633 634 os.makedirs(result_dir, exist_ok=True) 635 636 data_path = os.path.join(input_folder, choice) 637 638 # Extrapolate desired set from the datasets: 639 # a. for 2d datasets - 2d patches with the most number of labels present 640 # (in case of volumetric data, choose 2d patches per slice). 641 # b. for 3d datasets - 3d regions of interest with the most number of labels present. 642 ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10) 643 644 # Run inference and evaluation scripts on benchmark datasets. 645 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim) 646 _run_benchmark_evaluation_series( 647 image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg 648 ) 649 650 # Run inference and evaluation scripts on '2d' crops for volumetric datasets 651 if ndim == 3: 652 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2) 653 _run_benchmark_evaluation_series( 654 image_paths, gt_paths, model_type, output_folder, 2, device, checkpoint_path, run_amg 655 ) 656 657 _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder) 658 659 diff = time.time() - start 660 hours, rest = divmod(diff, 3600) 661 minutes, seconds = divmod(rest, 60) 662 print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {seconds:.2f}s")
Run evaluation for benchmarking Segment Anything models on microscopy datasets.
Arguments:
- input_folder: The path to directory where all inputs will be stored and preprocessed.
- dataset_choice: The dataset choice.
- model_type: The model choice for SAM.
- output_folder: The path to directory where all outputs will be stored.
- checkpoint_path: The checkpoint path
- run_amg: Whether to run automatic segmentation in AMG mode.
- retain: Whether to retain certain parts of the benchmark runs. By default, removes everything besides quantitative results. There is the choice to retain 'data', 'crops', 'auto', or 'int'.
- ignore_warnings: Whether to ignore warnings.