micro_sam.evaluation.benchmark_datasets
1import os 2import time 3from glob import glob 4from tqdm import tqdm 5from natsort import natsorted 6from typing import Union, Optional, List, Literal 7 8import numpy as np 9import pandas as pd 10import imageio.v3 as imageio 11from skimage.measure import label as connected_components 12 13from nifty.tools import blocking 14 15import torch 16 17from torch_em.data import datasets 18 19from micro_sam import util 20 21from . import run_evaluation 22from ..training.training import _filter_warnings 23from .inference import run_inference_with_iterative_prompting 24from .evaluation import run_evaluation_for_iterative_prompting 25from .multi_dimensional_segmentation import segment_slices_from_ground_truth 26from ..automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter 27 28 29LM_2D_DATASETS = [ 30 # in-domain 31 "livecell", # cell segmentation in PhC (has a TEST-set) 32 "deepbacs", # bacteria segmentation in label-free microscopy (has a TEST-set), 33 "tissuenet", # cell segmentation in tissue microscopy images (has a TEST-set), 34 "neurips_cellseg", # cell segmentation in various (has a TEST-set), 35 "cellpose", # cell segmentation in FM (has 'cyto2' on which we can TEST on), 36 "dynamicnuclearnet", # nuclei segmentation in FM (has a TEST-set) 37 "orgasegment", # organoid segmentation in BF (has a TEST-set) 38 "yeaz", # yeast segmentation in BF (has a TEST-set) 39 40 # out-of-domain 41 "arvidsson", # nuclei segmentation in HCS FM 42 "bitdepth_nucseg", # nuclei segmentation in FM 43 "cellbindb", # cell segmentation in various microscopy 44 "covid_if", # cell segmentation in IF 45 "deepseas", # cell segmentation in PhC, 46 "hpa", # cell segmentation in confocal, 47 "ifnuclei", # nuclei segmentation in IFM 48 "lizard", # nuclei segmentation in H&E histopathology, 49 "organoidnet", # organoid segmentation in BF 50 "toiam", # microbial cell segmentation in PhC 51 "vicar", # cell segmentation in label-free 52] 53 54LM_3D_DATASETS = [ 55 # in-domain 56 "plantseg_root", # cell segmentation in lightsheet (has a TEST-set) 57 58 # out-of-domain 59 "plantseg_ovules", # cell segmentation in confocal 60 "gonuclear", # nuclei segmentation in FM 61 "mouse_embryo", # cell segmentation in lightsheet 62 "cellseg3d", # nuclei segmentation in FM 63] 64 65EM_2D_DATASETS = ["mitolab_tem"] 66 67EM_3D_DATASETS = [ 68 # out-of-domain 69 "lucchi", # mitochondria segmentation in vEM 70 "mitolab", # mitochondria segmentation in various 71 "uro_cell", # mitochondria segmentation (and other organelles) in FIB-SEM 72 "sponge_em", # microvili segmentation (and other organelles) in sponge chamber vEM 73 "vnc", # mitochondria segmentation in drosophila brain TEM 74 "nuc_mm_mouse", # nuclei segmentation in microCT 75 "num_mm_zebrafish", # nuclei segmentation in EM 76 "platynereis_cilia", # cilia segmentation (and other structures) in platynereis larvae vEM 77 "asem_mito", # mitochondria segmentation (and other organelles) in FIB-SEM 78] 79 80DATASET_RETURNS_FOLDER = { 81 "deepbacs": "*.tif", 82 "mitolab_tem": "*.tiff" 83} 84 85DATASET_CONTAINER_KEYS = { 86 # 2d (LM) 87 "tissuenet": ["raw/rgb", "labels/cell"], 88 "covid_if": ["raw/serum_IgG/s0", "labels/cells/s0"], 89 "dynamicnuclearnet": ["raw", "labels"], 90 "hpa": [["raw/protein", "raw/microtubules", "raw/er"], "labels"], 91 "lizard": ["image", "labels/segmentation"], 92 93 # 3d (LM) 94 "plantseg_root": ["raw", "label"], 95 "plantseg_ovules": ["raw", "label"], 96 "gonuclear": ["raw/nuclei", "labels/nuclei"], 97 "mouse_embryo": ["raw", "label"], 98 "cellseg_3d": [None, None], 99 100 # 3d (EM) 101 "lucchi": ["raw", "labels"], 102 "uro_cell": ["raw", "labels/mito"], 103 "mitolab_3d": [None, None], 104 "sponge_em": ["volumes/raw", "volumes/labels/instances"], 105 "vnc": ["raw", "labels/mitochondria"] 106} 107 108 109def _download_benchmark_datasets(path, dataset_choice): 110 """Ensures whether all the datasets have been downloaded or not. 111 112 Args: 113 path: The path to directory where the supported datasets will be downloaded 114 for benchmarking Segment Anything models. 115 dataset_choice: The choice of dataset, expects the lower case name for the dataset. 116 117 Returns: 118 List of choice of dataset(s). 119 """ 120 available_datasets = { 121 # Light Microscopy datasets 122 123 # 2d datasets: in-domain 124 "livecell": lambda: datasets.livecell.get_livecell_data( 125 path=os.path.join(path, "livecell"), download=True, 126 ), 127 "deepbacs": lambda: datasets.deepbacs.get_deepbacs_data( 128 path=os.path.join(path, "deepbacs"), bac_type="mixed", download=True, 129 ), 130 "tissuenet": lambda: datasets.tissuenet.get_tissuenet_data( 131 path=os.path.join(path, "tissuenet"), split="test", download=True, 132 ), 133 "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_data( 134 root=os.path.join(path, "neurips_cellseg"), split="test", download=True, 135 ), 136 "cellpose": lambda: datasets.cellpose.get_cellpose_data( 137 path=os.path.join(path, "cellpose"), split="train", choice="cyto2", download=True, 138 ), 139 "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_data( 140 path=os.path.join(path, "dynamicnuclearnet"), split="test", download=True, 141 ), 142 "orgasegment": lambda: datasets.orgasegment.get_orgasegment_data( 143 path=os.path.join(path, "orgasegment"), split="eval", download=True, 144 ), 145 "yeaz": lambda: datasets.yeaz.get_yeaz_data( 146 path=os.path.join(path, "yeaz"), choice="bf", download=True, 147 ), 148 149 # 2d datasets: out-of-domain 150 "arvidsson": lambda: datasets.arvidsson.get_arvidsson_data( 151 path=os.path.join(path, "arvidsson"), split="test", download=True, 152 ), 153 "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_data( 154 path=os.path.join(path, "bitdepth_nucseg"), download=True, 155 ), 156 "cellbindb": lambda: datasets.cellbindb.get_cellbindb_data( 157 path=os.path.join(path, "cellbindb"), download=True, 158 ), 159 "covid_if": lambda: datasets.covid_if.get_covid_if_data( 160 path=os.path.join(path, "covid_if"), download=True, 161 ), 162 "deepseas": lambda: datasets.deepseas.get_deepseas_data( 163 path=os.path.join(path, "deepseas"), split="test", 164 ), 165 "hpa": lambda: datasets.hpa.get_hpa_segmentation_data( 166 path=os.path.join(path, "hpa"), download=True, 167 ), 168 "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_data( 169 path=os.path.join(path, "ifnuclei"), download=True, 170 ), 171 "lizard": lambda: datasets.lizard.get_lizard_data( 172 path=os.path.join(path, "lizard"), split="test", download=True, 173 ), 174 "organoidnet": lambda: datasets.organoidnet.get_organoidnet_data( 175 path=os.path.join(path, "organoidnet"), split="Test", download=True, 176 ), 177 "toiam": lambda: datasets.toiam.get_toiam_data( 178 path=os.path.join(path, "toiam"), download=True, 179 ), 180 "vicar": lambda: datasets.vicar.get_vicar_data( 181 path=os.path.join(path, "vicar"), download=True, 182 ), 183 184 # 3d datasets: in-domain 185 "plantseg_root": lambda: datasets.plantseg.get_plantseg_data( 186 path=os.path.join(path, "plantseg_root"), split="test", download=True, name="root", 187 ), 188 189 # 3d datasets: out-of-domain 190 "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_data( 191 path=os.path.join(path, "plantseg_ovules"), split="test", download=True, name="ovules", 192 ), 193 "gonuclear": lambda: datasets.gonuclear.get_gonuclear_data( 194 path=os.path.join(path, "gonuclear"), download=True, 195 ), 196 "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_data( 197 path=os.path.join(path, "mouse_embryo"), download=True, 198 ), 199 "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_data( 200 path=os.path.join(path, "cellseg_3d"), download=True, 201 ), 202 203 # Electron Microscopy datasets 204 205 # 2d datasets: out-of-domain 206 "mitolab_tem": lambda: datasets.cem.get_benchmark_data( 207 path=os.path.join(path, "mitolab"), dataset_id=7, download=True 208 ), 209 210 # 3d datasets: out-of-domain 211 "lucchi": lambda: datasets.lucchi.get_lucchi_data( 212 path=os.path.join(path, "lucchi"), split="test", download=True, 213 ), 214 "mitolab_3d": lambda: [ 215 datasets.cem.get_benchmark_data( 216 path=os.path.join(path, "mitolab"), dataset_id=dataset_id, download=True, 217 ) for dataset_id in range(1, 7) 218 ], 219 "uro_cell": lambda: datasets.uro_cell.get_uro_cell_data( 220 path=os.path.join(path, "uro_cell"), download=True, 221 ), 222 "vnc": lambda: datasets.vnc.get_vnc_data( 223 path=os.path.join(path, "vnc"), download=True, 224 ), 225 "sponge_em": lambda: datasets.sponge_em.get_sponge_em_data( 226 path=os.path.join(path, "sponge_em"), download=True, 227 ), 228 "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_data( 229 path=os.path.join(path, "nuc_mm"), sample="mouse", download=True, 230 ), 231 "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_data( 232 path=os.path.join(path, "nuc_mm"), sample="zebrafish", download=True, 233 ), 234 "asem_mito": lambda: datasets.asem.get_asem_data( 235 path=os.path.join(path, "asem"), volume_ids=datasets.asem.ORGANELLES["mito"], download=True, 236 ), 237 "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_data( 238 path=os.path.join(path, "platynereis"), name="cilia", download=True, 239 ), 240 } 241 242 if dataset_choice is None: 243 dataset_choice = available_datasets.keys() 244 else: 245 if not isinstance(dataset_choice, list): 246 dataset_choice = [dataset_choice] 247 248 for choice in dataset_choice: 249 if choice in available_datasets: 250 available_datasets[choice]() 251 else: 252 raise ValueError(f"'{choice}' is not a supported choice of dataset.") 253 254 return dataset_choice 255 256 257def _extract_slices_from_dataset(path, dataset_choice, crops_per_input=10): 258 """Extracts crops of desired shapes for performing evaluation in both 2d and 3d using `micro-sam`. 259 260 Args: 261 path: The path to directory where the supported datasets have be downloaded 262 for benchmarking Segment Anything models. 263 dataset_choice: The name of the dataset of choice to extract crops. 264 crops_per_input: The maximum number of crops to extract per inputs. 265 extract_2d: Whether to extract 2d crops from 3d patches. 266 267 Returns: 268 Filepath to the folder where extracted images are stored. 269 Filepath to the folder where corresponding extracted labels are stored. 270 The number of dimensions supported by the input. 271 """ 272 ndim = 2 if dataset_choice in [*LM_2D_DATASETS, *EM_2D_DATASETS] else 3 273 tile_shape = (512, 512) if ndim == 2 else (32, 512, 512) 274 275 # For 3d inputs, we extract both 2d and 3d crops. 276 extract_2d_crops_from_volumes = (ndim == 3) 277 278 available_datasets = { 279 # Light Microscopy datasets 280 281 # 2d: in-domain 282 "livecell": lambda: datasets.livecell.get_livecell_paths(path=path, split="test"), 283 "deepbacs": lambda: datasets.deepbacs.get_deepbacs_paths(path=path, split="test", bac_type="mixed"), 284 "tissuenet": lambda: datasets.tissuenet.get_tissuenet_paths(path=path, split="test"), 285 "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_paths(root=path, split="test"), 286 "cellpose": lambda: datasets.cellpose.get_cellpose_paths(path=path, split="train", choice="cyto2"), 287 "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_paths(path=path, split="test"), 288 "orgasegment": lambda: datasets.orgasegment.get_orgasegment_paths(path=path, split="eval"), 289 "yeaz": lambda: datasets.yeaz.get_yeaz_paths(path=path, choice="bf", split="test"), 290 291 # 2d: out-of-domain 292 "arvidsson": lambda: datasets.arvidsson.get_arvidsson_paths(path=path, split="test"), 293 "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_paths(path=path, magnification="20x"), 294 "cellbindb": lambda: datasets.cellbindb.get_cellbindb_paths( 295 path=path, data_choice=["10×Genomics_DAPI", "DAPI", "mIF"] 296 ), 297 "covid_if": lambda: datasets.covid_if.get_covid_if_paths(path=path), 298 "deepseas": lambda: datasets.deepseas.get_deepseas_paths(path=path, split="test"), 299 "hpa": lambda: datasets.hpa.get_hpa_segmentation_paths(path=path, split="val"), 300 "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_paths(path=path), 301 "lizard": lambda: datasets.lizard.get_lizard_paths(path=path, split="test"), 302 "organoidnet": lambda: datasets.organoidnet.get_organoidnet_paths(path=path, split="Test"), 303 "toiam": lambda: datasets.toiam.get_toiam_paths(path=path), 304 "vicar": lambda: datasets.vicar.get_vicar_paths(path=path), 305 306 # 3d: in-domain 307 "plantseg_root": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="root", split="test"), 308 309 # 3d: out-of-domain 310 "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="ovules", split="test"), 311 "gonuclear": lambda: datasets.gonuclear.get_gonuclear_paths(path=path), 312 "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_paths(path=path, name="nuclei", split="val"), 313 "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_paths(path=path), 314 315 # Electron Microscopy datasets 316 317 # 2d: out-of-domain 318 "mitolab_tem": lambda: datasets.cem.get_benchmark_paths( 319 path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=7 320 )[:2], 321 322 # 3d: out-of-domain"lucchi": lambda: datasets.lucchi.get_lucchi_paths(path=path, split="test"), 323 "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="cilia"), 324 "uro_cell": lambda: datasets.uro_cell.get_uro_cell_paths(path=path, target="mito"), 325 "vnc": lambda: datasets.vnc.get_vnc_mito_paths(path=path), 326 "sponge_em": lambda: datasets.sponge_em.get_sponge_em_paths(path=path, sample_ids=None), 327 "mitolab_3d": lambda: ( 328 [ 329 datasets.cem.get_benchmark_paths( 330 path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=i 331 )[0] for i in range(1, 7) 332 ], 333 [ 334 datasets.cem.get_benchmark_paths( 335 path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=i 336 )[1] for i in range(1, 7) 337 ] 338 ), 339 "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="mouse", split="val"), 340 "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="zebrafish", split="val"), 341 "asem_mito": lambda: datasets.asem.get_asem_paths(path=path, volume_ids=datasets.asem.ORGANELLES["mito"]) 342 } 343 344 if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice in ["cellseg_3d", "mitolab_3d"]: 345 image_paths, gt_paths = available_datasets[dataset_choice]() 346 347 if dataset_choice in DATASET_RETURNS_FOLDER: 348 image_paths = glob(os.path.join(image_paths, DATASET_RETURNS_FOLDER[dataset_choice])) 349 gt_paths = glob(os.path.join(gt_paths, DATASET_RETURNS_FOLDER[dataset_choice])) 350 351 image_paths, gt_paths = natsorted(image_paths), natsorted(gt_paths) 352 assert len(image_paths) == len(gt_paths) 353 354 paths_set = zip(image_paths, gt_paths) 355 356 else: 357 image_paths = available_datasets[dataset_choice]() 358 if isinstance(image_paths, str): 359 paths_set = [image_paths] 360 else: 361 paths_set = natsorted(image_paths) 362 363 # Directory where we store the extracted ROIs. 364 save_image_dir = [os.path.join(path, f"roi_{ndim}d", "inputs")] 365 save_gt_dir = [os.path.join(path, f"roi_{ndim}d", "labels")] 366 if extract_2d_crops_from_volumes: 367 save_image_dir.append(os.path.join(path, "roi_2d", "inputs")) 368 save_gt_dir.append(os.path.join(path, "roi_2d", "labels")) 369 370 _dir_exists = [os.path.exists(idir) and os.path.exists(gdir) for idir, gdir in zip(save_image_dir, save_gt_dir)] 371 if all(_dir_exists): 372 return ndim 373 374 [os.makedirs(idir, exist_ok=True) for idir in save_image_dir] 375 [os.makedirs(gdir, exist_ok=True) for gdir in save_gt_dir] 376 377 # Logic to extract relevant patches for inference 378 image_counter = 1 379 for per_paths in tqdm(paths_set, desc=f"Extracting {ndim}d patches for {dataset_choice}"): 380 if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice in ["cellseg_3d", "mitolab_3d"]: # noqa 381 image_path, gt_path = per_paths 382 image, gt = util.load_image_data(image_path), util.load_image_data(gt_path) 383 384 else: 385 image_path = per_paths 386 gt = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][1]) 387 if dataset_choice == "hpa": 388 # Get inputs per channel and stack them together to make the desired 3 channel image. 389 image = np.stack( 390 [util.load_image_data(image_path, k) for k in DATASET_CONTAINER_KEYS[dataset_choice][0]], axis=0, 391 ) 392 # Resize inputs to desired tile shape, in favor of working with the shape of foreground. 393 from torch_em.transform.generic import ResizeLongestSideInputs 394 raw_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_rgb=True) 395 label_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_label=True) 396 image, gt = raw_transform(image).transpose(1, 2, 0), label_transform(gt) 397 398 else: 399 image = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][0]) 400 401 if dataset_choice in ["tissuenet", "lizard"]: 402 if image.ndim == 3 and image.shape[0] == 3: # Make channels last for tissuenet RGB-style images. 403 image = image.transpose(1, 2, 0) 404 405 # Allow RGBs to stay as it is with channels last 406 if image.ndim == 3 and image.shape[-1] == 3: 407 skip_smaller_shape = (np.array(image.shape) >= np.array((*tile_shape, 3))).all() 408 else: 409 skip_smaller_shape = (np.array(image.shape) >= np.array(tile_shape)).all() 410 411 # Ensure ground truth has instance labels. 412 gt = connected_components(gt) 413 414 if len(np.unique(gt)) == 1: # There could be labels which does not have any annotated foreground. 415 continue 416 417 # Let's extract and save all the crops. 418 # The first round of extraction is always to match the desired input dimensions. 419 image_crops, gt_crops = _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input) 420 image_counter = _save_image_label_crops( 421 image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir[0], save_gt_dir[0] 422 ) 423 424 # The next round of extraction is to get 2d crops from 3d inputs. 425 if extract_2d_crops_from_volumes: 426 curr_tile_shape = tile_shape[1:] # We expect 2d tile shape for this stage. 427 428 curr_image_crops, curr_gt_crops = [], [] 429 for per_z_im, per_z_gt in zip(image, gt): 430 curr_skip_smaller_shape = (np.array(per_z_im.shape) >= np.array(curr_tile_shape)).all() 431 432 image_crops, gt_crops = _get_crops_for_input( 433 image=per_z_im, gt=per_z_gt, ndim=2, 434 tile_shape=curr_tile_shape, 435 skip_smaller_shape=curr_skip_smaller_shape, 436 crops_per_input=crops_per_input, 437 ) 438 curr_image_crops.extend(image_crops) 439 curr_gt_crops.extend(gt_crops) 440 441 image_counter = _save_image_label_crops( 442 curr_image_crops, curr_gt_crops, dataset_choice, 2, image_counter, save_image_dir[1], save_gt_dir[1] 443 ) 444 445 return ndim 446 447 448def _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input): 449 tiling = blocking([0] * ndim, gt.shape, tile_shape) 450 n_tiles = tiling.numberOfBlocks 451 tiles = [tiling.getBlock(tile_id) for tile_id in range(n_tiles)] 452 crop_boxes = [ 453 tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) for tile in tiles 454 ] 455 n_ids = [idx for idx in range(len(crop_boxes))] 456 n_instances = [len(np.unique(gt[crop])) for crop in crop_boxes] 457 458 # Extract the desired number of patches with higher number of instances. 459 image_crops, gt_crops = [], [] 460 for i, (per_n_instance, per_id) in enumerate(sorted(zip(n_instances, n_ids), reverse=True), start=1): 461 crop_box = crop_boxes[per_id] 462 crop_image, crop_gt = image[crop_box], gt[crop_box] 463 464 # NOTE: We avoid using the crops which do not match the desired tile shape. 465 _rtile_shape = (*tile_shape, 3) if image.ndim == 3 and image.shape[-1] == 3 else tile_shape # For RGB images. 466 if skip_smaller_shape and crop_image.shape != _rtile_shape: 467 continue 468 469 # NOTE: There could be a case where some later patches are invalid. 470 if per_n_instance == 1: 471 break 472 473 image_crops.append(crop_image) 474 gt_crops.append(crop_gt) 475 476 # NOTE: If the number of patches extracted have been fulfiled, we stop sampling patches. 477 if len(image_crops) > 0 and i >= crops_per_input: 478 break 479 480 return image_crops, gt_crops 481 482 483def _save_image_label_crops(image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir, save_gt_dir): 484 for image_crop, gt_crop in tqdm( 485 zip(image_crops, gt_crops), total=len(image_crops), desc=f"Saving {ndim}d crops for {dataset_choice}" 486 ): 487 fname = f"{dataset_choice}_{image_counter:05}.tif" 488 489 if image_crop.ndim == 3 and image_crop.shape[-1] == 3: 490 assert image_crop.shape[:2] == gt_crop.shape 491 else: 492 assert image_crop.shape == gt_crop.shape 493 494 imageio.imwrite(os.path.join(save_image_dir, fname), image_crop, compression="zlib") 495 imageio.imwrite(os.path.join(save_gt_dir, fname), gt_crop, compression="zlib") 496 497 image_counter += 1 498 499 return image_counter 500 501 502def _get_image_label_paths(path, ndim): 503 image_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "inputs", "*"))) 504 gt_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "labels", "*"))) 505 return image_paths, gt_paths 506 507 508def _run_automatic_segmentation_per_dataset( 509 image_paths: List[Union[os.PathLike, str]], 510 gt_paths: List[Union[os.PathLike, str]], 511 model_type: str, 512 output_folder: Union[os.PathLike, str], 513 ndim: Optional[int] = None, 514 device: Optional[Union[torch.device, str]] = None, 515 checkpoint_path: Optional[Union[os.PathLike, str]] = None, 516 run_amg: Optional[bool] = None, 517 **auto_seg_kwargs 518): 519 """Functionality to run automatic segmentation for multiple input files at once. 520 It stores the evaluated automatic segmentation results (quantitative). 521 522 Args: 523 image_paths: List of filepaths for the input image data. 524 gt_paths: List of filepaths for the corresponding label data. 525 model_type: The choice of image encoder for the Segment Anything model. 526 output_folder: Filepath to the folder where we store all the results. 527 ndim: The number of input dimensions. 528 device: The torch device. 529 checkpoint_path: The filepath where the model checkpoints are stored. 530 run_amg: Whether to run automatic segmentation in AMG mode. 531 auto_seg_kwargs: Additional arguments for automatic segmentation parameters. 532 """ 533 # First, we check if 'run_amg' is done, whether decoder is available or not. 534 # Depending on that, we can set 'run_amg' to the default best automatic segmentation (i.e. AIS > AMG). 535 if run_amg is None or (not run_amg): # The 2nd condition checks if you want AIS and if decoder state exists or not. 536 _, state = util.get_sam_model( 537 model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True 538 ) 539 run_amg = ("decoder_state" not in state) 540 541 experiment_name = "AMG" if run_amg else "AIS" 542 fname = f"{experiment_name.lower()}_{ndim}d" 543 544 result_path = os.path.join(output_folder, "results", f"{fname}.csv") 545 if os.path.exists(result_path): 546 return 547 548 prediction_dir = os.path.join(output_folder, fname, "inference") 549 os.makedirs(prediction_dir, exist_ok=True) 550 551 # Get the predictor (and the additional instance segmentation decoder, if available). 552 predictor, segmenter = get_predictor_and_segmenter( 553 model_type=model_type, checkpoint=checkpoint_path, device=device, amg=run_amg, is_tiled=False, 554 ) 555 556 for image_path in tqdm(image_paths, desc=f"Run {experiment_name} in {ndim}d"): 557 output_path = os.path.join(prediction_dir, os.path.basename(image_path)) 558 if os.path.exists(output_path): 559 continue 560 561 # Run Automatic Segmentation (AMG and AIS) 562 automatic_instance_segmentation( 563 predictor=predictor, 564 segmenter=segmenter, 565 input_path=image_path, 566 output_path=output_path, 567 ndim=ndim, 568 verbose=False, 569 **auto_seg_kwargs 570 ) 571 572 prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*"))) 573 run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths, save_path=result_path) 574 575 576def _run_interactive_segmentation_per_dataset( 577 image_paths: List[Union[os.PathLike, str]], 578 gt_paths: List[Union[os.PathLike, str]], 579 output_folder: Union[os.PathLike, str], 580 model_type: str, 581 prompt_choice: Literal["box", "points"], 582 device: Optional[Union[torch.device, str]] = None, 583 ndim: Optional[int] = None, 584 checkpoint_path: Optional[Union[os.PathLike, str]] = None, 585 use_masks: bool = False, 586): 587 """Functionality to run interactive segmentation for multiple input files at once. 588 It stores the evaluated interactive segmentation results. 589 590 Args: 591 image_paths: List of filepaths for the input image data. 592 gt_paths: List of filepaths for the corresponding label data. 593 output_folder: Filepath to the folder where we store all the results. 594 model_type: The choice of model type for Segment Anything. 595 prompt_choice: The choice of initial prompts to begin the interactive segmentation. 596 device: The torch device. 597 ndim: The number of input dimensions. 598 checkpoint_path: The filepath for stored checkpoints. 599 use_masks: Whether to use masks for iterative prompting. 600 """ 601 if ndim == 2: 602 # Get the Segment Anything predictor. 603 predictor = util.get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path) 604 605 prediction_root = os.path.join( 606 output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}", 607 "iterative_prompting_" + ("with_masks" if use_masks else "without_masks") 608 ) 609 610 # Run interactive instance segmentation 611 # (starting with box and points followed by iterative prompt-based correction) 612 run_inference_with_iterative_prompting( 613 predictor=predictor, 614 image_paths=image_paths, 615 gt_paths=gt_paths, 616 embedding_dir=None, # We set this to None to compute embeddings on-the-fly. 617 prediction_dir=prediction_root, 618 start_with_box_prompt=(prompt_choice == "box"), 619 use_masks=use_masks, 620 # TODO: add parameter for deform over box prompts (to simulate prompts in practice). 621 ) 622 623 # Evaluate the interactive instance segmentation. 624 run_evaluation_for_iterative_prompting( 625 gt_paths=gt_paths, 626 prediction_root=prediction_root, 627 experiment_folder=output_folder, 628 start_with_box_prompt=(prompt_choice == "box"), 629 use_masks=use_masks, 630 ) 631 632 else: 633 save_path = os.path.join(output_folder, "results", f"interactive_segmentation_3d_with_{prompt_choice}.csv") 634 if os.path.exists(save_path): 635 print( 636 f"Results for 3d interactive segmentation with '{prompt_choice}' are already stored at '{save_path}'." 637 ) 638 return 639 640 results = [] 641 for image_path, gt_path in tqdm( 642 zip(image_paths, gt_paths), total=len(image_paths), 643 desc=f"Run interactive segmentation in 3d with '{prompt_choice}'" 644 ): 645 prediction_dir = os.path.join(output_folder, "interactive_segmentation_3d", f"{prompt_choice}") 646 os.makedirs(prediction_dir, exist_ok=True) 647 648 prediction_path = os.path.join(prediction_dir, os.path.basename(image_path)) 649 if os.path.exists(prediction_path): 650 continue 651 652 per_vol_result = segment_slices_from_ground_truth( 653 volume=imageio.imread(image_path), 654 ground_truth=imageio.imread(gt_path), 655 model_type=model_type, 656 checkpoint_path=checkpoint_path, 657 save_path=prediction_path, 658 device=device, 659 interactive_seg_mode=prompt_choice, 660 min_size=10, 661 ) 662 results.append(per_vol_result) 663 664 results = pd.concat(results) 665 results = results.groupby(results.index).mean() 666 results.to_csv(save_path) 667 668 669def _run_benchmark_evaluation_series( 670 image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg, evaluation_methods, 671): 672 seg_kwargs = { 673 "image_paths": image_paths, 674 "gt_paths": gt_paths, 675 "output_folder": output_folder, 676 "ndim": ndim, 677 "model_type": model_type, 678 "device": device, 679 "checkpoint_path": checkpoint_path, 680 } 681 682 # Perform: 683 # a. automatic segmentation (supported in both 2d and 3d, wherever relevant) 684 # The automatic segmentation steps below are configured in a way that AIS has priority (if decoder is found) 685 # Otherwise, it runs for AMG. 686 # Next, we check if the user expects to run AMG as well (after the run for AIS). 687 688 if evaluation_methods != "interactive": # Avoid auto. seg. evaluation for 'interactive'-only run choice. 689 # i. Run automatic segmentation method supported with the SAM model (AMG or AIS). 690 _run_automatic_segmentation_per_dataset(run_amg=None, **seg_kwargs) 691 692 # ii. Run automatic mask generation (AMG). 693 # NOTE: This would only run if the user wants to. Else by default, it is set to 'False'. 694 _run_automatic_segmentation_per_dataset(run_amg=run_amg, **seg_kwargs) 695 696 if evaluation_methods != "automatic": # Avoid int. seg. evaluation for 'automatic'-only run choice. 697 # b. Run interactive segmentation (supported in both 2d and 3d, wherever relevant) 698 _run_interactive_segmentation_per_dataset(prompt_choice="box", **seg_kwargs) 699 _run_interactive_segmentation_per_dataset(prompt_choice="box", use_masks=True, **seg_kwargs) 700 _run_interactive_segmentation_per_dataset(prompt_choice="points", **seg_kwargs) 701 _run_interactive_segmentation_per_dataset(prompt_choice="points", use_masks=True, **seg_kwargs) 702 703 704def _clear_cached_items(retain, path, output_folder): 705 import shutil 706 from pathlib import Path 707 708 REMOVE_LIST = ["data", "crops", "automatic", "interactive"] 709 if retain is None: 710 remove_list = REMOVE_LIST 711 else: 712 assert isinstance(retain, list) 713 remove_list = set(REMOVE_LIST) - set(retain) 714 715 paths = [] 716 # Stage 1: Remove inputs. 717 if "data" in remove_list or "crops" in remove_list: 718 all_paths = glob(os.path.join(path, "*")) 719 720 # In case we want to remove both data and crops, we remove the data folder entirely. 721 if "data" in remove_list and "crops" in remove_list: 722 paths.extend(all_paths) 723 return 724 725 # Next, we verify whether the we only remove either of data or crops. 726 for curr_path in all_paths: 727 if os.path.basename(curr_path).startswith("roi") and "crops" in remove_list: 728 paths.append(curr_path) 729 elif "data" in remove_list: 730 paths.append(curr_path) 731 732 # Stage 2: Remove predictions 733 if "automatic" in remove_list: 734 paths.extend(glob(os.path.join(output_folder, "amg_*"))) 735 paths.extend(glob(os.path.join(output_folder, "ais_*"))) 736 737 if "interactive" in remove_list: 738 paths.extend(glob(os.path.join(output_folder, "interactive_segmentation_*"))) 739 740 [shutil.rmtree(_path) if Path(_path).is_dir() else os.remove(_path) for _path in paths] 741 742 743def run_benchmark_evaluations( 744 input_folder: Union[os.PathLike, str], 745 dataset_choice: str, 746 model_type: str = util._DEFAULT_MODEL, 747 output_folder: Optional[Union[str, os.PathLike]] = None, 748 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 749 run_amg: bool = False, 750 retain: Optional[List[str]] = None, 751 evaluation_methods: Literal["all", "automatic", "interactive"] = "all", 752 ignore_warnings: bool = False, 753): 754 """Run evaluation for benchmarking Segment Anything models on microscopy datasets. 755 756 Args: 757 input_folder: The path to directory where all inputs will be stored and preprocessed. 758 dataset_choice: The dataset choice. 759 model_type: The model choice for SAM. 760 output_folder: The path to directory where all outputs will be stored. 761 checkpoint_path: The checkpoint path 762 run_amg: Whether to run automatic segmentation in AMG mode. 763 retain: Whether to retain certain parts of the benchmark runs. 764 By default, removes everything besides quantitative results. 765 There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'. 766 evaluation_methods: The choice of evaluation methods. 767 By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). 768 Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs. 769 ignore_warnings: Whether to ignore warnings. 770 """ 771 start = time.time() 772 773 with _filter_warnings(ignore_warnings): 774 device = util._get_default_device() 775 776 # Ensure if all the datasets have been installed by default. 777 dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice) 778 779 for choice in dataset_choice: 780 output_folder = os.path.join(output_folder, choice) 781 result_dir = os.path.join(output_folder, "results") 782 os.makedirs(result_dir, exist_ok=True) 783 784 data_path = os.path.join(input_folder, choice) 785 786 # Extrapolate desired set from the datasets: 787 # a. for 2d datasets - 2d patches with the most number of labels present 788 # (in case of volumetric data, choose 2d patches per slice). 789 # b. for 3d datasets - 3d regions of interest with the most number of labels present. 790 ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10) 791 792 # Run inference and evaluation scripts on benchmark datasets. 793 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim) 794 _run_benchmark_evaluation_series( 795 image_paths=image_paths, 796 gt_paths=gt_paths, 797 model_type=model_type, 798 output_folder=output_folder, 799 ndim=ndim, 800 device=device, 801 checkpoint_path=checkpoint_path, 802 run_amg=run_amg, 803 evaluation_methods=evaluation_methods, 804 ) 805 806 # Run inference and evaluation scripts on '2d' crops for volumetric datasets 807 if ndim == 3: 808 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2) 809 _run_benchmark_evaluation_series( 810 image_paths=image_paths, 811 gt_paths=gt_paths, 812 model_type=model_type, 813 output_folder=output_folder, 814 ndim=2, 815 device=device, 816 checkpoint_path=checkpoint_path, 817 run_amg=run_amg, 818 evaluation_methods=evaluation_methods, 819 ) 820 821 _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder) 822 823 diff = time.time() - start 824 hours, rest = divmod(diff, 3600) 825 minutes, seconds = divmod(rest, 60) 826 print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s") 827 828 829def main(): 830 """@private""" 831 import argparse 832 833 available_models = list(util.get_model_names()) 834 available_models = ", ".join(available_models) 835 836 parser = argparse.ArgumentParser( 837 description="Run evaluation for benchmarking Segment Anything models on microscopy datasets." 838 ) 839 parser.add_argument( 840 "-i", "--input_folder", type=str, required=True, 841 help="The path to a directory where the microscopy datasets are and/or will be stored." 842 ) 843 parser.add_argument( 844 "-m", "--model_type", type=str, default=util._DEFAULT_MODEL, 845 help=f"The segment anything model that will be used, one of '{available_models}'. By default, " 846 f"it uses'{util._DEFAULT_MODEL}'." 847 ) 848 parser.add_argument( 849 "-c", "--checkpoint_path", type=str, default=None, 850 help="The filepath to checkpoint from which the SAM model will be loaded." 851 ) 852 parser.add_argument( 853 "-d", "--dataset_choice", type=str, nargs='*', default=None, 854 help="The choice(s) of dataset for evaluating SAM models. Multiple datasets can be specified. " 855 "By default, it evaluates on all datasets." 856 ) 857 parser.add_argument( 858 "-o", "--output_folder", type=str, required=True, 859 help="The path where the results for (automatic and interactive) instance segmentation results " 860 "will be stored as 'csv' files." 861 ) 862 parser.add_argument( 863 "--amg", action="store_true", 864 help="Whether to run automatic segmentation in AMG mode (i.e. the default auto-seg approach for SAM)." 865 ) 866 parser.add_argument( 867 "--retain", nargs="*", default=None, 868 help="By default, the functionality removes all besides quantitative results required for running benchmarks. " 869 "In case you would like to retain parts of the benchmark evaluation for visualization / reproducibility, " 870 "you should choose one or multiple of 'data', 'crops', 'automatic', 'interactive'. " 871 "where they are responsible for either retaining original inputs / extracted crops / " 872 "predictions of automatic segmentation / predictions of interactive segmentation, respectively." 873 ) 874 parser.add_argument( 875 "--evaluate", type=str, default="all", choices=["all", "automatic", "interactive"], 876 help="The choice of methods for benchmarking evaluation for reproducibility. " 877 "By default, we run all evaluations with 'all'. If 'automatic' is chosen, it runs automatic segmentation only " 878 "/ 'interactive' runs interactive segmentation (starting from box and single point) with iterative prompting." 879 ) 880 args = parser.parse_args() 881 882 run_benchmark_evaluations( 883 input_folder=args.input_folder, 884 dataset_choice=args.dataset_choice, 885 model_type=args.model_type, 886 output_folder=args.output_folder, 887 checkpoint_path=args.checkpoint_path, 888 run_amg=args.amg, 889 retain=args.retain, 890 evaluation_methods=args.evaluate, 891 ignore_warnings=True, 892 ) 893 894 895if __name__ == "__main__": 896 main()
LM_2D_DATASETS =
['livecell', 'deepbacs', 'tissuenet', 'neurips_cellseg', 'cellpose', 'dynamicnuclearnet', 'orgasegment', 'yeaz', 'arvidsson', 'bitdepth_nucseg', 'cellbindb', 'covid_if', 'deepseas', 'hpa', 'ifnuclei', 'lizard', 'organoidnet', 'toiam', 'vicar']
LM_3D_DATASETS =
['plantseg_root', 'plantseg_ovules', 'gonuclear', 'mouse_embryo', 'cellseg3d']
EM_2D_DATASETS =
['mitolab_tem']
EM_3D_DATASETS =
['lucchi', 'mitolab', 'uro_cell', 'sponge_em', 'vnc', 'nuc_mm_mouse', 'num_mm_zebrafish', 'platynereis_cilia', 'asem_mito']
DATASET_RETURNS_FOLDER =
{'deepbacs': '*.tif', 'mitolab_tem': '*.tiff'}
DATASET_CONTAINER_KEYS =
{'tissuenet': ['raw/rgb', 'labels/cell'], 'covid_if': ['raw/serum_IgG/s0', 'labels/cells/s0'], 'dynamicnuclearnet': ['raw', 'labels'], 'hpa': [['raw/protein', 'raw/microtubules', 'raw/er'], 'labels'], 'lizard': ['image', 'labels/segmentation'], 'plantseg_root': ['raw', 'label'], 'plantseg_ovules': ['raw', 'label'], 'gonuclear': ['raw/nuclei', 'labels/nuclei'], 'mouse_embryo': ['raw', 'label'], 'cellseg_3d': [None, None], 'lucchi': ['raw', 'labels'], 'uro_cell': ['raw', 'labels/mito'], 'mitolab_3d': [None, None], 'sponge_em': ['volumes/raw', 'volumes/labels/instances'], 'vnc': ['raw', 'labels/mitochondria']}
def
run_benchmark_evaluations( input_folder: Union[os.PathLike, str], dataset_choice: str, model_type: str = 'vit_b_lm', output_folder: Union[os.PathLike, str, NoneType] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, run_amg: bool = False, retain: Optional[List[str]] = None, evaluation_methods: Literal['all', 'automatic', 'interactive'] = 'all', ignore_warnings: bool = False):
744def run_benchmark_evaluations( 745 input_folder: Union[os.PathLike, str], 746 dataset_choice: str, 747 model_type: str = util._DEFAULT_MODEL, 748 output_folder: Optional[Union[str, os.PathLike]] = None, 749 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 750 run_amg: bool = False, 751 retain: Optional[List[str]] = None, 752 evaluation_methods: Literal["all", "automatic", "interactive"] = "all", 753 ignore_warnings: bool = False, 754): 755 """Run evaluation for benchmarking Segment Anything models on microscopy datasets. 756 757 Args: 758 input_folder: The path to directory where all inputs will be stored and preprocessed. 759 dataset_choice: The dataset choice. 760 model_type: The model choice for SAM. 761 output_folder: The path to directory where all outputs will be stored. 762 checkpoint_path: The checkpoint path 763 run_amg: Whether to run automatic segmentation in AMG mode. 764 retain: Whether to retain certain parts of the benchmark runs. 765 By default, removes everything besides quantitative results. 766 There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'. 767 evaluation_methods: The choice of evaluation methods. 768 By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). 769 Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs. 770 ignore_warnings: Whether to ignore warnings. 771 """ 772 start = time.time() 773 774 with _filter_warnings(ignore_warnings): 775 device = util._get_default_device() 776 777 # Ensure if all the datasets have been installed by default. 778 dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice) 779 780 for choice in dataset_choice: 781 output_folder = os.path.join(output_folder, choice) 782 result_dir = os.path.join(output_folder, "results") 783 os.makedirs(result_dir, exist_ok=True) 784 785 data_path = os.path.join(input_folder, choice) 786 787 # Extrapolate desired set from the datasets: 788 # a. for 2d datasets - 2d patches with the most number of labels present 789 # (in case of volumetric data, choose 2d patches per slice). 790 # b. for 3d datasets - 3d regions of interest with the most number of labels present. 791 ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10) 792 793 # Run inference and evaluation scripts on benchmark datasets. 794 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim) 795 _run_benchmark_evaluation_series( 796 image_paths=image_paths, 797 gt_paths=gt_paths, 798 model_type=model_type, 799 output_folder=output_folder, 800 ndim=ndim, 801 device=device, 802 checkpoint_path=checkpoint_path, 803 run_amg=run_amg, 804 evaluation_methods=evaluation_methods, 805 ) 806 807 # Run inference and evaluation scripts on '2d' crops for volumetric datasets 808 if ndim == 3: 809 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2) 810 _run_benchmark_evaluation_series( 811 image_paths=image_paths, 812 gt_paths=gt_paths, 813 model_type=model_type, 814 output_folder=output_folder, 815 ndim=2, 816 device=device, 817 checkpoint_path=checkpoint_path, 818 run_amg=run_amg, 819 evaluation_methods=evaluation_methods, 820 ) 821 822 _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder) 823 824 diff = time.time() - start 825 hours, rest = divmod(diff, 3600) 826 minutes, seconds = divmod(rest, 60) 827 print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s")
Run evaluation for benchmarking Segment Anything models on microscopy datasets.
Arguments:
- input_folder: The path to directory where all inputs will be stored and preprocessed.
- dataset_choice: The dataset choice.
- model_type: The model choice for SAM.
- output_folder: The path to directory where all outputs will be stored.
- checkpoint_path: The checkpoint path
- run_amg: Whether to run automatic segmentation in AMG mode.
- retain: Whether to retain certain parts of the benchmark runs. By default, removes everything besides quantitative results. There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
- evaluation_methods: The choice of evaluation methods. By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
- ignore_warnings: Whether to ignore warnings.