micro_sam.evaluation.benchmark_datasets
1import os 2import time 3from glob import glob 4from tqdm import tqdm 5from natsort import natsorted 6from typing import Union, Optional, List, Literal 7 8import numpy as np 9import pandas as pd 10import imageio.v3 as imageio 11from skimage.measure import label as connected_components 12 13from nifty.tools import blocking 14 15import torch 16 17from torch_em.data import datasets 18 19from micro_sam import util 20 21from . import run_evaluation 22from ..training.training import _filter_warnings 23from .inference import run_inference_with_iterative_prompting 24from .evaluation import run_evaluation_for_iterative_prompting 25from .multi_dimensional_segmentation import segment_slices_from_ground_truth 26from ..automatic_segmentation import ( 27 automatic_instance_segmentation, get_predictor_and_segmenter, DEFAULT_SEGMENTATION_MODE_WITH_DECODER, 28) 29 30 31LM_2D_DATASETS = [ 32 # in-domain 33 "livecell", # cell segmentation in PhC (has a TEST-set) 34 "deepbacs", # bacteria segmentation in label-free microscopy (has a TEST-set), 35 "tissuenet", # cell segmentation in tissue microscopy images (has a TEST-set), 36 "neurips_cellseg", # cell segmentation in various (has a TEST-set), 37 "cellpose", # cell segmentation in FM (has 'cyto2' on which we can TEST on), 38 "dynamicnuclearnet", # nuclei segmentation in FM (has a TEST-set) 39 "orgasegment", # organoid segmentation in BF (has a TEST-set) 40 "yeaz", # yeast segmentation in BF (has a TEST-set) 41 42 # out-of-domain 43 "arvidsson", # nuclei segmentation in HCS FM 44 "bitdepth_nucseg", # nuclei segmentation in FM 45 "cellbindb", # cell segmentation in various microscopy 46 "covid_if", # cell segmentation in IF 47 "deepseas", # cell segmentation in PhC, 48 "hpa", # cell segmentation in confocal, 49 "ifnuclei", # nuclei segmentation in IFM 50 "lizard", # nuclei segmentation in H&E histopathology, 51 "organoidnet", # organoid segmentation in BF 52 "toiam", # microbial cell segmentation in PhC 53 "vicar", # cell segmentation in label-free 54] 55 56LM_3D_DATASETS = [ 57 # in-domain 58 "plantseg_root", # cell segmentation in lightsheet (has a TEST-set) 59 60 # out-of-domain 61 "plantseg_ovules", # cell segmentation in confocal 62 "gonuclear", # nuclei segmentation in FM 63 "mouse_embryo", # cell segmentation in lightsheet 64 "cellseg3d", # nuclei segmentation in FM 65] 66 67EM_2D_DATASETS = ["mitolab_tem"] 68 69EM_3D_DATASETS = [ 70 # out-of-domain 71 "lucchi", # mitochondria segmentation in vEM 72 "mitolab", # mitochondria segmentation in various 73 "uro_cell", # mitochondria segmentation (and other organelles) in FIB-SEM 74 "sponge_em", # microvili segmentation (and other organelles) in sponge chamber vEM 75 "vnc", # mitochondria segmentation in drosophila brain TEM 76 "nuc_mm_mouse", # nuclei segmentation in microCT 77 "num_mm_zebrafish", # nuclei segmentation in EM 78 "platynereis_cilia", # cilia segmentation (and other structures) in platynereis larvae vEM 79 "asem_mito", # mitochondria segmentation (and other organelles) in FIB-SEM 80] 81 82DATASET_RETURNS_FOLDER = { 83 "deepbacs": "*.tif", 84 "mitolab_tem": "*.tiff" 85} 86 87DATASET_CONTAINER_KEYS = { 88 # 2d (LM) 89 "tissuenet": ["raw/rgb", "labels/cell"], 90 "covid_if": ["raw/serum_IgG/s0", "labels/cells/s0"], 91 "dynamicnuclearnet": ["raw", "labels"], 92 "hpa": [["raw/protein", "raw/microtubules", "raw/er"], "labels"], 93 "lizard": ["image", "labels/segmentation"], 94 95 # 3d (LM) 96 "plantseg_root": ["raw", "label"], 97 "plantseg_ovules": ["raw", "label"], 98 "gonuclear": ["raw/nuclei", "labels/nuclei"], 99 "mouse_embryo": ["raw", "label"], 100 "cellseg_3d": [None, None], 101 102 # 3d (EM) 103 "lucchi": ["raw", "labels"], 104 "uro_cell": ["raw", "labels/mito"], 105 "mitolab_3d": [None, None], 106 "sponge_em": ["volumes/raw", "volumes/labels/instances"], 107 "vnc": ["raw", "labels/mitochondria"] 108} 109 110 111def _download_benchmark_datasets(path, dataset_choice): 112 """Ensures whether all the datasets have been downloaded or not. 113 114 Args: 115 path: The path to directory where the supported datasets will be downloaded 116 for benchmarking Segment Anything models. 117 dataset_choice: The choice of dataset, expects the lower case name for the dataset. 118 119 Returns: 120 List of choice of dataset(s). 121 """ 122 available_datasets = { 123 # Light Microscopy datasets 124 125 # 2d datasets: in-domain 126 "livecell": lambda: datasets.livecell.get_livecell_data( 127 path=os.path.join(path, "livecell"), download=True, 128 ), 129 "deepbacs": lambda: datasets.deepbacs.get_deepbacs_data( 130 path=os.path.join(path, "deepbacs"), bac_type="mixed", download=True, 131 ), 132 "tissuenet": lambda: datasets.tissuenet.get_tissuenet_data( 133 path=os.path.join(path, "tissuenet"), split="test", download=True, 134 ), 135 "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_data( 136 root=os.path.join(path, "neurips_cellseg"), split="test", download=True, 137 ), 138 "cellpose": lambda: datasets.cellpose.get_cellpose_data( 139 path=os.path.join(path, "cellpose"), split="train", choice="cyto2", download=True, 140 ), 141 "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_data( 142 path=os.path.join(path, "dynamicnuclearnet"), split="test", download=True, 143 ), 144 "orgasegment": lambda: datasets.orgasegment.get_orgasegment_data( 145 path=os.path.join(path, "orgasegment"), split="eval", download=True, 146 ), 147 "yeaz": lambda: datasets.yeaz.get_yeaz_data( 148 path=os.path.join(path, "yeaz"), choice="bf", download=True, 149 ), 150 151 # 2d datasets: out-of-domain 152 "arvidsson": lambda: datasets.arvidsson.get_arvidsson_data( 153 path=os.path.join(path, "arvidsson"), split="test", download=True, 154 ), 155 "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_data( 156 path=os.path.join(path, "bitdepth_nucseg"), download=True, 157 ), 158 "cellbindb": lambda: datasets.cellbindb.get_cellbindb_data( 159 path=os.path.join(path, "cellbindb"), download=True, 160 ), 161 "covid_if": lambda: datasets.covid_if.get_covid_if_data( 162 path=os.path.join(path, "covid_if"), download=True, 163 ), 164 "deepseas": lambda: datasets.deepseas.get_deepseas_data( 165 path=os.path.join(path, "deepseas"), split="test", 166 ), 167 "hpa": lambda: datasets.hpa.get_hpa_segmentation_data( 168 path=os.path.join(path, "hpa"), download=True, 169 ), 170 "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_data( 171 path=os.path.join(path, "ifnuclei"), download=True, 172 ), 173 "lizard": lambda: datasets.lizard.get_lizard_data( 174 path=os.path.join(path, "lizard"), split="test", download=True, 175 ), 176 "organoidnet": lambda: datasets.organoidnet.get_organoidnet_data( 177 path=os.path.join(path, "organoidnet"), split="Test", download=True, 178 ), 179 "toiam": lambda: datasets.toiam.get_toiam_data( 180 path=os.path.join(path, "toiam"), download=True, 181 ), 182 "vicar": lambda: datasets.vicar.get_vicar_data( 183 path=os.path.join(path, "vicar"), download=True, 184 ), 185 186 # 3d datasets: in-domain 187 "plantseg_root": lambda: datasets.plantseg.get_plantseg_data( 188 path=os.path.join(path, "plantseg_root"), split="test", download=True, name="root", 189 ), 190 191 # 3d datasets: out-of-domain 192 "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_data( 193 path=os.path.join(path, "plantseg_ovules"), split="test", download=True, name="ovules", 194 ), 195 "gonuclear": lambda: datasets.gonuclear.get_gonuclear_data( 196 path=os.path.join(path, "gonuclear"), download=True, 197 ), 198 "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_data( 199 path=os.path.join(path, "mouse_embryo"), download=True, 200 ), 201 "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_data( 202 path=os.path.join(path, "cellseg_3d"), download=True, 203 ), 204 205 # Electron Microscopy datasets 206 207 # 2d datasets: out-of-domain 208 "mitolab_tem": lambda: datasets.cem.get_benchmark_data( 209 path=os.path.join(path, "mitolab"), dataset_id=7, download=True 210 ), 211 212 # 3d datasets: out-of-domain 213 "lucchi": lambda: datasets.lucchi.get_lucchi_data( 214 path=os.path.join(path, "lucchi"), split="test", download=True, 215 ), 216 "mitolab_3d": lambda: [ 217 datasets.cem.get_benchmark_data( 218 path=os.path.join(path, "mitolab"), dataset_id=dataset_id, download=True, 219 ) for dataset_id in range(1, 7) 220 ], 221 "uro_cell": lambda: datasets.uro_cell.get_uro_cell_data( 222 path=os.path.join(path, "uro_cell"), download=True, 223 ), 224 "vnc": lambda: datasets.vnc.get_vnc_data( 225 path=os.path.join(path, "vnc"), download=True, 226 ), 227 "sponge_em": lambda: datasets.sponge_em.get_sponge_em_data( 228 path=os.path.join(path, "sponge_em"), download=True, 229 ), 230 "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_data( 231 path=os.path.join(path, "nuc_mm"), sample="mouse", download=True, 232 ), 233 "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_data( 234 path=os.path.join(path, "nuc_mm"), sample="zebrafish", download=True, 235 ), 236 "asem_mito": lambda: datasets.asem.get_asem_data( 237 path=os.path.join(path, "asem"), volume_ids=datasets.asem.ORGANELLES["mito"], download=True, 238 ), 239 "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_data( 240 path=os.path.join(path, "platynereis"), name="cilia", download=True, 241 ), 242 } 243 244 if dataset_choice is None: 245 dataset_choice = available_datasets.keys() 246 else: 247 if not isinstance(dataset_choice, list): 248 dataset_choice = [dataset_choice] 249 250 for choice in dataset_choice: 251 if choice in available_datasets: 252 available_datasets[choice]() 253 else: 254 raise ValueError(f"'{choice}' is not a supported choice of dataset.") 255 256 return dataset_choice 257 258 259def _extract_slices_from_dataset(path, dataset_choice, crops_per_input=10): 260 """Extracts crops of desired shapes for performing evaluation in both 2d and 3d using `micro-sam`. 261 262 Args: 263 path: The path to directory where the supported datasets have be downloaded 264 for benchmarking Segment Anything models. 265 dataset_choice: The name of the dataset of choice to extract crops. 266 crops_per_input: The maximum number of crops to extract per inputs. 267 extract_2d: Whether to extract 2d crops from 3d patches. 268 269 Returns: 270 Filepath to the folder where extracted images are stored. 271 Filepath to the folder where corresponding extracted labels are stored. 272 The number of dimensions supported by the input. 273 """ 274 ndim = 2 if dataset_choice in [*LM_2D_DATASETS, *EM_2D_DATASETS] else 3 275 tile_shape = (512, 512) if ndim == 2 else (32, 512, 512) 276 277 # For 3d inputs, we extract both 2d and 3d crops. 278 extract_2d_crops_from_volumes = (ndim == 3) 279 280 available_datasets = { 281 # Light Microscopy datasets 282 283 # 2d: in-domain 284 "livecell": lambda: datasets.livecell.get_livecell_paths(path=path, split="test"), 285 "deepbacs": lambda: datasets.deepbacs.get_deepbacs_paths(path=path, split="test", bac_type="mixed"), 286 "tissuenet": lambda: datasets.tissuenet.get_tissuenet_paths(path=path, split="test"), 287 "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_paths(root=path, split="test"), 288 "cellpose": lambda: datasets.cellpose.get_cellpose_paths(path=path, split="train", choice="cyto2"), 289 "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_paths(path=path, split="test"), 290 "orgasegment": lambda: datasets.orgasegment.get_orgasegment_paths(path=path, split="eval"), 291 "yeaz": lambda: datasets.yeaz.get_yeaz_paths(path=path, choice="bf", split="test"), 292 293 # 2d: out-of-domain 294 "arvidsson": lambda: datasets.arvidsson.get_arvidsson_paths(path=path, split="test"), 295 "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_paths(path=path, magnification="20x"), 296 "cellbindb": lambda: datasets.cellbindb.get_cellbindb_paths( 297 path=path, data_choice=["10×Genomics_DAPI", "DAPI", "mIF"] 298 ), 299 "covid_if": lambda: datasets.covid_if.get_covid_if_paths(path=path), 300 "deepseas": lambda: datasets.deepseas.get_deepseas_paths(path=path, split="test"), 301 "hpa": lambda: datasets.hpa.get_hpa_segmentation_paths(path=path, split="val"), 302 "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_paths(path=path), 303 "lizard": lambda: datasets.lizard.get_lizard_paths(path=path, split="test"), 304 "organoidnet": lambda: datasets.organoidnet.get_organoidnet_paths(path=path, split="Test"), 305 "toiam": lambda: datasets.toiam.get_toiam_paths(path=path), 306 "vicar": lambda: datasets.vicar.get_vicar_paths(path=path), 307 308 # 3d: in-domain 309 "plantseg_root": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="root", split="test"), 310 311 # 3d: out-of-domain 312 "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="ovules", split="test"), 313 "gonuclear": lambda: datasets.gonuclear.get_gonuclear_paths(path=path), 314 "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_paths(path=path, name="nuclei", split="val"), 315 "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_paths(path=path), 316 317 # Electron Microscopy datasets 318 319 # 2d: out-of-domain 320 "mitolab_tem": lambda: datasets.cem.get_benchmark_paths( 321 path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=7 322 )[:2], 323 324 # 3d: out-of-domain"lucchi": lambda: datasets.lucchi.get_lucchi_paths(path=path, split="test"), 325 "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="cilia"), 326 "uro_cell": lambda: datasets.uro_cell.get_uro_cell_paths(path=path, target="mito"), 327 "vnc": lambda: datasets.vnc.get_vnc_mito_paths(path=path), 328 "sponge_em": lambda: datasets.sponge_em.get_sponge_em_paths(path=path, sample_ids=None), 329 "mitolab_3d": lambda: ( 330 [ 331 datasets.cem.get_benchmark_paths( 332 path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=i 333 )[0] for i in range(1, 7) 334 ], 335 [ 336 datasets.cem.get_benchmark_paths( 337 path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=i 338 )[1] for i in range(1, 7) 339 ] 340 ), 341 "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="mouse", split="val"), 342 "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="zebrafish", split="val"), 343 "asem_mito": lambda: datasets.asem.get_asem_paths(path=path, volume_ids=datasets.asem.ORGANELLES["mito"]) 344 } 345 346 if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice in ["cellseg_3d", "mitolab_3d"]: 347 image_paths, gt_paths = available_datasets[dataset_choice]() 348 349 if dataset_choice in DATASET_RETURNS_FOLDER: 350 image_paths = glob(os.path.join(image_paths, DATASET_RETURNS_FOLDER[dataset_choice])) 351 gt_paths = glob(os.path.join(gt_paths, DATASET_RETURNS_FOLDER[dataset_choice])) 352 353 image_paths, gt_paths = natsorted(image_paths), natsorted(gt_paths) 354 assert len(image_paths) == len(gt_paths) 355 356 paths_set = zip(image_paths, gt_paths) 357 358 else: 359 image_paths = available_datasets[dataset_choice]() 360 if isinstance(image_paths, str): 361 paths_set = [image_paths] 362 else: 363 paths_set = natsorted(image_paths) 364 365 # Directory where we store the extracted ROIs. 366 save_image_dir = [os.path.join(path, f"roi_{ndim}d", "inputs")] 367 save_gt_dir = [os.path.join(path, f"roi_{ndim}d", "labels")] 368 if extract_2d_crops_from_volumes: 369 save_image_dir.append(os.path.join(path, "roi_2d", "inputs")) 370 save_gt_dir.append(os.path.join(path, "roi_2d", "labels")) 371 372 _dir_exists = [os.path.exists(idir) and os.path.exists(gdir) for idir, gdir in zip(save_image_dir, save_gt_dir)] 373 if all(_dir_exists): 374 return ndim 375 376 [os.makedirs(idir, exist_ok=True) for idir in save_image_dir] 377 [os.makedirs(gdir, exist_ok=True) for gdir in save_gt_dir] 378 379 # Logic to extract relevant patches for inference 380 image_counter = 1 381 for per_paths in tqdm(paths_set, desc=f"Extracting {ndim}d patches for {dataset_choice}"): 382 if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice in ["cellseg_3d", "mitolab_3d"]: # noqa 383 image_path, gt_path = per_paths 384 image, gt = util.load_image_data(image_path), util.load_image_data(gt_path) 385 386 else: 387 image_path = per_paths 388 gt = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][1]) 389 if dataset_choice == "hpa": 390 # Get inputs per channel and stack them together to make the desired 3 channel image. 391 image = np.stack( 392 [util.load_image_data(image_path, k) for k in DATASET_CONTAINER_KEYS[dataset_choice][0]], axis=0, 393 ) 394 # Resize inputs to desired tile shape, in favor of working with the shape of foreground. 395 from torch_em.transform.generic import ResizeLongestSideInputs 396 raw_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_rgb=True) 397 label_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_label=True) 398 image, gt = raw_transform(image).transpose(1, 2, 0), label_transform(gt) 399 400 else: 401 image = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][0]) 402 403 if dataset_choice in ["tissuenet", "lizard"]: 404 if image.ndim == 3 and image.shape[0] == 3: # Make channels last for tissuenet RGB-style images. 405 image = image.transpose(1, 2, 0) 406 407 # Allow RGBs to stay as it is with channels last 408 if image.ndim == 3 and image.shape[-1] == 3: 409 skip_smaller_shape = (np.array(image.shape) >= np.array((*tile_shape, 3))).all() 410 else: 411 skip_smaller_shape = (np.array(image.shape) >= np.array(tile_shape)).all() 412 413 # Ensure ground truth has instance labels. 414 gt = connected_components(gt) 415 416 if len(np.unique(gt)) == 1: # There could be labels which does not have any annotated foreground. 417 continue 418 419 # Let's extract and save all the crops. 420 # The first round of extraction is always to match the desired input dimensions. 421 image_crops, gt_crops = _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input) 422 image_counter = _save_image_label_crops( 423 image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir[0], save_gt_dir[0] 424 ) 425 426 # The next round of extraction is to get 2d crops from 3d inputs. 427 if extract_2d_crops_from_volumes: 428 curr_tile_shape = tile_shape[1:] # We expect 2d tile shape for this stage. 429 430 curr_image_crops, curr_gt_crops = [], [] 431 for per_z_im, per_z_gt in zip(image, gt): 432 curr_skip_smaller_shape = (np.array(per_z_im.shape) >= np.array(curr_tile_shape)).all() 433 434 image_crops, gt_crops = _get_crops_for_input( 435 image=per_z_im, gt=per_z_gt, ndim=2, 436 tile_shape=curr_tile_shape, 437 skip_smaller_shape=curr_skip_smaller_shape, 438 crops_per_input=crops_per_input, 439 ) 440 curr_image_crops.extend(image_crops) 441 curr_gt_crops.extend(gt_crops) 442 443 image_counter = _save_image_label_crops( 444 curr_image_crops, curr_gt_crops, dataset_choice, 2, image_counter, save_image_dir[1], save_gt_dir[1] 445 ) 446 447 return ndim 448 449 450def _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input): 451 tiling = blocking([0] * ndim, gt.shape, tile_shape) 452 n_tiles = tiling.numberOfBlocks 453 tiles = [tiling.getBlock(tile_id) for tile_id in range(n_tiles)] 454 crop_boxes = [ 455 tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) for tile in tiles 456 ] 457 n_ids = [idx for idx in range(len(crop_boxes))] 458 n_instances = [len(np.unique(gt[crop])) for crop in crop_boxes] 459 460 # Extract the desired number of patches with higher number of instances. 461 image_crops, gt_crops = [], [] 462 for i, (per_n_instance, per_id) in enumerate(sorted(zip(n_instances, n_ids), reverse=True), start=1): 463 crop_box = crop_boxes[per_id] 464 crop_image, crop_gt = image[crop_box], gt[crop_box] 465 466 # NOTE: We avoid using the crops which do not match the desired tile shape. 467 _rtile_shape = (*tile_shape, 3) if image.ndim == 3 and image.shape[-1] == 3 else tile_shape # For RGB images. 468 if skip_smaller_shape and crop_image.shape != _rtile_shape: 469 continue 470 471 # NOTE: There could be a case where some later patches are invalid. 472 if per_n_instance == 1: 473 break 474 475 image_crops.append(crop_image) 476 gt_crops.append(crop_gt) 477 478 # NOTE: If the number of patches extracted have been fulfiled, we stop sampling patches. 479 if len(image_crops) > 0 and i >= crops_per_input: 480 break 481 482 return image_crops, gt_crops 483 484 485def _save_image_label_crops(image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir, save_gt_dir): 486 for image_crop, gt_crop in tqdm( 487 zip(image_crops, gt_crops), total=len(image_crops), desc=f"Saving {ndim}d crops for {dataset_choice}" 488 ): 489 fname = f"{dataset_choice}_{image_counter:05}.tif" 490 491 if image_crop.ndim == 3 and image_crop.shape[-1] == 3: 492 assert image_crop.shape[:2] == gt_crop.shape 493 else: 494 assert image_crop.shape == gt_crop.shape 495 496 imageio.imwrite(os.path.join(save_image_dir, fname), image_crop, compression="zlib") 497 imageio.imwrite(os.path.join(save_gt_dir, fname), gt_crop, compression="zlib") 498 499 image_counter += 1 500 501 return image_counter 502 503 504def _get_image_label_paths(path, ndim): 505 image_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "inputs", "*"))) 506 gt_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "labels", "*"))) 507 return image_paths, gt_paths 508 509 510def _run_automatic_segmentation_per_dataset( 511 image_paths: List[Union[os.PathLike, str]], 512 gt_paths: List[Union[os.PathLike, str]], 513 model_type: str, 514 output_folder: Union[os.PathLike, str], 515 ndim: Optional[int] = None, 516 device: Optional[Union[torch.device, str]] = None, 517 checkpoint_path: Optional[Union[os.PathLike, str]] = None, 518 segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = "ais", 519 **auto_seg_kwargs 520): 521 """Functionality to run automatic segmentation for multiple input files at once. 522 It stores the evaluated automatic segmentation results (quantitative). 523 524 Args: 525 image_paths: List of filepaths for the input image data. 526 gt_paths: List of filepaths for the corresponding label data. 527 model_type: The choice of image encoder for the Segment Anything model. 528 output_folder: Filepath to the folder where we store all the results. 529 ndim: The number of input dimensions. 530 device: The torch device. 531 checkpoint_path: The filepath where the model checkpoints are stored. 532 segmentation_mode: The mode for automatic segmentation. 533 auto_seg_kwargs: Additional arguments for automatic segmentation parameters. 534 """ 535 if segmentation_mode is None: # The 2nd condition checks if you want AIS and if decoder state exists or not. 536 _, state = util.get_sam_model( 537 model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True 538 ) 539 segmentation_mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg" 540 541 fname = f"{segmentation_mode}_{ndim}d" 542 543 result_path = os.path.join(output_folder, "results", f"{fname}.csv") 544 if os.path.exists(result_path): 545 return 546 547 prediction_dir = os.path.join(output_folder, fname, "inference") 548 os.makedirs(prediction_dir, exist_ok=True) 549 550 # Get the predictor (and the additional instance segmentation decoder, if available). 551 predictor, segmenter = get_predictor_and_segmenter( 552 model_type=model_type, checkpoint=checkpoint_path, device=device, 553 segmentation_mode=segmentation_mode, is_tiled=False, 554 ) 555 556 for image_path in tqdm(image_paths, desc=f"Run {segmentation_mode} in {ndim}d"): 557 output_path = os.path.join(prediction_dir, os.path.basename(image_path)) 558 if os.path.exists(output_path): 559 continue 560 561 # Run Automatic Segmentation (AMG and AIS) 562 automatic_instance_segmentation( 563 predictor=predictor, 564 segmenter=segmenter, 565 input_path=image_path, 566 output_path=output_path, 567 ndim=ndim, 568 verbose=False, 569 **auto_seg_kwargs 570 ) 571 572 prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*"))) 573 run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths, save_path=result_path) 574 575 576def _run_interactive_segmentation_per_dataset( 577 image_paths: List[Union[os.PathLike, str]], 578 gt_paths: List[Union[os.PathLike, str]], 579 output_folder: Union[os.PathLike, str], 580 model_type: str, 581 prompt_choice: Literal["box", "points"], 582 device: Optional[Union[torch.device, str]] = None, 583 ndim: Optional[int] = None, 584 checkpoint_path: Optional[Union[os.PathLike, str]] = None, 585 use_masks: bool = False, 586): 587 """Functionality to run interactive segmentation for multiple input files at once. 588 It stores the evaluated interactive segmentation results. 589 590 Args: 591 image_paths: List of filepaths for the input image data. 592 gt_paths: List of filepaths for the corresponding label data. 593 output_folder: Filepath to the folder where we store all the results. 594 model_type: The choice of model type for Segment Anything. 595 prompt_choice: The choice of initial prompts to begin the interactive segmentation. 596 device: The torch device. 597 ndim: The number of input dimensions. 598 checkpoint_path: The filepath for stored checkpoints. 599 use_masks: Whether to use masks for iterative prompting. 600 """ 601 if ndim == 2: 602 # Get the Segment Anything predictor. 603 predictor = util.get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path) 604 605 prediction_root = os.path.join( 606 output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}", 607 "iterative_prompting_" + ("with_masks" if use_masks else "without_masks") 608 ) 609 610 # Run interactive instance segmentation 611 # (starting with box and points followed by iterative prompt-based correction) 612 run_inference_with_iterative_prompting( 613 predictor=predictor, 614 image_paths=image_paths, 615 gt_paths=gt_paths, 616 embedding_dir=None, # We set this to None to compute embeddings on-the-fly. 617 prediction_dir=prediction_root, 618 start_with_box_prompt=(prompt_choice == "box"), 619 use_masks=use_masks, 620 # TODO: add parameter for deform over box prompts (to simulate prompts in practice). 621 ) 622 623 # Evaluate the interactive instance segmentation. 624 run_evaluation_for_iterative_prompting( 625 gt_paths=gt_paths, 626 prediction_root=prediction_root, 627 experiment_folder=output_folder, 628 start_with_box_prompt=(prompt_choice == "box"), 629 use_masks=use_masks, 630 ) 631 632 else: 633 save_path = os.path.join(output_folder, "results", f"interactive_segmentation_3d_with_{prompt_choice}.csv") 634 if os.path.exists(save_path): 635 print( 636 f"Results for 3d interactive segmentation with '{prompt_choice}' are already stored at '{save_path}'." 637 ) 638 return 639 640 results = [] 641 for image_path, gt_path in tqdm( 642 zip(image_paths, gt_paths), total=len(image_paths), 643 desc=f"Run interactive segmentation in 3d with '{prompt_choice}'" 644 ): 645 prediction_dir = os.path.join(output_folder, "interactive_segmentation_3d", f"{prompt_choice}") 646 os.makedirs(prediction_dir, exist_ok=True) 647 648 prediction_path = os.path.join(prediction_dir, os.path.basename(image_path)) 649 if os.path.exists(prediction_path): 650 continue 651 652 per_vol_result = segment_slices_from_ground_truth( 653 volume=imageio.imread(image_path), 654 ground_truth=imageio.imread(gt_path), 655 model_type=model_type, 656 checkpoint_path=checkpoint_path, 657 save_path=prediction_path, 658 device=device, 659 interactive_seg_mode=prompt_choice, 660 min_size=10, 661 ) 662 results.append(per_vol_result) 663 664 results = pd.concat(results) 665 results = results.groupby(results.index).mean() 666 results.to_csv(save_path) 667 668 669def _run_benchmark_evaluation_series( 670 image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, 671 segmentation_mode, evaluation_methods, 672): 673 seg_kwargs = { 674 "image_paths": image_paths, 675 "gt_paths": gt_paths, 676 "output_folder": output_folder, 677 "ndim": ndim, 678 "model_type": model_type, 679 "device": device, 680 "checkpoint_path": checkpoint_path, 681 } 682 683 # Perform: 684 # a. automatic segmentation (supported in both 2d and 3d, wherever relevant) 685 # The automatic segmentation steps below are configured in a way that AIS has priority (if decoder is found) 686 # Otherwise, it runs for AMG. 687 # Next, we check if the user expects to run AMG as well (after the run for AIS). 688 689 if evaluation_methods != "interactive": # Avoid auto. seg. evaluation for 'interactive'-only run choice. 690 # i. Run automatic segmentation method supported with the SAM model (AMG or AIS). 691 _run_automatic_segmentation_per_dataset(segmentation_mode=None, **seg_kwargs) 692 693 # ii. Run automatic mask generation (AMG). 694 # NOTE: This would only run if the user wants to. Else by default, it is set to 'False'. 695 _run_automatic_segmentation_per_dataset(segmentation_mode=segmentation_mode, **seg_kwargs) 696 697 if evaluation_methods != "automatic": # Avoid int. seg. evaluation for 'automatic'-only run choice. 698 # b. Run interactive segmentation (supported in both 2d and 3d, wherever relevant) 699 _run_interactive_segmentation_per_dataset(prompt_choice="box", **seg_kwargs) 700 _run_interactive_segmentation_per_dataset(prompt_choice="box", use_masks=True, **seg_kwargs) 701 _run_interactive_segmentation_per_dataset(prompt_choice="points", **seg_kwargs) 702 _run_interactive_segmentation_per_dataset(prompt_choice="points", use_masks=True, **seg_kwargs) 703 704 705def _clear_cached_items(retain, path, output_folder): 706 import shutil 707 from pathlib import Path 708 709 REMOVE_LIST = ["data", "crops", "automatic", "interactive"] 710 if retain is None: 711 remove_list = REMOVE_LIST 712 else: 713 assert isinstance(retain, list) 714 remove_list = set(REMOVE_LIST) - set(retain) 715 716 paths = [] 717 # Stage 1: Remove inputs. 718 if "data" in remove_list or "crops" in remove_list: 719 all_paths = glob(os.path.join(path, "*")) 720 721 # In case we want to remove both data and crops, we remove the data folder entirely. 722 if "data" in remove_list and "crops" in remove_list: 723 paths.extend(all_paths) 724 return 725 726 # Next, we verify whether the we only remove either of data or crops. 727 for curr_path in all_paths: 728 if os.path.basename(curr_path).startswith("roi") and "crops" in remove_list: 729 paths.append(curr_path) 730 elif "data" in remove_list: 731 paths.append(curr_path) 732 733 # Stage 2: Remove predictions 734 if "automatic" in remove_list: 735 paths.extend(glob(os.path.join(output_folder, "amg_*"))) 736 paths.extend(glob(os.path.join(output_folder, "ais_*"))) 737 738 if "interactive" in remove_list: 739 paths.extend(glob(os.path.join(output_folder, "interactive_segmentation_*"))) 740 741 [shutil.rmtree(_path) if Path(_path).is_dir() else os.remove(_path) for _path in paths] 742 743 744def run_benchmark_evaluations( 745 input_folder: Union[os.PathLike, str], 746 dataset_choice: str, 747 model_type: str = util._DEFAULT_MODEL, 748 output_folder: Optional[Union[str, os.PathLike]] = None, 749 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 750 segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None, 751 retain: Optional[List[str]] = None, 752 evaluation_methods: Literal["all", "automatic", "interactive"] = "all", 753 ignore_warnings: bool = False, 754): 755 """Run evaluation for benchmarking Segment Anything models on microscopy datasets. 756 757 Args: 758 input_folder: The path to directory where all inputs will be stored and preprocessed. 759 dataset_choice: The dataset choice. 760 model_type: The model choice for SAM. 761 output_folder: The path to directory where all outputs will be stored. 762 checkpoint_path: The checkpoint path 763 segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'. 764 retain: Whether to retain certain parts of the benchmark runs. 765 By default, removes everything besides quantitative results. 766 There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'. 767 evaluation_methods: The choice of evaluation methods. 768 By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). 769 Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs. 770 ignore_warnings: Whether to ignore warnings. 771 """ 772 start = time.time() 773 774 with _filter_warnings(ignore_warnings): 775 device = util._get_default_device() 776 777 # Ensure if all the datasets have been installed by default. 778 dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice) 779 780 for choice in dataset_choice: 781 output_folder = os.path.join(output_folder, choice) 782 result_dir = os.path.join(output_folder, "results") 783 os.makedirs(result_dir, exist_ok=True) 784 785 data_path = os.path.join(input_folder, choice) 786 787 # Extrapolate desired set from the datasets: 788 # a. for 2d datasets - 2d patches with the most number of labels present 789 # (in case of volumetric data, choose 2d patches per slice). 790 # b. for 3d datasets - 3d regions of interest with the most number of labels present. 791 ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10) 792 793 # Run inference and evaluation scripts on benchmark datasets. 794 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim) 795 _run_benchmark_evaluation_series( 796 image_paths=image_paths, 797 gt_paths=gt_paths, 798 model_type=model_type, 799 output_folder=output_folder, 800 ndim=ndim, 801 device=device, 802 checkpoint_path=checkpoint_path, 803 segmentation_mode=segmentation_mode, 804 evaluation_methods=evaluation_methods, 805 ) 806 807 # Run inference and evaluation scripts on '2d' crops for volumetric datasets 808 if ndim == 3: 809 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2) 810 _run_benchmark_evaluation_series( 811 image_paths=image_paths, 812 gt_paths=gt_paths, 813 model_type=model_type, 814 output_folder=output_folder, 815 ndim=2, 816 device=device, 817 checkpoint_path=checkpoint_path, 818 segmentation_mode=segmentation_mode, 819 evaluation_methods=evaluation_methods, 820 ) 821 822 _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder) 823 824 diff = time.time() - start 825 hours, rest = divmod(diff, 3600) 826 minutes, seconds = divmod(rest, 60) 827 print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s") 828 829 830def main(): 831 """@private""" 832 import argparse 833 834 available_models = list(util.get_model_names()) 835 available_models = ", ".join(available_models) 836 837 parser = argparse.ArgumentParser( 838 description="Run evaluation for benchmarking Segment Anything models on microscopy datasets." 839 ) 840 parser.add_argument( 841 "-i", "--input_folder", type=str, required=True, 842 help="The path to a directory where the microscopy datasets are and/or will be stored." 843 ) 844 parser.add_argument( 845 "-m", "--model_type", type=str, default=util._DEFAULT_MODEL, 846 help=f"The segment anything model that will be used, one of '{available_models}'. By default, " 847 f"it uses'{util._DEFAULT_MODEL}'." 848 ) 849 parser.add_argument( 850 "-c", "--checkpoint_path", type=str, default=None, 851 help="The filepath to checkpoint from which the SAM model will be loaded." 852 ) 853 parser.add_argument( 854 "-d", "--dataset_choice", type=str, nargs='*', default=None, 855 help="The choice(s) of dataset for evaluating SAM models. Multiple datasets can be specified. " 856 "By default, it evaluates on all datasets." 857 ) 858 parser.add_argument( 859 "-o", "--output_folder", type=str, required=True, 860 help="The path where the results for (automatic and interactive) instance segmentation results " 861 "will be stored as 'csv' files." 862 ) 863 parser.add_argument( 864 "--amg", action="store_true", 865 help="Whether to run automatic segmentation in AMG mode (i.e. the default auto-seg approach for SAM)." 866 ) 867 parser.add_argument( 868 "--retain", nargs="*", default=None, 869 help="By default, the functionality removes all besides quantitative results required for running benchmarks. " 870 "In case you would like to retain parts of the benchmark evaluation for visualization / reproducibility, " 871 "you should choose one or multiple of 'data', 'crops', 'automatic', 'interactive'. " 872 "where they are responsible for either retaining original inputs / extracted crops / " 873 "predictions of automatic segmentation / predictions of interactive segmentation, respectively." 874 ) 875 parser.add_argument( 876 "--evaluate", type=str, default="all", choices=["all", "automatic", "interactive"], 877 help="The choice of methods for benchmarking evaluation for reproducibility. " 878 "By default, we run all evaluations with 'all'. If 'automatic' is chosen, it runs automatic segmentation only " 879 "/ 'interactive' runs interactive segmentation (starting from box and single point) with iterative prompting." 880 ) 881 args = parser.parse_args() 882 883 run_benchmark_evaluations( 884 input_folder=args.input_folder, 885 dataset_choice=args.dataset_choice, 886 model_type=args.model_type, 887 output_folder=args.output_folder, 888 checkpoint_path=args.checkpoint_path, 889 segmentation_mode=args.segmentation_mode, 890 retain=args.retain, 891 evaluation_methods=args.evaluate, 892 ignore_warnings=True, 893 ) 894 895 896if __name__ == "__main__": 897 main()
LM_2D_DATASETS =
['livecell', 'deepbacs', 'tissuenet', 'neurips_cellseg', 'cellpose', 'dynamicnuclearnet', 'orgasegment', 'yeaz', 'arvidsson', 'bitdepth_nucseg', 'cellbindb', 'covid_if', 'deepseas', 'hpa', 'ifnuclei', 'lizard', 'organoidnet', 'toiam', 'vicar']
LM_3D_DATASETS =
['plantseg_root', 'plantseg_ovules', 'gonuclear', 'mouse_embryo', 'cellseg3d']
EM_2D_DATASETS =
['mitolab_tem']
EM_3D_DATASETS =
['lucchi', 'mitolab', 'uro_cell', 'sponge_em', 'vnc', 'nuc_mm_mouse', 'num_mm_zebrafish', 'platynereis_cilia', 'asem_mito']
DATASET_RETURNS_FOLDER =
{'deepbacs': '*.tif', 'mitolab_tem': '*.tiff'}
DATASET_CONTAINER_KEYS =
{'tissuenet': ['raw/rgb', 'labels/cell'], 'covid_if': ['raw/serum_IgG/s0', 'labels/cells/s0'], 'dynamicnuclearnet': ['raw', 'labels'], 'hpa': [['raw/protein', 'raw/microtubules', 'raw/er'], 'labels'], 'lizard': ['image', 'labels/segmentation'], 'plantseg_root': ['raw', 'label'], 'plantseg_ovules': ['raw', 'label'], 'gonuclear': ['raw/nuclei', 'labels/nuclei'], 'mouse_embryo': ['raw', 'label'], 'cellseg_3d': [None, None], 'lucchi': ['raw', 'labels'], 'uro_cell': ['raw', 'labels/mito'], 'mitolab_3d': [None, None], 'sponge_em': ['volumes/raw', 'volumes/labels/instances'], 'vnc': ['raw', 'labels/mitochondria']}
def
run_benchmark_evaluations( input_folder: Union[os.PathLike, str], dataset_choice: str, model_type: str = 'vit_b_lm', output_folder: Union[os.PathLike, str, NoneType] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, segmentation_mode: Optional[Literal['amg', 'ais', 'apg']] = None, retain: Optional[List[str]] = None, evaluation_methods: Literal['all', 'automatic', 'interactive'] = 'all', ignore_warnings: bool = False):
745def run_benchmark_evaluations( 746 input_folder: Union[os.PathLike, str], 747 dataset_choice: str, 748 model_type: str = util._DEFAULT_MODEL, 749 output_folder: Optional[Union[str, os.PathLike]] = None, 750 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 751 segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None, 752 retain: Optional[List[str]] = None, 753 evaluation_methods: Literal["all", "automatic", "interactive"] = "all", 754 ignore_warnings: bool = False, 755): 756 """Run evaluation for benchmarking Segment Anything models on microscopy datasets. 757 758 Args: 759 input_folder: The path to directory where all inputs will be stored and preprocessed. 760 dataset_choice: The dataset choice. 761 model_type: The model choice for SAM. 762 output_folder: The path to directory where all outputs will be stored. 763 checkpoint_path: The checkpoint path 764 segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'. 765 retain: Whether to retain certain parts of the benchmark runs. 766 By default, removes everything besides quantitative results. 767 There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'. 768 evaluation_methods: The choice of evaluation methods. 769 By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). 770 Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs. 771 ignore_warnings: Whether to ignore warnings. 772 """ 773 start = time.time() 774 775 with _filter_warnings(ignore_warnings): 776 device = util._get_default_device() 777 778 # Ensure if all the datasets have been installed by default. 779 dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice) 780 781 for choice in dataset_choice: 782 output_folder = os.path.join(output_folder, choice) 783 result_dir = os.path.join(output_folder, "results") 784 os.makedirs(result_dir, exist_ok=True) 785 786 data_path = os.path.join(input_folder, choice) 787 788 # Extrapolate desired set from the datasets: 789 # a. for 2d datasets - 2d patches with the most number of labels present 790 # (in case of volumetric data, choose 2d patches per slice). 791 # b. for 3d datasets - 3d regions of interest with the most number of labels present. 792 ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10) 793 794 # Run inference and evaluation scripts on benchmark datasets. 795 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim) 796 _run_benchmark_evaluation_series( 797 image_paths=image_paths, 798 gt_paths=gt_paths, 799 model_type=model_type, 800 output_folder=output_folder, 801 ndim=ndim, 802 device=device, 803 checkpoint_path=checkpoint_path, 804 segmentation_mode=segmentation_mode, 805 evaluation_methods=evaluation_methods, 806 ) 807 808 # Run inference and evaluation scripts on '2d' crops for volumetric datasets 809 if ndim == 3: 810 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2) 811 _run_benchmark_evaluation_series( 812 image_paths=image_paths, 813 gt_paths=gt_paths, 814 model_type=model_type, 815 output_folder=output_folder, 816 ndim=2, 817 device=device, 818 checkpoint_path=checkpoint_path, 819 segmentation_mode=segmentation_mode, 820 evaluation_methods=evaluation_methods, 821 ) 822 823 _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder) 824 825 diff = time.time() - start 826 hours, rest = divmod(diff, 3600) 827 minutes, seconds = divmod(rest, 60) 828 print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s")
Run evaluation for benchmarking Segment Anything models on microscopy datasets.
Arguments:
- input_folder: The path to directory where all inputs will be stored and preprocessed.
- dataset_choice: The dataset choice.
- model_type: The model choice for SAM.
- output_folder: The path to directory where all outputs will be stored.
- checkpoint_path: The checkpoint path
- segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
- retain: Whether to retain certain parts of the benchmark runs. By default, removes everything besides quantitative results. There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
- evaluation_methods: The choice of evaluation methods. By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
- ignore_warnings: Whether to ignore warnings.