micro_sam.evaluation.benchmark_datasets
1import os 2import time 3from glob import glob 4from tqdm import tqdm 5from natsort import natsorted 6from typing import Union, Optional, List, Literal 7 8import numpy as np 9import pandas as pd 10import imageio.v3 as imageio 11from skimage.measure import label as connected_components 12 13from nifty.tools import blocking 14 15import torch 16 17from torch_em.data import datasets 18 19from micro_sam import util 20 21from . import run_evaluation 22from ..training.training import _filter_warnings 23from .inference import run_inference_with_iterative_prompting 24from .evaluation import run_evaluation_for_iterative_prompting 25from .multi_dimensional_segmentation import segment_slices_from_ground_truth 26from ..automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter 27 28 29LM_2D_DATASETS = [ 30 # in-domain 31 "livecell", # cell segmentation in PhC (has a TEST-set) 32 "deepbacs", # bacteria segmentation in label-free microscopy (has a TEST-set), 33 "tissuenet", # cell segmentation in tissue microscopy images (has a TEST-set), 34 "neurips_cellseg", # cell segmentation in various (has a TEST-set), 35 "cellpose", # cell segmentation in FM (has 'cyto2' on which we can TEST on), 36 "dynamicnuclearnet", # nuclei segmentation in FM (has a TEST-set) 37 "orgasegment", # organoid segmentation in BF (has a TEST-set) 38 "yeaz", # yeast segmentation in BF (has a TEST-set) 39 40 # out-of-domain 41 "arvidsson", # nuclei segmentation in HCS FM 42 "bitdepth_nucseg", # nuclei segmentation in FM 43 "cellbindb", # cell segmentation in various microscopy 44 "covid_if", # cell segmentation in IF 45 "deepseas", # cell segmentation in PhC, 46 "hpa", # cell segmentation in confocal, 47 "ifnuclei", # nuclei segmentation in IFM 48 "lizard", # nuclei segmentation in H&E histopathology, 49 "organoidnet", # organoid segmentation in BF 50 "toiam", # microbial cell segmentation in PhC 51 "vicar", # cell segmentation in label-free 52] 53 54LM_3D_DATASETS = [ 55 # in-domain 56 "plantseg_root", # cell segmentation in lightsheet (has a TEST-set) 57 58 # out-of-domain 59 "plantseg_ovules", # cell segmentation in confocal 60 "gonuclear", # nuclei segmentation in FM 61 "mouse_embryo", # cell segmentation in lightsheet 62 "cellseg3d", # nuclei segmentation in FM 63] 64 65EM_2D_DATASETS = ["mitolab_tem"] 66 67EM_3D_DATASETS = [ 68 "mitoem_rat", "mitoem_human", "platynereis_nuclei", "lucchi", "mitolab", "nuc_mm_mouse", 69 "num_mm_zebrafish", "uro_cell", "sponge_em", "platynereis_cilia", "vnc", "asem_mito", 70] 71 72DATASET_RETURNS_FOLDER = { 73 "deepbacs": "*.tif" 74} 75 76DATASET_CONTAINER_KEYS = { 77 # 2d 78 "tissuenet": ["raw/rgb", "labels/cell"], 79 "covid_if": ["raw/serum_IgG/s0", "labels/cells/s0"], 80 "dynamicnuclearnet": ["raw", "labels"], 81 "hpa": [["raw/protein", "raw/microtubules", "raw/er"], "labels"], 82 "lizard": ["image", "labels/segmentation"], 83 84 # 3d 85 "plantseg_root": ["raw", "label"], 86 "plantseg_ovules": ["raw", "label"], 87 "gonuclear": ["raw/nuclei", "labels/nuclei"], 88 "mouse_embryo": ["raw", "label"], 89 "cellseg_3d": [None, None], 90 "lucchi": ["raw", "labels"], 91} 92 93 94def _download_benchmark_datasets(path, dataset_choice): 95 """Ensures whether all the datasets have been downloaded or not. 96 97 Args: 98 path: The path to directory where the supported datasets will be downloaded 99 for benchmarking Segment Anything models. 100 dataset_choice: The choice of dataset, expects the lower case name for the dataset. 101 102 Returns: 103 List of choice of dataset(s). 104 """ 105 available_datasets = { 106 # Light Microscopy datasets 107 108 # 2d datasets: in-domain 109 "livecell": lambda: datasets.livecell.get_livecell_data( 110 path=os.path.join(path, "livecell"), download=True, 111 ), 112 "deepbacs": lambda: datasets.deepbacs.get_deepbacs_data( 113 path=os.path.join(path, "deepbacs"), bac_type="mixed", download=True, 114 ), 115 "tissuenet": lambda: datasets.tissuenet.get_tissuenet_data( 116 path=os.path.join(path, "tissuenet"), split="test", download=True, 117 ), 118 "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_data( 119 root=os.path.join(path, "neurips_cellseg"), split="test", download=True, 120 ), 121 "cellpose": lambda: datasets.cellpose.get_cellpose_data( 122 path=os.path.join(path, "cellpose"), split="train", choice="cyto2", download=True, 123 ), 124 "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_data( 125 path=os.path.join(path, "dynamicnuclearnet"), split="test", download=True, 126 ), 127 "orgasegment": lambda: datasets.orgasegment.get_orgasegment_data( 128 path=os.path.join(path, "orgasegment"), split="eval", download=True, 129 ), 130 "yeaz": lambda: datasets.yeaz.get_yeaz_data( 131 path=os.path.join(path, "yeaz"), choice="bf", download=True, 132 ), 133 134 # 2d datasets: out-of-domain 135 "arvidsson": lambda: datasets.arvidsson.get_arvidsson_data( 136 path=os.path.join(path, "arvidsson"), split="test", download=True, 137 ), 138 "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_data( 139 path=os.path.join(path, "bitdepth_nucseg"), download=True, 140 ), 141 "cellbindb": lambda: datasets.cellbindb.get_cellbindb_data( 142 path=os.path.join(path, "cellbindb"), download=True, 143 ), 144 "covid_if": lambda: datasets.covid_if.get_covid_if_data( 145 path=os.path.join(path, "covid_if"), download=True, 146 ), 147 "deepseas": lambda: datasets.deepseas.get_deepseas_data( 148 path=os.path.join(path, "deepseas"), split="test", 149 ), 150 "hpa": lambda: datasets.hpa.get_hpa_segmentation_data( 151 path=os.path.join(path, "hpa"), download=True, 152 ), 153 "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_data( 154 path=os.path.join(path, "ifnuclei"), download=True, 155 ), 156 "lizard": lambda: datasets.lizard.get_lizard_data( 157 path=os.path.join(path, "lizard"), split="test", download=True, 158 ), 159 "organoidnet": lambda: datasets.organoidnet.get_organoidnet_data( 160 path=os.path.join(path, "organoidnet"), split="Test", download=True, 161 ), 162 "toiam": lambda: datasets.toiam.get_toiam_data( 163 path=os.path.join(path, "toiam"), download=True, 164 ), 165 "vicar": lambda: datasets.vicar.get_vicar_data( 166 path=os.path.join(path, "vicar"), download=True, 167 ), 168 169 # 3d datasets: in-domain 170 "plantseg_root": lambda: datasets.plantseg.get_plantseg_data( 171 path=os.path.join(path, "plantseg_root"), split="test", download=True, name="root", 172 ), 173 174 # 3d datasets: out-of-domain 175 "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_data( 176 path=os.path.join(path, "plantseg_ovules"), split="test", download=True, name="ovules", 177 ), 178 "gonuclear": lambda: datasets.gonuclear.get_gonuclear_data( 179 path=os.path.join(path, "gonuclear"), download=True, 180 ), 181 "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_data( 182 path=os.path.join(path, "mouse_embryo"), download=True, 183 ), 184 "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_data( 185 path=os.path.join(path, "cellseg_3d"), download=True, 186 ), 187 188 # Electron Microscopy datasets 189 "mitoem_rat": lambda: datasets.mitoem.get_mitoem_data( 190 path=os.path.join(path, "mitoem"), samples="rat", split="test", download=True, 191 ), 192 "mitoem_human": lambda: datasets.mitoem.get_mitoem_data( 193 path=os.path.join(path, "mitoem"), samples="human", split="test", download=True, 194 ), 195 "platynereis_nuclei": lambda: datasets.platynereis.get_platy_data( 196 path=os.path.join(path, "platynereis"), name="nuclei", download=True, 197 ), 198 "platynereis_cilia": lambda: datasets.platynereis.get_platy_data( 199 path=os.path.join(path, "platynereis"), name="cilia", download=True, 200 ), 201 "lucchi": lambda: datasets.lucchi.get_lucchi_data( 202 path=os.path.join(path, "lucchi"), split="test", download=True, 203 ), 204 "mitolab_3d": lambda: [ 205 datasets.cem.get_benchmark_data( 206 path=os.path.join(path, "mitolab"), dataset_id=dataset_id, download=True, 207 ) for dataset_id in range(1, 7) 208 ], 209 "mitolab_tem": lambda: datasets.cem.get_benchmark_data( 210 path=os.path.join(path, "mitolab"), dataset_id=7, download=True 211 ), 212 "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_data( 213 path=os.path.join(path, "nuc_mm"), sample="mouse", download=True, 214 ), 215 "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_data( 216 path=os.path.join(path, "nuc_mm"), sample="zebrafish", download=True, 217 ), 218 "uro_cell": lambda: datasets.uro_cell.get_uro_cell_data( 219 path=os.path.join(path, "uro_cell"), download=True, 220 ), 221 "sponge_em": lambda: datasets.sponge_em.get_sponge_em_data( 222 path=os.path.join(path, "sponge_em"), download=True, 223 ), 224 "vnc": lambda: datasets.vnc.get_vnc_data( 225 path=os.path.join(path, "vnc"), download=True, 226 ), 227 "asem_mito": lambda: datasets.asem.get_asem_data( 228 path=os.path.join(path, "asem"), volume_ids=datasets.asem.ORGANELLES["mito"], download=True, 229 ) 230 } 231 232 if dataset_choice is None: 233 dataset_choice = available_datasets.keys() 234 else: 235 if not isinstance(dataset_choice, list): 236 dataset_choice = [dataset_choice] 237 238 for choice in dataset_choice: 239 if choice in available_datasets: 240 available_datasets[choice]() 241 else: 242 raise ValueError(f"'{choice}' is not a supported choice of dataset.") 243 244 return dataset_choice 245 246 247def _extract_slices_from_dataset(path, dataset_choice, crops_per_input=10): 248 """Extracts crops of desired shapes for performing evaluation in both 2d and 3d using `micro-sam`. 249 250 Args: 251 path: The path to directory where the supported datasets have be downloaded 252 for benchmarking Segment Anything models. 253 dataset_choice: The name of the dataset of choice to extract crops. 254 crops_per_input: The maximum number of crops to extract per inputs. 255 extract_2d: Whether to extract 2d crops from 3d patches. 256 257 Returns: 258 Filepath to the folder where extracted images are stored. 259 Filepath to the folder where corresponding extracted labels are stored. 260 The number of dimensions supported by the input. 261 """ 262 ndim = 2 if dataset_choice in [*LM_2D_DATASETS, *EM_2D_DATASETS] else 3 263 tile_shape = (512, 512) if ndim == 2 else (32, 512, 512) 264 265 # For 3d inputs, we extract both 2d and 3d crops. 266 extract_2d_crops_from_volumes = (ndim == 3) 267 268 available_datasets = { 269 # Light Microscopy datasets 270 271 # 2d: in-domain 272 "livecell": lambda: datasets.livecell.get_livecell_paths(path=path, split="test"), 273 "deepbacs": lambda: datasets.deepbacs.get_deepbacs_paths(path=path, split="test", bac_type="mixed"), 274 "tissuenet": lambda: datasets.tissuenet.get_tissuenet_paths(path=path, split="test"), 275 "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_paths(root=path, split="test"), 276 "cellpose": lambda: datasets.cellpose.get_cellpose_paths(path=path, split="train", choice="cyto2"), 277 "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_paths(path=path, split="test"), 278 "orgasegment": lambda: datasets.orgasegment.get_orgasegment_paths(path=path, split="eval"), 279 "yeaz": lambda: datasets.yeaz.get_yeaz_paths(path=path, choice="bf", split="test"), 280 281 # 2d: out-of-domain 282 "arvidsson": lambda: datasets.arvidsson.get_arvidsson_paths(path=path, split="test"), 283 "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_paths(path=path, magnification="20x"), 284 "cellbindb": lambda: datasets.cellbindb.get_cellbindb_paths( 285 path=path, data_choice=["10×Genomics_DAPI", "DAPI", "mIF"] 286 ), 287 "covid_if": lambda: datasets.covid_if.get_covid_if_paths(path=path), 288 "deepseas": lambda: datasets.deepseas.get_deepseas_paths(path=path, split="test"), 289 "hpa": lambda: datasets.hpa.get_hpa_segmentation_paths(path=path, split="val"), 290 "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_paths(path=path), 291 "lizard": lambda: datasets.lizard.get_lizard_paths(path=path, split="test"), 292 "organoidnet": lambda: datasets.organoidnet.get_organoidnet_paths(path=path, split="Test"), 293 "toiam": lambda: datasets.toiam.get_toiam_paths(path=path), 294 "vicar": lambda: datasets.vicar.get_vicar_paths(path=path), 295 296 # 3d: in-domain 297 "plantseg_root": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="root", split="test"), 298 299 # 3d: out-of-domain 300 "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="ovules", split="test"), 301 "gonuclear": lambda: datasets.gonuclear.get_gonuclear_paths(path=path), 302 "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_paths(path=path, name="nuclei", split="val"), 303 "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_paths(path=path), 304 305 # Electron Microscopy datasets 306 "mitoem_rat": lambda: datasets.mitoem.get_mitoem_paths(path=path, splits="test", samples="rat"), 307 "mitem_human": lambda: datasets.mitoem.get_mitoem_paths(path=path, splits="test", samples="human"), 308 "platynereis_nuclei": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="nuclei"), 309 "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="cilia"), 310 "lucchi": lambda: datasets.lucchi.get_lucchi_paths(path=path, split="test"), 311 "mitolab_3d": lambda: ( 312 [rpath for i in range(1, 7) for rpath in datasets.cem.get_benchmark_paths(path=path, dataset_id=i)[0]], 313 [lpath for i in range(1, 7) for lpath in datasets.cem.get_benchmark_paths(path=path, dataset_id=i)[1]] 314 ), 315 "mitolab_tem": lambda: datasets.cem.get_benchmark_paths(path=path, dataset_id=7), 316 "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="mouse", split="val"), 317 "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="zebrafish", split="val"), 318 "uro_cell": lambda: datasets.uro_cell.get_uro_cell_paths(path=path, target="mito"), 319 "sponge_em": lambda: datasets.sponge_em.get_sponge_em_paths(path=path, sample_ids=None), 320 "vnc": lambda: datasets.vnc.get_vnc_mito_paths(path=path), 321 "asem_mito": lambda: datasets.asem.get_asem_paths(path=path, volume_ids=datasets.asem.ORGANELLES["mito"]) 322 } 323 324 if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice == "cellseg_3d": 325 image_paths, gt_paths = available_datasets[dataset_choice]() 326 327 if dataset_choice in DATASET_RETURNS_FOLDER: 328 image_paths = glob(os.path.join(image_paths, DATASET_RETURNS_FOLDER[dataset_choice])) 329 gt_paths = glob(os.path.join(gt_paths, DATASET_RETURNS_FOLDER[dataset_choice])) 330 331 image_paths, gt_paths = natsorted(image_paths), natsorted(gt_paths) 332 assert len(image_paths) == len(gt_paths) 333 334 paths_set = zip(image_paths, gt_paths) 335 336 else: 337 image_paths = available_datasets[dataset_choice]() 338 if isinstance(image_paths, str): 339 paths_set = [image_paths] 340 else: 341 paths_set = natsorted(image_paths) 342 343 # Directory where we store the extracted ROIs. 344 save_image_dir = [os.path.join(path, f"roi_{ndim}d", "inputs")] 345 save_gt_dir = [os.path.join(path, f"roi_{ndim}d", "labels")] 346 if extract_2d_crops_from_volumes: 347 save_image_dir.append(os.path.join(path, "roi_2d", "inputs")) 348 save_gt_dir.append(os.path.join(path, "roi_2d", "labels")) 349 350 _dir_exists = [os.path.exists(idir) and os.path.exists(gdir) for idir, gdir in zip(save_image_dir, save_gt_dir)] 351 if all(_dir_exists): 352 return ndim 353 354 [os.makedirs(idir, exist_ok=True) for idir in save_image_dir] 355 [os.makedirs(gdir, exist_ok=True) for gdir in save_gt_dir] 356 357 # Logic to extract relevant patches for inference 358 image_counter = 1 359 for per_paths in tqdm(paths_set, desc=f"Extracting {ndim}d patches for {dataset_choice}"): 360 if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice == "cellseg_3d": 361 image_path, gt_path = per_paths 362 image, gt = util.load_image_data(image_path), util.load_image_data(gt_path) 363 else: 364 image_path = per_paths 365 gt = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][1]) 366 if dataset_choice == "hpa": 367 # Get inputs per channel and stack them together to make the desired 3 channel image. 368 image = np.stack( 369 [util.load_image_data(image_path, k) for k in DATASET_CONTAINER_KEYS[dataset_choice][0]], axis=0, 370 ) 371 # Resize inputs to desired tile shape, in favor of working with the shape of foreground. 372 from torch_em.transform.generic import ResizeLongestSideInputs 373 raw_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_rgb=True) 374 label_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_label=True) 375 image, gt = raw_transform(image).transpose(1, 2, 0), label_transform(gt) 376 377 else: 378 image = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][0]) 379 380 if dataset_choice in ["tissuenet", "lizard"]: 381 if image.ndim == 3 and image.shape[0] == 3: # Make channels last for tissuenet RGB-style images. 382 image = image.transpose(1, 2, 0) 383 384 # Allow RGBs to stay as it is with channels last 385 if image.ndim == 3 and image.shape[-1] == 3: 386 skip_smaller_shape = (np.array(image.shape) >= np.array((*tile_shape, 3))).all() 387 else: 388 skip_smaller_shape = (np.array(image.shape) >= np.array(tile_shape)).all() 389 390 # Ensure ground truth has instance labels. 391 gt = connected_components(gt) 392 393 if len(np.unique(gt)) == 1: # There could be labels which does not have any annotated foreground. 394 continue 395 396 # Let's extract and save all the crops. 397 # The first round of extraction is always to match the desired input dimensions. 398 image_crops, gt_crops = _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input) 399 image_counter = _save_image_label_crops( 400 image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir[0], save_gt_dir[0] 401 ) 402 403 # The next round of extraction is to get 2d crops from 3d inputs. 404 if extract_2d_crops_from_volumes: 405 curr_tile_shape = tile_shape[1:] # We expect 2d tile shape for this stage. 406 407 curr_image_crops, curr_gt_crops = [], [] 408 for per_z_im, per_z_gt in zip(image, gt): 409 curr_skip_smaller_shape = (np.array(per_z_im.shape) >= np.array(curr_tile_shape)).all() 410 411 image_crops, gt_crops = _get_crops_for_input( 412 image=per_z_im, gt=per_z_gt, ndim=2, 413 tile_shape=curr_tile_shape, 414 skip_smaller_shape=curr_skip_smaller_shape, 415 crops_per_input=crops_per_input, 416 ) 417 curr_image_crops.extend(image_crops) 418 curr_gt_crops.extend(gt_crops) 419 420 image_counter = _save_image_label_crops( 421 curr_image_crops, curr_gt_crops, dataset_choice, 2, image_counter, save_image_dir[1], save_gt_dir[1] 422 ) 423 424 return ndim 425 426 427def _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input): 428 tiling = blocking([0] * ndim, gt.shape, tile_shape) 429 n_tiles = tiling.numberOfBlocks 430 tiles = [tiling.getBlock(tile_id) for tile_id in range(n_tiles)] 431 crop_boxes = [ 432 tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) for tile in tiles 433 ] 434 n_ids = [idx for idx in range(len(crop_boxes))] 435 n_instances = [len(np.unique(gt[crop])) for crop in crop_boxes] 436 437 # Extract the desired number of patches with higher number of instances. 438 image_crops, gt_crops = [], [] 439 for i, (per_n_instance, per_id) in enumerate(sorted(zip(n_instances, n_ids), reverse=True), start=1): 440 crop_box = crop_boxes[per_id] 441 crop_image, crop_gt = image[crop_box], gt[crop_box] 442 443 # NOTE: We avoid using the crops which do not match the desired tile shape. 444 _rtile_shape = (*tile_shape, 3) if image.ndim == 3 and image.shape[-1] == 3 else tile_shape # For RGB images. 445 if skip_smaller_shape and crop_image.shape != _rtile_shape: 446 continue 447 448 # NOTE: There could be a case where some later patches are invalid. 449 if per_n_instance == 1: 450 break 451 452 image_crops.append(crop_image) 453 gt_crops.append(crop_gt) 454 455 # NOTE: If the number of patches extracted have been fulfiled, we stop sampling patches. 456 if len(image_crops) > 0 and i >= crops_per_input: 457 break 458 459 return image_crops, gt_crops 460 461 462def _save_image_label_crops(image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir, save_gt_dir): 463 for image_crop, gt_crop in tqdm( 464 zip(image_crops, gt_crops), total=len(image_crops), desc=f"Saving {ndim}d crops for {dataset_choice}" 465 ): 466 fname = f"{dataset_choice}_{image_counter:05}.tif" 467 468 if image_crop.ndim == 3 and image_crop.shape[-1] == 3: 469 assert image_crop.shape[:2] == gt_crop.shape 470 else: 471 assert image_crop.shape == gt_crop.shape 472 473 imageio.imwrite(os.path.join(save_image_dir, fname), image_crop, compression="zlib") 474 imageio.imwrite(os.path.join(save_gt_dir, fname), gt_crop, compression="zlib") 475 476 image_counter += 1 477 478 return image_counter 479 480 481def _get_image_label_paths(path, ndim): 482 image_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "inputs", "*"))) 483 gt_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "labels", "*"))) 484 return image_paths, gt_paths 485 486 487def _run_automatic_segmentation_per_dataset( 488 image_paths: List[Union[os.PathLike, str]], 489 gt_paths: List[Union[os.PathLike, str]], 490 model_type: str, 491 output_folder: Union[os.PathLike, str], 492 ndim: Optional[int] = None, 493 device: Optional[Union[torch.device, str]] = None, 494 checkpoint_path: Optional[Union[os.PathLike, str]] = None, 495 run_amg: Optional[bool] = None, 496 **auto_seg_kwargs 497): 498 """Functionality to run automatic segmentation for multiple input files at once. 499 It stores the evaluated automatic segmentation results (quantitative). 500 501 Args: 502 image_paths: List of filepaths for the input image data. 503 gt_paths: List of filepaths for the corresponding label data. 504 model_type: The choice of image encoder for the Segment Anything model. 505 output_folder: Filepath to the folder where we store all the results. 506 ndim: The number of input dimensions. 507 device: The torch device. 508 checkpoint_path: The filepath where the model checkpoints are stored. 509 run_amg: Whether to run automatic segmentation in AMG mode. 510 auto_seg_kwargs: Additional arguments for automatic segmentation parameters. 511 """ 512 # First, we check if 'run_amg' is done, whether decoder is available or not. 513 # Depending on that, we can set 'run_amg' to the default best automatic segmentation (i.e. AIS > AMG). 514 if run_amg is None or (not run_amg): # The 2nd condition checks if you want AIS and if decoder state exists or not. 515 _, state = util.get_sam_model( 516 model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True 517 ) 518 run_amg = ("decoder_state" not in state) 519 520 experiment_name = "AMG" if run_amg else "AIS" 521 fname = f"{experiment_name.lower()}_{ndim}d" 522 523 result_path = os.path.join(output_folder, "results", f"{fname}.csv") 524 if os.path.exists(result_path): 525 return 526 527 prediction_dir = os.path.join(output_folder, fname, "inference") 528 os.makedirs(prediction_dir, exist_ok=True) 529 530 # Get the predictor (and the additional instance segmentation decoder, if available). 531 predictor, segmenter = get_predictor_and_segmenter( 532 model_type=model_type, checkpoint=checkpoint_path, device=device, amg=run_amg, is_tiled=False, 533 ) 534 535 for image_path in tqdm(image_paths, desc=f"Run {experiment_name} in {ndim}d"): 536 output_path = os.path.join(prediction_dir, os.path.basename(image_path)) 537 if os.path.exists(output_path): 538 continue 539 540 # Run Automatic Segmentation (AMG and AIS) 541 automatic_instance_segmentation( 542 predictor=predictor, 543 segmenter=segmenter, 544 input_path=image_path, 545 output_path=output_path, 546 ndim=ndim, 547 verbose=False, 548 **auto_seg_kwargs 549 ) 550 551 prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*"))) 552 run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths, save_path=result_path) 553 554 555def _run_interactive_segmentation_per_dataset( 556 image_paths: List[Union[os.PathLike, str]], 557 gt_paths: List[Union[os.PathLike, str]], 558 output_folder: Union[os.PathLike, str], 559 model_type: str, 560 prompt_choice: Literal["box", "points"], 561 device: Optional[Union[torch.device, str]] = None, 562 ndim: Optional[int] = None, 563 checkpoint_path: Optional[Union[os.PathLike, str]] = None, 564 use_masks: bool = False, 565): 566 """Functionality to run interactive segmentation for multiple input files at once. 567 It stores the evaluated interactive segmentation results. 568 569 Args: 570 image_paths: List of filepaths for the input image data. 571 gt_paths: List of filepaths for the corresponding label data. 572 output_folder: Filepath to the folder where we store all the results. 573 model_type: The choice of model type for Segment Anything. 574 prompt_choice: The choice of initial prompts to begin the interactive segmentation. 575 device: The torch device. 576 ndim: The number of input dimensions. 577 checkpoint_path: The filepath for stored checkpoints. 578 use_masks: Whether to use masks for iterative prompting. 579 """ 580 if ndim == 2: 581 # Get the Segment Anything predictor. 582 predictor = util.get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path) 583 584 prediction_root = os.path.join( 585 output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}", 586 "iterative_prompting_" + ("with_masks" if use_masks else "without_masks") 587 ) 588 589 # Run interactive instance segmentation 590 # (starting with box and points followed by iterative prompt-based correction) 591 run_inference_with_iterative_prompting( 592 predictor=predictor, 593 image_paths=image_paths, 594 gt_paths=gt_paths, 595 embedding_dir=None, # We set this to None to compute embeddings on-the-fly. 596 prediction_dir=prediction_root, 597 start_with_box_prompt=(prompt_choice == "box"), 598 use_masks=use_masks, 599 # TODO: add parameter for deform over box prompts (to simulate prompts in practice). 600 ) 601 602 # Evaluate the interactive instance segmentation. 603 run_evaluation_for_iterative_prompting( 604 gt_paths=gt_paths, 605 prediction_root=prediction_root, 606 experiment_folder=output_folder, 607 start_with_box_prompt=(prompt_choice == "box"), 608 use_masks=use_masks, 609 ) 610 611 else: 612 save_path = os.path.join(output_folder, "results", f"interactive_segmentation_3d_with_{prompt_choice}.csv") 613 if os.path.exists(save_path): 614 print( 615 f"Results for 3d interactive segmentation with '{prompt_choice}' are already stored at '{save_path}'." 616 ) 617 return 618 619 results = [] 620 for image_path, gt_path in tqdm( 621 zip(image_paths, gt_paths), total=len(image_paths), 622 desc=f"Run interactive segmentation in 3d with '{prompt_choice}'" 623 ): 624 prediction_dir = os.path.join(output_folder, "interactive_segmentation_3d", f"{prompt_choice}") 625 os.makedirs(prediction_dir, exist_ok=True) 626 627 prediction_path = os.path.join(prediction_dir, os.path.basename(image_path)) 628 if os.path.exists(prediction_path): 629 continue 630 631 per_vol_result = segment_slices_from_ground_truth( 632 volume=imageio.imread(image_path), 633 ground_truth=imageio.imread(gt_path), 634 model_type=model_type, 635 checkpoint_path=checkpoint_path, 636 save_path=prediction_path, 637 device=device, 638 interactive_seg_mode=prompt_choice, 639 min_size=10, 640 ) 641 results.append(per_vol_result) 642 643 results = pd.concat(results) 644 results = results.groupby(results.index).mean() 645 results.to_csv(save_path) 646 647 648def _run_benchmark_evaluation_series( 649 image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg, evaluation_methods, 650): 651 seg_kwargs = { 652 "image_paths": image_paths, 653 "gt_paths": gt_paths, 654 "output_folder": output_folder, 655 "ndim": ndim, 656 "model_type": model_type, 657 "device": device, 658 "checkpoint_path": checkpoint_path, 659 } 660 661 # Perform: 662 # a. automatic segmentation (supported in both 2d and 3d, wherever relevant) 663 # The automatic segmentation steps below are configured in a way that AIS has priority (if decoder is found) 664 # Otherwise, it runs for AMG. 665 # Next, we check if the user expects to run AMG as well (after the run for AIS). 666 667 if evaluation_methods != "interactive": # Avoid auto. seg. evaluation for 'interactive'-only run choice. 668 # i. Run automatic segmentation method supported with the SAM model (AMG or AIS). 669 _run_automatic_segmentation_per_dataset(run_amg=None, **seg_kwargs) 670 671 # ii. Run automatic mask generation (AMG). 672 # NOTE: This would only run if the user wants to. Else by default, it is set to 'False'. 673 _run_automatic_segmentation_per_dataset(run_amg=run_amg, **seg_kwargs) 674 675 if evaluation_methods != "automatic": # Avoid int. seg. evaluation for 'automatic'-only run choice. 676 # b. Run interactive segmentation (supported in both 2d and 3d, wherever relevant) 677 _run_interactive_segmentation_per_dataset(prompt_choice="box", **seg_kwargs) 678 _run_interactive_segmentation_per_dataset(prompt_choice="box", use_masks=True, **seg_kwargs) 679 _run_interactive_segmentation_per_dataset(prompt_choice="points", **seg_kwargs) 680 _run_interactive_segmentation_per_dataset(prompt_choice="points", use_masks=True, **seg_kwargs) 681 682 683def _clear_cached_items(retain, path, output_folder): 684 import shutil 685 from pathlib import Path 686 687 REMOVE_LIST = ["data", "crops", "automatic", "interactive"] 688 if retain is None: 689 remove_list = REMOVE_LIST 690 else: 691 assert isinstance(retain, list) 692 remove_list = set(REMOVE_LIST) - set(retain) 693 694 paths = [] 695 # Stage 1: Remove inputs. 696 if "data" in remove_list or "crops" in remove_list: 697 all_paths = glob(os.path.join(path, "*")) 698 699 # In case we want to remove both data and crops, we remove the data folder entirely. 700 if "data" in remove_list and "crops" in remove_list: 701 paths.extend(all_paths) 702 return 703 704 # Next, we verify whether the we only remove either of data or crops. 705 for curr_path in all_paths: 706 if os.path.basename(curr_path).startswith("roi") and "crops" in remove_list: 707 paths.append(curr_path) 708 elif "data" in remove_list: 709 paths.append(curr_path) 710 711 # Stage 2: Remove predictions 712 if "automatic" in remove_list: 713 paths.extend(glob(os.path.join(output_folder, "amg_*"))) 714 paths.extend(glob(os.path.join(output_folder, "ais_*"))) 715 716 if "interactive" in remove_list: 717 paths.extend(glob(os.path.join(output_folder, "interactive_segmentation_*"))) 718 719 [shutil.rmtree(_path) if Path(_path).is_dir() else os.remove(_path) for _path in paths] 720 721 722def run_benchmark_evaluations( 723 input_folder: Union[os.PathLike, str], 724 dataset_choice: str, 725 model_type: str = util._DEFAULT_MODEL, 726 output_folder: Optional[Union[str, os.PathLike]] = None, 727 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 728 run_amg: bool = False, 729 retain: Optional[List[str]] = None, 730 evaluation_methods: Literal["all", "automatic", "interactive"] = "all", 731 ignore_warnings: bool = False, 732): 733 """Run evaluation for benchmarking Segment Anything models on microscopy datasets. 734 735 Args: 736 input_folder: The path to directory where all inputs will be stored and preprocessed. 737 dataset_choice: The dataset choice. 738 model_type: The model choice for SAM. 739 output_folder: The path to directory where all outputs will be stored. 740 checkpoint_path: The checkpoint path 741 run_amg: Whether to run automatic segmentation in AMG mode. 742 retain: Whether to retain certain parts of the benchmark runs. 743 By default, removes everything besides quantitative results. 744 There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'. 745 evaluation_methods: The choice of evaluation methods. 746 By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). 747 Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs. 748 ignore_warnings: Whether to ignore warnings. 749 """ 750 start = time.time() 751 752 with _filter_warnings(ignore_warnings): 753 device = util._get_default_device() 754 755 # Ensure if all the datasets have been installed by default. 756 dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice) 757 758 for choice in dataset_choice: 759 output_folder = os.path.join(output_folder, choice) 760 result_dir = os.path.join(output_folder, "results") 761 os.makedirs(result_dir, exist_ok=True) 762 763 data_path = os.path.join(input_folder, choice) 764 765 # Extrapolate desired set from the datasets: 766 # a. for 2d datasets - 2d patches with the most number of labels present 767 # (in case of volumetric data, choose 2d patches per slice). 768 # b. for 3d datasets - 3d regions of interest with the most number of labels present. 769 ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10) 770 771 # Run inference and evaluation scripts on benchmark datasets. 772 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim) 773 _run_benchmark_evaluation_series( 774 image_paths=image_paths, 775 gt_paths=gt_paths, 776 model_type=model_type, 777 output_folder=output_folder, 778 ndim=ndim, 779 device=device, 780 checkpoint_path=checkpoint_path, 781 run_amg=run_amg, 782 evaluation_methods=evaluation_methods, 783 ) 784 785 # Run inference and evaluation scripts on '2d' crops for volumetric datasets 786 if ndim == 3: 787 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2) 788 _run_benchmark_evaluation_series( 789 image_paths=image_paths, 790 gt_paths=gt_paths, 791 model_type=model_type, 792 output_folder=output_folder, 793 ndim=2, 794 device=device, 795 checkpoint_path=checkpoint_path, 796 run_amg=run_amg, 797 evaluation_methods=evaluation_methods, 798 ) 799 800 _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder) 801 802 diff = time.time() - start 803 hours, rest = divmod(diff, 3600) 804 minutes, seconds = divmod(rest, 60) 805 print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s") 806 807 808def main(): 809 """@private""" 810 import argparse 811 812 available_models = list(util.get_model_names()) 813 available_models = ", ".join(available_models) 814 815 parser = argparse.ArgumentParser( 816 description="Run evaluation for benchmarking Segment Anything models on microscopy datasets." 817 ) 818 parser.add_argument( 819 "-i", "--input_folder", type=str, required=True, 820 help="The path to a directory where the microscopy datasets are / will be stored." 821 ) 822 parser.add_argument( 823 "-m", "--model_type", type=str, default=util._DEFAULT_MODEL, 824 help=f"The segment anything model that will be used, one of {available_models}." 825 ) 826 parser.add_argument( 827 "-c", "--checkpoint_path", type=str, default=None, 828 help="Checkpoint from which the SAM model will be loaded loaded." 829 ) 830 parser.add_argument( 831 "-d", "--dataset_choice", type=str, nargs='*', default=None, 832 help="The choice(s) of dataset for evaluating SAM models. Multiple datasets can be specified." 833 ) 834 parser.add_argument( 835 "-o", "--output_folder", type=str, required=True, 836 help="The path where the results for automatic and interactive instance segmentation will be stored as 'csv'." 837 ) 838 parser.add_argument( 839 "--amg", action="store_true", 840 help="Whether to run automatic segmentation in AMG mode (i.e. the default auto-seg approach for SAM)." 841 ) 842 parser.add_argument( 843 "--retain", nargs="*", default=None, 844 help="By default, the functionality removes all besides quantitative results required for running benchmarks. " 845 "In case you would like to retain parts of the benchmark evaluation for visualization / reproducibility, " 846 "you should choose one or multiple of 'data', 'crops', 'automatic', 'interactive'. " 847 "where they are responsible for either retaining original inputs / extracted crops / " 848 "predictions of automatic segmentation / predictions of interactive segmentation, respectively." 849 ) 850 parser.add_argument( 851 "--evaluate", type=str, default="all", choices=["all", "automatic", "interactive"], 852 help="The choice of methods for benchmarking evaluation for reproducibility. " 853 "By default, we run all evaluations with 'all' / 'automatic' runs automatic segmentation only / " 854 "'interactive' runs interactive segmentation (starting from box and single point) with iterative prompting." 855 ) 856 args = parser.parse_args() 857 858 run_benchmark_evaluations( 859 input_folder=args.input_folder, 860 dataset_choice=args.dataset_choice, 861 model_type=args.model_type, 862 output_folder=args.output_folder, 863 checkpoint_path=args.checkpoint_path, 864 run_amg=args.amg, 865 retain=args.retain, 866 evaluation_methods=args.evaluate, 867 ignore_warnings=True, 868 ) 869 870 871if __name__ == "__main__": 872 main()
LM_2D_DATASETS =
['livecell', 'deepbacs', 'tissuenet', 'neurips_cellseg', 'cellpose', 'dynamicnuclearnet', 'orgasegment', 'yeaz', 'arvidsson', 'bitdepth_nucseg', 'cellbindb', 'covid_if', 'deepseas', 'hpa', 'ifnuclei', 'lizard', 'organoidnet', 'toiam', 'vicar']
LM_3D_DATASETS =
['plantseg_root', 'plantseg_ovules', 'gonuclear', 'mouse_embryo', 'cellseg3d']
EM_2D_DATASETS =
['mitolab_tem']
EM_3D_DATASETS =
['mitoem_rat', 'mitoem_human', 'platynereis_nuclei', 'lucchi', 'mitolab', 'nuc_mm_mouse', 'num_mm_zebrafish', 'uro_cell', 'sponge_em', 'platynereis_cilia', 'vnc', 'asem_mito']
DATASET_RETURNS_FOLDER =
{'deepbacs': '*.tif'}
DATASET_CONTAINER_KEYS =
{'tissuenet': ['raw/rgb', 'labels/cell'], 'covid_if': ['raw/serum_IgG/s0', 'labels/cells/s0'], 'dynamicnuclearnet': ['raw', 'labels'], 'hpa': [['raw/protein', 'raw/microtubules', 'raw/er'], 'labels'], 'lizard': ['image', 'labels/segmentation'], 'plantseg_root': ['raw', 'label'], 'plantseg_ovules': ['raw', 'label'], 'gonuclear': ['raw/nuclei', 'labels/nuclei'], 'mouse_embryo': ['raw', 'label'], 'cellseg_3d': [None, None], 'lucchi': ['raw', 'labels']}
def
run_benchmark_evaluations( input_folder: Union[os.PathLike, str], dataset_choice: str, model_type: str = 'vit_b_lm', output_folder: Union[os.PathLike, str, NoneType] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, run_amg: bool = False, retain: Optional[List[str]] = None, evaluation_methods: Literal['all', 'automatic', 'interactive'] = 'all', ignore_warnings: bool = False):
723def run_benchmark_evaluations( 724 input_folder: Union[os.PathLike, str], 725 dataset_choice: str, 726 model_type: str = util._DEFAULT_MODEL, 727 output_folder: Optional[Union[str, os.PathLike]] = None, 728 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 729 run_amg: bool = False, 730 retain: Optional[List[str]] = None, 731 evaluation_methods: Literal["all", "automatic", "interactive"] = "all", 732 ignore_warnings: bool = False, 733): 734 """Run evaluation for benchmarking Segment Anything models on microscopy datasets. 735 736 Args: 737 input_folder: The path to directory where all inputs will be stored and preprocessed. 738 dataset_choice: The dataset choice. 739 model_type: The model choice for SAM. 740 output_folder: The path to directory where all outputs will be stored. 741 checkpoint_path: The checkpoint path 742 run_amg: Whether to run automatic segmentation in AMG mode. 743 retain: Whether to retain certain parts of the benchmark runs. 744 By default, removes everything besides quantitative results. 745 There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'. 746 evaluation_methods: The choice of evaluation methods. 747 By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). 748 Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs. 749 ignore_warnings: Whether to ignore warnings. 750 """ 751 start = time.time() 752 753 with _filter_warnings(ignore_warnings): 754 device = util._get_default_device() 755 756 # Ensure if all the datasets have been installed by default. 757 dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice) 758 759 for choice in dataset_choice: 760 output_folder = os.path.join(output_folder, choice) 761 result_dir = os.path.join(output_folder, "results") 762 os.makedirs(result_dir, exist_ok=True) 763 764 data_path = os.path.join(input_folder, choice) 765 766 # Extrapolate desired set from the datasets: 767 # a. for 2d datasets - 2d patches with the most number of labels present 768 # (in case of volumetric data, choose 2d patches per slice). 769 # b. for 3d datasets - 3d regions of interest with the most number of labels present. 770 ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10) 771 772 # Run inference and evaluation scripts on benchmark datasets. 773 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim) 774 _run_benchmark_evaluation_series( 775 image_paths=image_paths, 776 gt_paths=gt_paths, 777 model_type=model_type, 778 output_folder=output_folder, 779 ndim=ndim, 780 device=device, 781 checkpoint_path=checkpoint_path, 782 run_amg=run_amg, 783 evaluation_methods=evaluation_methods, 784 ) 785 786 # Run inference and evaluation scripts on '2d' crops for volumetric datasets 787 if ndim == 3: 788 image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2) 789 _run_benchmark_evaluation_series( 790 image_paths=image_paths, 791 gt_paths=gt_paths, 792 model_type=model_type, 793 output_folder=output_folder, 794 ndim=2, 795 device=device, 796 checkpoint_path=checkpoint_path, 797 run_amg=run_amg, 798 evaluation_methods=evaluation_methods, 799 ) 800 801 _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder) 802 803 diff = time.time() - start 804 hours, rest = divmod(diff, 3600) 805 minutes, seconds = divmod(rest, 60) 806 print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s")
Run evaluation for benchmarking Segment Anything models on microscopy datasets.
Arguments:
- input_folder: The path to directory where all inputs will be stored and preprocessed.
- dataset_choice: The dataset choice.
- model_type: The model choice for SAM.
- output_folder: The path to directory where all outputs will be stored.
- checkpoint_path: The checkpoint path
- run_amg: Whether to run automatic segmentation in AMG mode.
- retain: Whether to retain certain parts of the benchmark runs. By default, removes everything besides quantitative results. There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
- evaluation_methods: The choice of evaluation methods. By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
- ignore_warnings: Whether to ignore warnings.