micro_sam.evaluation.benchmark_datasets
1import os 2import time 3from glob import glob 4from tqdm import tqdm 5from natsort import natsorted 6from typing import Union, Optional, List, Literal 7 8import numpy as np 9import pandas as pd 10import imageio.v3 as imageio 11from bioimage_cpp.segmentation import label as connected_components 12from bioimage_cpp.utils import Blocking 13 14import torch 15 16from torch_em.data import datasets 17 18from micro_sam import util 19 20from . import run_evaluation 21from ..training.training import _filter_warnings 22from .inference import run_inference_with_iterative_prompting 23from .evaluation import run_evaluation_for_iterative_prompting 24from .multi_dimensional_segmentation import segment_slices_from_ground_truth 25from ..automatic_segmentation import ( 26 automatic_instance_segmentation, get_predictor_and_segmenter, DEFAULT_SEGMENTATION_MODE_WITH_DECODER, 27) 28 29 30LM_2D_DATASETS = [ 31 # in-domain 32 "livecell", # cell segmentation in PhC (has a TEST-set) 33 "deepbacs", # bacteria segmentation in label-free microscopy (has a TEST-set), 34 "tissuenet", # cell segmentation in tissue microscopy images (has a TEST-set), 35 "neurips_cellseg", # cell segmentation in various (has a TEST-set), 36 "cellpose", # cell segmentation in FM (has 'cyto2' on which we can TEST on), 37 "dynamicnuclearnet", # nuclei segmentation in FM (has a TEST-set) 38 "orgasegment", # organoid segmentation in BF (has a TEST-set) 39 "yeaz", # yeast segmentation in BF (has a TEST-set) 40 41 # out-of-domain 42 "arvidsson", # nuclei segmentation in HCS FM 43 "bitdepth_nucseg", # nuclei segmentation in FM 44 "cellbindb", # cell segmentation in various microscopy 45 "covid_if", # cell segmentation in IF 46 "deepseas", # cell segmentation in PhC, 47 "hpa", # cell segmentation in confocal, 48 "ifnuclei", # nuclei segmentation in IFM 49 "lizard", # nuclei segmentation in H&E histopathology, 50 "organoidnet", # organoid segmentation in BF 51 "toiam", # microbial cell segmentation in PhC 52 "vicar", # cell segmentation in label-free 53] 54 55LM_3D_DATASETS = [ 56 # in-domain 57 "plantseg_root", # cell segmentation in lightsheet (has a TEST-set) 58 59 # out-of-domain 60 "plantseg_ovules", # cell segmentation in confocal 61 "gonuclear", # nuclei segmentation in FM 62 "mouse_embryo", # cell segmentation in lightsheet 63 "cellseg3d", # nuclei segmentation in FM 64] 65 66EM_2D_DATASETS = ["mitolab_tem"] 67 68EM_3D_DATASETS = [ 69 # out-of-domain 70 "lucchi", # mitochondria segmentation in vEM 71 "mitolab", # mitochondria segmentation in various 72 "uro_cell", # mitochondria segmentation (and other organelles) in FIB-SEM 73 "sponge_em", # microvili segmentation (and other organelles) in sponge chamber vEM 74 "vnc", # mitochondria segmentation in drosophila brain TEM 75 "nuc_mm_mouse", # nuclei segmentation in microCT 76 "num_mm_zebrafish", # nuclei segmentation in EM 77 "platynereis_cilia", # cilia segmentation (and other structures) in platynereis larvae vEM 78 "asem_mito", # mitochondria segmentation (and other organelles) in FIB-SEM 79] 80 81DATASET_RETURNS_FOLDER = { 82 "deepbacs": "*.tif", 83 "mitolab_tem": "*.tiff" 84} 85 86DATASET_CONTAINER_KEYS = { 87 # 2d (LM) 88 "tissuenet": ["raw/rgb", "labels/cell"], 89 "covid_if": ["raw/serum_IgG/s0", "labels/cells/s0"], 90 "dynamicnuclearnet": ["raw", "labels"], 91 "hpa": [["raw/protein", "raw/microtubules", "raw/er"], "labels"], 92 "lizard": ["image", "labels/segmentation"], 93 94 # 3d (LM) 95 "plantseg_root": ["raw", "label"], 96 "plantseg_ovules": ["raw", "label"], 97 "gonuclear": ["raw/nuclei", "labels/nuclei"], 98 "mouse_embryo": ["raw", "label"], 99 "cellseg_3d": [None, None], 100 101 # 3d (EM) 102 "lucchi": ["raw", "labels"], 103 "uro_cell": ["raw", "labels/mito"], 104 "mitolab_3d": [None, None], 105 "sponge_em": ["volumes/raw", "volumes/labels/instances"], 106 "vnc": ["raw", "labels/mitochondria"] 107} 108 109 110def _download_benchmark_datasets(path, dataset_choice): 111 """Ensures whether all the datasets have been downloaded or not. 112 113 Args: 114 path: The path to directory where the supported datasets will be downloaded 115 for benchmarking Segment Anything models. 116 dataset_choice: The choice of dataset, expects the lower case name for the dataset. 117 118 Returns: 119 List of choice of dataset(s). 120 """ 121 available_datasets = { 122 # Light Microscopy datasets 123 124 # 2d datasets: in-domain 125 "livecell": lambda: datasets.livecell.get_livecell_data( 126 path=os.path.join(path, "livecell"), download=True, 127 ), 128 "deepbacs": lambda: datasets.deepbacs.get_deepbacs_data( 129 path=os.path.join(path, "deepbacs"), bac_type="mixed", download=True, 130 ), 131 "tissuenet": lambda: datasets.tissuenet.get_tissuenet_data( 132 path=os.path.join(path, "tissuenet"), split="test", download=True, 133 ), 134 "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_data( 135 root=os.path.join(path, "neurips_cellseg"), split="test", download=True, 136 ), 137 "cellpose": lambda: datasets.cellpose.get_cellpose_data( 138 path=os.path.join(path, "cellpose"), split="train", choice="cyto2", download=True, 139 ), 140 "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_data( 141 path=os.path.join(path, "dynamicnuclearnet"), split="test", download=True, 142 ), 143 "orgasegment": lambda: datasets.orgasegment.get_orgasegment_data( 144 path=os.path.join(path, "orgasegment"), split="eval", download=True, 145 ), 146 "yeaz": lambda: datasets.yeaz.get_yeaz_data( 147 path=os.path.join(path, "yeaz"), choice="bf", download=True, 148 ), 149 150 # 2d datasets: out-of-domain 151 "arvidsson": lambda: datasets.arvidsson.get_arvidsson_data( 152 path=os.path.join(path, "arvidsson"), split="test", download=True, 153 ), 154 "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_data( 155 path=os.path.join(path, "bitdepth_nucseg"), download=True, 156 ), 157 "cellbindb": lambda: datasets.cellbindb.get_cellbindb_data( 158 path=os.path.join(path, "cellbindb"), download=True, 159 ), 160 "covid_if": lambda: datasets.covid_if.get_covid_if_data( 161 path=os.path.join(path, "covid_if"), download=True, 162 ), 163 "deepseas": lambda: datasets.deepseas.get_deepseas_data( 164 path=os.path.join(path, "deepseas"), split="test", 165 ), 166 "hpa": lambda: datasets.hpa.get_hpa_segmentation_data( 167 path=os.path.join(path, "hpa"), download=True, 168 ), 169 "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_data( 170 path=os.path.join(path, "ifnuclei"), download=True, 171 ), 172 "lizard": lambda: datasets.lizard.get_lizard_data( 173 path=os.path.join(path, "lizard"), split="test", download=True, 174 ), 175 "organoidnet": lambda: datasets.organoidnet.get_organoidnet_data( 176 path=os.path.join(path, "organoidnet"), split="Test", download=True, 177 ), 178 "toiam": lambda: datasets.toiam.get_toiam_data( 179 path=os.path.join(path, "toiam"), download=True, 180 ), 181 "vicar": lambda: datasets.vicar.get_vicar_data( 182 path=os.path.join(path, "vicar"), download=True, 183 ), 184 185 # 3d datasets: in-domain 186 "plantseg_root": lambda: datasets.plantseg.get_plantseg_data( 187 path=os.path.join(path, "plantseg_root"), split="test", download=True, name="root", 188 ), 189 190 # 3d datasets: out-of-domain 191 "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_data( 192 path=os.path.join(path, "plantseg_ovules"), split="test", download=True, name="ovules", 193 ), 194 "gonuclear": lambda: datasets.gonuclear.get_gonuclear_data( 195 path=os.path.join(path, "gonuclear"), download=True, 196 ), 197 "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_data( 198 path=os.path.join(path, "mouse_embryo"), download=True, 199 ), 200 "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_data( 201 path=os.path.join(path, "cellseg_3d"), download=True, 202 ), 203 204 # Electron Microscopy datasets 205 206 # 2d datasets: out-of-domain 207 "mitolab_tem": lambda: datasets.cem.get_benchmark_data( 208 path=os.path.join(path, "mitolab"), dataset_id=7, download=True 209 ), 210 211 # 3d datasets: out-of-domain 212 "lucchi": lambda: datasets.lucchi.get_lucchi_data( 213 path=os.path.join(path, "lucchi"), split="test", download=True, 214 ), 215 "mitolab_3d": lambda: [ 216 datasets.cem.get_benchmark_data( 217 path=os.path.join(path, "mitolab"), dataset_id=dataset_id, download=True, 218 ) for dataset_id in range(1, 7) 219 ], 220 "uro_cell": lambda: datasets.uro_cell.get_uro_cell_data( 221 path=os.path.join(path, "uro_cell"), download=True, 222 ), 223 "vnc": lambda: datasets.vnc.get_vnc_data( 224 path=os.path.join(path, "vnc"), download=True, 225 ), 226 "sponge_em": lambda: datasets.sponge_em.get_sponge_em_data( 227 path=os.path.join(path, "sponge_em"), download=True, 228 ), 229 "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_data( 230 path=os.path.join(path, "nuc_mm"), sample="mouse", download=True, 231 ), 232 "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_data( 233 path=os.path.join(path, "nuc_mm"), sample="zebrafish", download=True, 234 ), 235 "asem_mito": lambda: datasets.asem.get_asem_data( 236 path=os.path.join(path, "asem"), volume_ids=datasets.asem.ORGANELLES["mito"], download=True, 237 ), 238 "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_data( 239 path=os.path.join(path, "platynereis"), name="cilia", download=True, 240 ), 241 } 242 243 if dataset_choice is None: 244 dataset_choice = available_datasets.keys() 245 else: 246 if not isinstance(dataset_choice, list): 247 dataset_choice = [dataset_choice] 248 249 for choice in dataset_choice: 250 if choice in available_datasets: 251 available_datasets[choice]() 252 else: 253 raise ValueError(f"'{choice}' is not a supported choice of dataset.") 254 255 return dataset_choice 256 257 258def _extract_slices_from_dataset(path, dataset_choice, crops_per_input=10): 259 """Extracts crops of desired shapes for performing evaluation in both 2d and 3d using `micro-sam`. 260 261 Args: 262 path: The path to directory where the supported datasets have be downloaded 263 for benchmarking Segment Anything models. 264 dataset_choice: The name of the dataset of choice to extract crops. 265 crops_per_input: The maximum number of crops to extract per inputs. 266 extract_2d: Whether to extract 2d crops from 3d patches. 267 268 Returns: 269 Filepath to the folder where extracted images are stored. 270 Filepath to the folder where corresponding extracted labels are stored. 271 The number of dimensions supported by the input. 272 """ 273 ndim = 2 if dataset_choice in [*LM_2D_DATASETS, *EM_2D_DATASETS] else 3 274 tile_shape = (512, 512) if ndim == 2 else (32, 512, 512) 275 276 # For 3d inputs, we extract both 2d and 3d crops. 277 extract_2d_crops_from_volumes = (ndim == 3) 278 279 available_datasets = { 280 # Light Microscopy datasets 281 282 # 2d: in-domain 283 "livecell": lambda: datasets.livecell.get_livecell_paths(path=path, split="test"), 284 "deepbacs": lambda: datasets.deepbacs.get_deepbacs_paths(path=path, split="test", bac_type="mixed"), 285 "tissuenet": lambda: datasets.tissuenet.get_tissuenet_paths(path=path, split="test"), 286 "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_paths(root=path, split="test"), 287 "cellpose": lambda: datasets.cellpose.get_cellpose_paths(path=path, split="train", choice="cyto2"), 288 "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_paths(path=path, split="test"), 289 "orgasegment": lambda: datasets.orgasegment.get_orgasegment_paths(path=path, split="eval"), 290 "yeaz": lambda: datasets.yeaz.get_yeaz_paths(path=path, choice="bf", split="test"), 291 292 # 2d: out-of-domain 293 "arvidsson": lambda: datasets.arvidsson.get_arvidsson_paths(path=path, split="test"), 294 "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_paths(path=path, magnification="20x"), 295 "cellbindb": lambda: datasets.cellbindb.get_cellbindb_paths( 296 path=path, data_choice=["10×Genomics_DAPI", "DAPI", "mIF"] 297 ), 298 "covid_if": lambda: datasets.covid_if.get_covid_if_paths(path=path), 299 "deepseas": lambda: datasets.deepseas.get_deepseas_paths(path=path, split="test"), 300 "hpa": lambda: datasets.hpa.get_hpa_segmentation_paths(path=path, split="val"), 301 "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_paths(path=path), 302 "lizard": lambda: datasets.lizard.get_lizard_paths(path=path, split="test"), 303 "organoidnet": lambda: datasets.organoidnet.get_organoidnet_paths(path=path, split="Test"), 304 "toiam": lambda: datasets.toiam.get_toiam_paths(path=path), 305 "vicar": lambda: datasets.vicar.get_vicar_paths(path=path), 306 307 # 3d: in-domain 308 "plantseg_root": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="root", split="test"), 309 310 # 3d: out-of-domain 311 "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="ovules", split="test"), 312 "gonuclear": lambda: datasets.gonuclear.get_gonuclear_paths(path=path), 313 "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_paths(path=path, name="nuclei", split="val"), 314 "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_paths(path=path), 315 316 # Electron Microscopy datasets 317 318 # 2d: out-of-domain 319 "mitolab_tem": lambda: datasets.cem.get_benchmark_paths( 320 path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=7 321 )[:2], 322 323 # 3d: out-of-domain"lucchi": lambda: datasets.lucchi.get_lucchi_paths(path=path, split="test"), 324 "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="cilia"), 325 "uro_cell": lambda: datasets.uro_cell.get_uro_cell_paths(path=path, target="mito"), 326 "vnc": lambda: datasets.vnc.get_vnc_mito_paths(path=path), 327 "sponge_em": lambda: datasets.sponge_em.get_sponge_em_paths(path=path, sample_ids=None), 328 "mitolab_3d": lambda: ( 329 [ 330 datasets.cem.get_benchmark_paths( 331 path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=i 332 )[0] for i in range(1, 7) 333 ], 334 [ 335 datasets.cem.get_benchmark_paths( 336 path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=i 337 )[1] for i in range(1, 7) 338 ] 339 ), 340 "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="mouse", split="val"), 341 "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="zebrafish", split="val"), 342 "asem_mito": lambda: datasets.asem.get_asem_paths(path=path, volume_ids=datasets.asem.ORGANELLES["mito"]) 343 } 344 345 if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice in ["cellseg_3d", "mitolab_3d"]: 346 image_paths, gt_paths = available_datasets[dataset_choice]() 347 348 if dataset_choice in DATASET_RETURNS_FOLDER: 349 image_paths = glob(os.path.join(image_paths, DATASET_RETURNS_FOLDER[dataset_choice])) 350 gt_paths = glob(os.path.join(gt_paths, DATASET_RETURNS_FOLDER[dataset_choice])) 351 352 image_paths, gt_paths = natsorted(image_paths), natsorted(gt_paths) 353 assert len(image_paths) == len(gt_paths) 354 355 paths_set = zip(image_paths, gt_paths) 356 357 else: 358 image_paths = available_datasets[dataset_choice]() 359 if isinstance(image_paths, str): 360 paths_set = [image_paths] 361 else: 362 paths_set = natsorted(image_paths) 363 364 # Directory where we store the extracted ROIs. 365 save_image_dir = [os.path.join(path, f"roi_{ndim}d", "inputs")] 366 save_gt_dir = [os.path.join(path, f"roi_{ndim}d", "labels")] 367 if extract_2d_crops_from_volumes: 368 save_image_dir.append(os.path.join(path, "roi_2d", "inputs")) 369 save_gt_dir.append(os.path.join(path, "roi_2d", "labels")) 370 371 _dir_exists = [os.path.exists(idir) and os.path.exists(gdir) for idir, gdir in zip(save_image_dir, save_gt_dir)] 372 if all(_dir_exists): 373 return ndim 374 375 [os.makedirs(idir, exist_ok=True) for idir in save_image_dir] 376 [os.makedirs(gdir, exist_ok=True) for gdir in save_gt_dir] 377 378 # Logic to extract relevant patches for inference 379 image_counter = 1 380 for per_paths in tqdm(paths_set, desc=f"Extracting {ndim}d patches for {dataset_choice}"): 381 if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice in ["cellseg_3d", "mitolab_3d"]: # noqa 382 image_path, gt_path = per_paths 383 image, gt = util.load_image_data(image_path), util.load_image_data(gt_path) 384 385 else: 386 image_path = per_paths 387 gt = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][1]) 388 if dataset_choice == "hpa": 389 # Get inputs per channel and stack them together to make the desired 3 channel image. 390 image = np.stack( 391 [util.load_image_data(image_path, k) for k in DATASET_CONTAINER_KEYS[dataset_choice][0]], axis=0, 392 ) 393 # Resize inputs to desired tile shape, in favor of working with the shape of foreground. 394 from torch_em.transform.generic import ResizeLongestSideInputs 395 raw_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_rgb=True) 396 label_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_label=True) 397 image, gt = raw_transform(image).transpose(1, 2, 0), label_transform(gt) 398 399 else: 400 image = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][0]) 401 402 if dataset_choice in ["tissuenet", "lizard"]: 403 if image.ndim == 3 and image.shape[0] == 3: # Make channels last for tissuenet RGB-style images. 404 image = image.transpose(1, 2, 0) 405 406 # Allow RGBs to stay as it is with channels last 407 if image.ndim == 3 and image.shape[-1] == 3: 408 skip_smaller_shape = (np.array(image.shape) >= np.array((*tile_shape, 3))).all() 409 else: 410 skip_smaller_shape = (np.array(image.shape) >= np.array(tile_shape)).all() 411 412 # Ensure ground truth has instance labels. 413 gt = connected_components(gt) 414 415 if len(np.unique(gt)) == 1: # There could be labels which does not have any annotated foreground. 416 continue 417 418 # Let's extract and save all the crops. 419 # The first round of extraction is always to match the desired input dimensions. 420 image_crops, gt_crops = _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input) 421 image_counter = _save_image_label_crops( 422 image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir[0], save_gt_dir[0] 423 ) 424 425 # The next round of extraction is to get 2d crops from 3d inputs. 426 if extract_2d_crops_from_volumes: 427 curr_tile_shape = tile_shape[1:] # We expect 2d tile shape for this stage. 428 429 curr_image_crops, curr_gt_crops = [], [] 430 for per_z_im, per_z_gt in zip(image, gt): 431 curr_skip_smaller_shape = (np.array(per_z_im.shape) >= np.array(curr_tile_shape)).all() 432 433 image_crops, gt_crops = _get_crops_for_input( 434 image=per_z_im, gt=per_z_gt, ndim=2, 435 tile_shape=curr_tile_shape, 436 skip_smaller_shape=curr_skip_smaller_shape, 437 crops_per_input=crops_per_input, 438 ) 439 curr_image_crops.extend(image_crops) 440 curr_gt_crops.extend(gt_crops) 441 442 image_counter = _save_image_label_crops( 443 curr_image_crops, curr_gt_crops, dataset_choice, 2, image_counter, save_image_dir[1], save_gt_dir[1] 444 ) 445 446 return ndim 447 448 449def _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input): 450 tiling = Blocking([0] * ndim, gt.shape, tile_shape) 451 n_tiles = tiling.number_of_blocks 452 tiles = [tiling.get_block(tile_id) for tile_id in range(n_tiles)] 453 crop_boxes = [ 454 tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) for tile in tiles 455 ] 456 n_ids = [idx for idx in range(len(crop_boxes))] 457 n_instances = [len(np.unique(gt[crop])) for crop in crop_boxes] 458 459 # Extract the desired number of patches with higher number of instances. 460 image_crops, gt_crops = [], [] 461 for i, (per_n_instance, per_id) in enumerate(sorted(zip(n_instances, n_ids), reverse=True), start=1): 462 crop_box = crop_boxes[per_id] 463 crop_image, crop_gt = image[crop_box], gt[crop_box] 464 465 # NOTE: We avoid using the crops which do not match the desired tile shape. 466 _rtile_shape = (*tile_shape, 3) if image.ndim == 3 and image.shape[-1] == 3 else tile_shape # For RGB images. 467 if skip_smaller_shape and crop_image.shape != _rtile_shape: 468 continue 469 470 # NOTE: There could be a case where some later patches are invalid. 471 if per_n_instance == 1: 472 break 473 474 image_crops.append(crop_image) 475 gt_crops.append(crop_gt) 476 477 # NOTE: If the number of patches extracted have been fulfiled, we stop sampling patches. 478 if len(image_crops) > 0 and i >= crops_per_input: 479 break 480 481 return image_crops, gt_crops 482 483 484def _save_image_label_crops(image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir, save_gt_dir): 485 for image_crop, gt_crop in tqdm( 486 zip(image_crops, gt_crops), total=len(image_crops), desc=f"Saving {ndim}d crops for {dataset_choice}" 487 ): 488 fname = f"{dataset_choice}_{image_counter:05}.tif" 489 490 if image_crop.ndim == 3 and image_crop.shape[-1] == 3: 491 assert image_crop.shape[:2] == gt_crop.shape 492 else: 493 assert image_crop.shape == gt_crop.shape 494 495 imageio.imwrite(os.path.join(save_image_dir, fname), image_crop, compression="zlib") 496 imageio.imwrite(os.path.join(save_gt_dir, fname), gt_crop, compression="zlib") 497 498 image_counter += 1 499 500 return image_counter 501 502 503def _get_image_label_paths(path, ndim): 504 image_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "inputs", "*"))) 505 gt_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "labels", "*"))) 506 return image_paths, gt_paths 507 508 509def _run_automatic_segmentation_per_dataset( 510 image_paths: List[Union[os.PathLike, str]], 511 gt_paths: List[Union[os.PathLike, str]], 512 model_type: str, 513 output_folder: Union[os.PathLike, str], 514 ndim: Optional[int] = None, 515 device: Optional[Union[torch.device, str]] = None, 516 checkpoint_path: Optional[Union[os.PathLike, str]] = None, 517 segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = "ais", 518 **auto_seg_kwargs 519): 520 """Functionality to run automatic segmentation for multiple input files at once. 521 It stores the evaluated automatic segmentation results (quantitative). 522 523 Args: 524 image_paths: List of filepaths for the input image data. 525 gt_paths: List of filepaths for the corresponding label data. 526 model_type: The choice of image encoder for the Segment Anything model. 527 output_folder: Filepath to the folder where we store all the results. 528 ndim: The number of input dimensions. 529 device: The torch device. 530 checkpoint_path: The filepath where the model checkpoints are stored. 531 segmentation_mode: The mode for automatic segmentation. 532 auto_seg_kwargs: Additional arguments for automatic segmentation parameters. 533 """ 534 if segmentation_mode is None: # The 2nd condition checks if you want AIS and if decoder state exists or not. 535 _, state = util.get_sam_model( 536 model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True 537 ) 538 segmentation_mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg" 539 540 fname = f"{segmentation_mode}_{ndim}d" 541 542 result_path = os.path.join(output_folder, "results", f"{fname}.csv") 543 if os.path.exists(result_path): 544 return 545 546 prediction_dir = os.path.join(output_folder, fname, "inference") 547 os.makedirs(prediction_dir, exist_ok=True) 548 549 # Get the predictor (and the additional instance segmentation decoder, if available). 550 predictor, segmenter = get_predictor_and_segmenter( 551 model_type=model_type, checkpoint=checkpoint_path, device=device, 552 segmentation_mode=segmentation_mode, is_tiled=False, 553 ) 554 555 for image_path in tqdm(image_paths, desc=f"Run {segmentation_mode} in {ndim}d"): 556 output_path = os.path.join(prediction_dir, os.path.basename(image_path)) 557 if os.path.exists(output_path): 558 continue 559 560 # Run Automatic Segmentation (AMG and AIS) 561 automatic_instance_segmentation( 562 predictor=predictor, 563 segmenter=segmenter, 564 input_path=image_path, 565 output_path=output_path, 566 ndim=ndim, 567 verbose=False, 568 **auto_seg_kwargs 569 ) 570 571 prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*"))) 572 run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths, save_path=result_path) 573 574 575def _run_interactive_segmentation_per_dataset( 576 image_paths: List[Union[os.PathLike, str]], 577 gt_paths: List[Union[os.PathLike, str]], 578 output_folder: Union[os.PathLike, str], 579 model_type: str, 580 prompt_choice: Literal["box", "points"], 581 device: Optional[Union[torch.device, str]] = None, 582 ndim: Optional[int] = None, 583 checkpoint_path: Optional[Union[os.PathLike, str]] = None, 584 use_masks: bool = False, 585): 586 """Functionality to run interactive segmentation for multiple input files at once. 587 It stores the evaluated interactive segmentation results. 588 589 Args: 590 image_paths: List of filepaths for the input image data. 591 gt_paths: List of filepaths for the corresponding label data. 592 output_folder: Filepath to the folder where we store all the results. 593 model_type: The choice of model type for Segment Anything. 594 prompt_choice: The choice of initial prompts to begin the interactive segmentation. 595 device: The torch device. 596 ndim: The number of input dimensions. 597 checkpoint_path: The filepath for stored checkpoints. 598 use_masks: Whether to use masks for iterative prompting. 599 """ 600 if ndim == 2: 601 # Get the Segment Anything predictor. 602 predictor = util.get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path) 603 604 prediction_root = os.path.join( 605 output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}", 606 "iterative_prompting_" + ("with_masks" if use_masks else "without_masks") 607 ) 608 609 # Run interactive instance segmentation 610 # (starting with box and points followed by iterative prompt-based correction) 611 run_inference_with_iterative_prompting( 612 predictor=predictor, 613 image_paths=image_paths, 614 gt_paths=gt_paths, 615 embedding_dir=None, # We set this to None to compute embeddings on-the-fly. 616 prediction_dir=prediction_root, 617 start_with_box_prompt=(prompt_choice == "box"), 618 use_masks=use_masks, 619 # TODO: add parameter for deform over box prompts (to simulate prompts in practice). 620 ) 621 622 # Evaluate the interactive instance segmentation. 623 run_evaluation_for_iterative_prompting( 624 gt_paths=gt_paths, 625 prediction_root=prediction_root, 626 experiment_folder=output_folder, 627 start_with_box_prompt=(prompt_choice == "box"), 628 use_masks=use_masks, 629 ) 630 631 else: 632 save_path = os.path.join(output_folder, "results", f"interactive_segmentation_3d_with_{prompt_choice}.csv") 633 if os.path.exists(save_path): 634 print( 635 f"Results for 3d interactive segmentation with '{prompt_choice}' are already stored at '{save_path}'." 636 ) 637 return 638 639 results = [] 640 for image_path, gt_path in tqdm( 641 zip(image_paths, gt_paths), total=len(image_paths), 642 desc=f"Run interactive segmentation in 3d with '{prompt_choice}'" 643 ): 644 prediction_dir = os.path.join(output_folder, "interactive_segmentation_3d", f"{prompt_choice}") 645 os.makedirs(prediction_dir, exist_ok=True) 646 647 prediction_path = os.path.join(prediction_dir, os.path.basename(image_path)) 648 if os.path.exists(prediction_path): 649 continue 650 651 per_vol_result = segment_slices_from_ground_truth( 652 volume=imageio.imread(image_path), 653 ground_truth=imageio.imread(gt_path), 654 model_type=model_type, 655 checkpoint_path=checkpoint_path, 656 save_path=prediction_path, 657 device=device, 658 interactive_seg_mode=prompt_choice, 659 min_size=10, 660 ) 661 results.append(per_vol_result) 662 663 results = pd.concat(results) 664 results = results.groupby(results.index).mean() 665 results.to_csv(save_path) 666 667 668def _run_benchmark_evaluation_series( 669 image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, 670 segmentation_mode, evaluation_methods, 671): 672 seg_kwargs = { 673 "image_paths": image_paths, 674 "gt_paths": gt_paths, 675 "output_folder": output_folder, 676 "ndim": ndim, 677 "model_type": model_type, 678 "device": device, 679 "checkpoint_path": checkpoint_path, 680 } 681 682 # Perform: 683 # a. automatic segmentation (supported in both 2d and 3d, wherever relevant) 684 # The automatic segmentation steps below are configured in a way that AIS has priority (if decoder is found) 685 # Otherwise, it runs for AMG. 686 # Next, we check if the user expects to run AMG as well (after the run for AIS). 687 688 if evaluation_methods != "interactive": # Avoid auto. seg. evaluation for 'interactive'-only run choice. 689 # i. Run automatic segmentation method supported with the SAM model (AMG or AIS). 690 _run_automatic_segmentation_per_dataset(segmentation_mode=None, **seg_kwargs) 691 692 # ii. Run automatic mask generation (AMG). 693 # NOTE: This would only run if the user wants to. Else by default, it is set to 'False'. 694 _run_automatic_segmentation_per_dataset(segmentation_mode=segmentation_mode, **seg_kwargs) 695 696 if evaluation_methods != "automatic": # Avoid int. seg. evaluation for 'automatic'-only run choice. 697 # b. Run interactive segmentation (supported in both 2d and 3d, wherever relevant) 698 _run_interactive_segmentation_per_dataset(prompt_choice="box", **seg_kwargs) 699 _run_interactive_segmentation_per_dataset(prompt_choice="box", use_masks=True, **seg_kwargs) 700 _run_interactive_segmentation_per_dataset(prompt_choice="points", **seg_kwargs) 701 _run_interactive_segmentation_per_dataset(prompt_choice="points", use_masks=True, **seg_kwargs) 702 703 704def _clear_cached_items(retain, path, output_folder): 705 import shutil 706 from pathlib import Path 707 708 REMOVE_LIST = ["data", "crops", "automatic", "interactive"] 709 if retain is None: 710 remove_list = REMOVE_LIST 711 else: 712 assert isinstance(retain, list) 713 remove_list = set(REMOVE_LIST) - set(retain) 714 715 paths = [] 716 # Stage 1: Remove inputs. 717 if "data" in remove_list or "crops" in remove_list: 718 all_paths = glob(os.path.join(path, "*")) 719 720 # In case we want to remove both data and crops, we remove the data folder entirely. 721 if "data" in remove_list and "crops" in remove_list: 722 paths.extend(all_paths) 723 return 724 725 # Next, we verify whether the we only remove either of data or crops. 726 for curr_path in all_paths: 727 if os.path.basename(curr_path).startswith("roi") and "crops" in remove_list: 728 paths.append(curr_path) 729 elif "data" in remove_list: 730 paths.append(curr_path) 731 732 # Stage 2: Remove predictions 733 if "automatic" in remove_list: 734 paths.extend(glob(os.path.join(output_folder, "amg_*"))) 735 paths.extend(glob(os.path.join(output_folder, "ais_*"))) 736 737 if "interactive" in remove_list: 738 paths.extend(glob(os.path.join(output_folder, "interactive_segmentation_*"))) 739 740 [shutil.rmtree(_path) if Path(_path).is_dir() else os.remove(_path) for _path in paths] 741 742 743def run_benchmark_evaluations( 744 input_folder: Union[os.PathLike, str], 745 dataset_choice: str, 746 model_type: str = util._DEFAULT_MODEL, 747 output_folder: Optional[Union[str, os.PathLike]] = None, 748 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 749 segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None, 750 retain: Optional[List[str]] = None, 751 evaluation_methods: Literal["all", "automatic", "interactive"] = "all", 752 ignore_warnings: bool = False, 753): 754 """Run evaluation for benchmarking Segment Anything models on microscopy datasets. 755 756 Args: 757 input_folder: The path to directory where all inputs will be stored and preprocessed. 758 dataset_choice: The dataset choice. 759 model_type: The model choice for SAM. 760 output_folder: The path to directory where all outputs will be stored. 761 checkpoint_path: The checkpoint path 762 segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'. 763 retain: Whether to retain certain parts of the benchmark runs. 764 By default, removes everything besides quantitative results. 765 There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'. 766 evaluation_methods: The choice of evaluation methods. 767 By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). 768 Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs. 769 ignore_warnings: Whether to ignore warnings. 770 """ 771 start = time.time() 772 773 with _filter_warnings(ignore_warnings): 774 device = util._get_default_device() 775 776 # Ensure if all the datasets have been installed by default. 777 dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice) 778 779 for choice in dataset_choice: 780 output_folder = os.path.join(output_folder, choice) 781 result_dir = os.path.join(output_folder, "results") 782 os.makedirs(result_dir, exist_ok=True) 783 784 data_path = os.path.join(input_folder, choice) 785 786 # Extrapolate desired set from the datasets: 787 # a. for 2d datasets - 2d patches with the most number of labels present 788 # (in case of volumetric data, choose 2d patches per slice). 789 # b. for 3d datasets - 3d regions of interest with the most number of labels present. 790 ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10) 791 792 # Run inference and evaluation scripts on benchmark datasets. 793 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim) 794 _run_benchmark_evaluation_series( 795 image_paths=image_paths, 796 gt_paths=gt_paths, 797 model_type=model_type, 798 output_folder=output_folder, 799 ndim=ndim, 800 device=device, 801 checkpoint_path=checkpoint_path, 802 segmentation_mode=segmentation_mode, 803 evaluation_methods=evaluation_methods, 804 ) 805 806 # Run inference and evaluation scripts on '2d' crops for volumetric datasets 807 if ndim == 3: 808 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2) 809 _run_benchmark_evaluation_series( 810 image_paths=image_paths, 811 gt_paths=gt_paths, 812 model_type=model_type, 813 output_folder=output_folder, 814 ndim=2, 815 device=device, 816 checkpoint_path=checkpoint_path, 817 segmentation_mode=segmentation_mode, 818 evaluation_methods=evaluation_methods, 819 ) 820 821 _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder) 822 823 diff = time.time() - start 824 hours, rest = divmod(diff, 3600) 825 minutes, seconds = divmod(rest, 60) 826 print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s") 827 828 829def main(): 830 """@private""" 831 import argparse 832 833 available_models = list(util.get_model_names()) 834 available_models = ", ".join(available_models) 835 836 parser = argparse.ArgumentParser( 837 description="Run evaluation for benchmarking Segment Anything models on microscopy datasets." 838 ) 839 parser.add_argument( 840 "-i", "--input_folder", type=str, required=True, 841 help="The path to a directory where the microscopy datasets are and/or will be stored." 842 ) 843 parser.add_argument( 844 "-m", "--model_type", type=str, default=util._DEFAULT_MODEL, 845 help=f"The segment anything model that will be used, one of '{available_models}'. By default, " 846 f"it uses'{util._DEFAULT_MODEL}'." 847 ) 848 parser.add_argument( 849 "-c", "--checkpoint_path", type=str, default=None, 850 help="The filepath to checkpoint from which the SAM model will be loaded." 851 ) 852 parser.add_argument( 853 "-d", "--dataset_choice", type=str, nargs='*', default=None, 854 help="The choice(s) of dataset for evaluating SAM models. Multiple datasets can be specified. " 855 "By default, it evaluates on all datasets." 856 ) 857 parser.add_argument( 858 "-o", "--output_folder", type=str, required=True, 859 help="The path where the results for (automatic and interactive) instance segmentation results " 860 "will be stored as 'csv' files." 861 ) 862 parser.add_argument( 863 "--amg", action="store_true", 864 help="Whether to run automatic segmentation in AMG mode (i.e. the default auto-seg approach for SAM)." 865 ) 866 parser.add_argument( 867 "--retain", nargs="*", default=None, 868 help="By default, the functionality removes all besides quantitative results required for running benchmarks. " 869 "In case you would like to retain parts of the benchmark evaluation for visualization / reproducibility, " 870 "you should choose one or multiple of 'data', 'crops', 'automatic', 'interactive'. " 871 "where they are responsible for either retaining original inputs / extracted crops / " 872 "predictions of automatic segmentation / predictions of interactive segmentation, respectively." 873 ) 874 parser.add_argument( 875 "--evaluate", type=str, default="all", choices=["all", "automatic", "interactive"], 876 help="The choice of methods for benchmarking evaluation for reproducibility. " 877 "By default, we run all evaluations with 'all'. If 'automatic' is chosen, it runs automatic segmentation only " 878 "/ 'interactive' runs interactive segmentation (starting from box and single point) with iterative prompting." 879 ) 880 args = parser.parse_args() 881 882 run_benchmark_evaluations( 883 input_folder=args.input_folder, 884 dataset_choice=args.dataset_choice, 885 model_type=args.model_type, 886 output_folder=args.output_folder, 887 checkpoint_path=args.checkpoint_path, 888 segmentation_mode=args.segmentation_mode, 889 retain=args.retain, 890 evaluation_methods=args.evaluate, 891 ignore_warnings=True, 892 ) 893 894 895if __name__ == "__main__": 896 main()
LM_2D_DATASETS =
['livecell', 'deepbacs', 'tissuenet', 'neurips_cellseg', 'cellpose', 'dynamicnuclearnet', 'orgasegment', 'yeaz', 'arvidsson', 'bitdepth_nucseg', 'cellbindb', 'covid_if', 'deepseas', 'hpa', 'ifnuclei', 'lizard', 'organoidnet', 'toiam', 'vicar']
LM_3D_DATASETS =
['plantseg_root', 'plantseg_ovules', 'gonuclear', 'mouse_embryo', 'cellseg3d']
EM_2D_DATASETS =
['mitolab_tem']
EM_3D_DATASETS =
['lucchi', 'mitolab', 'uro_cell', 'sponge_em', 'vnc', 'nuc_mm_mouse', 'num_mm_zebrafish', 'platynereis_cilia', 'asem_mito']
DATASET_RETURNS_FOLDER =
{'deepbacs': '*.tif', 'mitolab_tem': '*.tiff'}
DATASET_CONTAINER_KEYS =
{'tissuenet': ['raw/rgb', 'labels/cell'], 'covid_if': ['raw/serum_IgG/s0', 'labels/cells/s0'], 'dynamicnuclearnet': ['raw', 'labels'], 'hpa': [['raw/protein', 'raw/microtubules', 'raw/er'], 'labels'], 'lizard': ['image', 'labels/segmentation'], 'plantseg_root': ['raw', 'label'], 'plantseg_ovules': ['raw', 'label'], 'gonuclear': ['raw/nuclei', 'labels/nuclei'], 'mouse_embryo': ['raw', 'label'], 'cellseg_3d': [None, None], 'lucchi': ['raw', 'labels'], 'uro_cell': ['raw', 'labels/mito'], 'mitolab_3d': [None, None], 'sponge_em': ['volumes/raw', 'volumes/labels/instances'], 'vnc': ['raw', 'labels/mitochondria']}
def
run_benchmark_evaluations( input_folder: Union[os.PathLike, str], dataset_choice: str, model_type: str = 'vit_b_lm', output_folder: Union[os.PathLike, str, NoneType] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, segmentation_mode: Optional[Literal['amg', 'ais', 'apg']] = None, retain: Optional[List[str]] = None, evaluation_methods: Literal['all', 'automatic', 'interactive'] = 'all', ignore_warnings: bool = False):
744def run_benchmark_evaluations( 745 input_folder: Union[os.PathLike, str], 746 dataset_choice: str, 747 model_type: str = util._DEFAULT_MODEL, 748 output_folder: Optional[Union[str, os.PathLike]] = None, 749 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 750 segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None, 751 retain: Optional[List[str]] = None, 752 evaluation_methods: Literal["all", "automatic", "interactive"] = "all", 753 ignore_warnings: bool = False, 754): 755 """Run evaluation for benchmarking Segment Anything models on microscopy datasets. 756 757 Args: 758 input_folder: The path to directory where all inputs will be stored and preprocessed. 759 dataset_choice: The dataset choice. 760 model_type: The model choice for SAM. 761 output_folder: The path to directory where all outputs will be stored. 762 checkpoint_path: The checkpoint path 763 segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'. 764 retain: Whether to retain certain parts of the benchmark runs. 765 By default, removes everything besides quantitative results. 766 There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'. 767 evaluation_methods: The choice of evaluation methods. 768 By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). 769 Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs. 770 ignore_warnings: Whether to ignore warnings. 771 """ 772 start = time.time() 773 774 with _filter_warnings(ignore_warnings): 775 device = util._get_default_device() 776 777 # Ensure if all the datasets have been installed by default. 778 dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice) 779 780 for choice in dataset_choice: 781 output_folder = os.path.join(output_folder, choice) 782 result_dir = os.path.join(output_folder, "results") 783 os.makedirs(result_dir, exist_ok=True) 784 785 data_path = os.path.join(input_folder, choice) 786 787 # Extrapolate desired set from the datasets: 788 # a. for 2d datasets - 2d patches with the most number of labels present 789 # (in case of volumetric data, choose 2d patches per slice). 790 # b. for 3d datasets - 3d regions of interest with the most number of labels present. 791 ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10) 792 793 # Run inference and evaluation scripts on benchmark datasets. 794 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim) 795 _run_benchmark_evaluation_series( 796 image_paths=image_paths, 797 gt_paths=gt_paths, 798 model_type=model_type, 799 output_folder=output_folder, 800 ndim=ndim, 801 device=device, 802 checkpoint_path=checkpoint_path, 803 segmentation_mode=segmentation_mode, 804 evaluation_methods=evaluation_methods, 805 ) 806 807 # Run inference and evaluation scripts on '2d' crops for volumetric datasets 808 if ndim == 3: 809 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2) 810 _run_benchmark_evaluation_series( 811 image_paths=image_paths, 812 gt_paths=gt_paths, 813 model_type=model_type, 814 output_folder=output_folder, 815 ndim=2, 816 device=device, 817 checkpoint_path=checkpoint_path, 818 segmentation_mode=segmentation_mode, 819 evaluation_methods=evaluation_methods, 820 ) 821 822 _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder) 823 824 diff = time.time() - start 825 hours, rest = divmod(diff, 3600) 826 minutes, seconds = divmod(rest, 60) 827 print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s")
Run evaluation for benchmarking Segment Anything models on microscopy datasets.
Arguments:
- input_folder: The path to directory where all inputs will be stored and preprocessed.
- dataset_choice: The dataset choice.
- model_type: The model choice for SAM.
- output_folder: The path to directory where all outputs will be stored.
- checkpoint_path: The checkpoint path
- segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
- retain: Whether to retain certain parts of the benchmark runs. By default, removes everything besides quantitative results. There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
- evaluation_methods: The choice of evaluation methods. By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
- ignore_warnings: Whether to ignore warnings.