micro_sam.evaluation.benchmark_datasets

  1import os
  2import time
  3from glob import glob
  4from tqdm import tqdm
  5from natsort import natsorted
  6from typing import Union, Optional, List, Literal
  7
  8import numpy as np
  9import pandas as pd
 10import imageio.v3 as imageio
 11from bioimage_cpp.segmentation import label as connected_components
 12from bioimage_cpp.utils import Blocking
 13
 14import torch
 15
 16from torch_em.data import datasets
 17
 18from micro_sam import util
 19
 20from . import run_evaluation
 21from ..training.training import _filter_warnings
 22from .inference import run_inference_with_iterative_prompting
 23from .evaluation import run_evaluation_for_iterative_prompting
 24from .multi_dimensional_segmentation import segment_slices_from_ground_truth
 25from ..automatic_segmentation import (
 26    automatic_instance_segmentation, get_predictor_and_segmenter, DEFAULT_SEGMENTATION_MODE_WITH_DECODER,
 27)
 28
 29
 30LM_2D_DATASETS = [
 31    # in-domain
 32    "livecell",  # cell segmentation in PhC (has a TEST-set)
 33    "deepbacs",  # bacteria segmentation in label-free microscopy (has a TEST-set),
 34    "tissuenet",  # cell segmentation in tissue microscopy images (has a TEST-set),
 35    "neurips_cellseg",  # cell segmentation in various (has a TEST-set),
 36    "cellpose",  # cell segmentation in FM (has 'cyto2' on which we can TEST on),
 37    "dynamicnuclearnet",  # nuclei segmentation in FM (has a TEST-set)
 38    "orgasegment",  # organoid segmentation in BF (has a TEST-set)
 39    "yeaz",  # yeast segmentation in BF (has a TEST-set)
 40
 41    # out-of-domain
 42    "arvidsson",  # nuclei segmentation in HCS FM
 43    "bitdepth_nucseg",  # nuclei segmentation in FM
 44    "cellbindb",  # cell segmentation in various microscopy
 45    "covid_if",  # cell segmentation in IF
 46    "deepseas",  # cell segmentation in PhC,
 47    "hpa",  # cell segmentation in confocal,
 48    "ifnuclei",  # nuclei segmentation in IFM
 49    "lizard",  # nuclei segmentation in H&E histopathology,
 50    "organoidnet",  # organoid segmentation in BF
 51    "toiam",  # microbial cell segmentation in PhC
 52    "vicar",  # cell segmentation in label-free
 53]
 54
 55LM_3D_DATASETS = [
 56    # in-domain
 57    "plantseg_root",  # cell segmentation in lightsheet (has a TEST-set)
 58
 59    # out-of-domain
 60    "plantseg_ovules",  # cell segmentation in confocal
 61    "gonuclear",  # nuclei segmentation in FM
 62    "mouse_embryo",  # cell segmentation in lightsheet
 63    "cellseg3d",  # nuclei segmentation in FM
 64]
 65
 66EM_2D_DATASETS = ["mitolab_tem"]
 67
 68EM_3D_DATASETS = [
 69    # out-of-domain
 70    "lucchi",  # mitochondria segmentation in vEM
 71    "mitolab",  # mitochondria segmentation in various
 72    "uro_cell",  # mitochondria segmentation (and other organelles) in FIB-SEM
 73    "sponge_em",  # microvili segmentation (and other organelles) in sponge chamber vEM
 74    "vnc",  # mitochondria segmentation in drosophila brain TEM
 75    "nuc_mm_mouse",  # nuclei segmentation in microCT
 76    "num_mm_zebrafish",  # nuclei segmentation in EM
 77    "platynereis_cilia",  # cilia segmentation (and other structures) in platynereis larvae vEM
 78    "asem_mito",  # mitochondria segmentation (and other organelles) in FIB-SEM
 79]
 80
 81DATASET_RETURNS_FOLDER = {
 82    "deepbacs": "*.tif",
 83    "mitolab_tem": "*.tiff"
 84}
 85
 86DATASET_CONTAINER_KEYS = {
 87    # 2d (LM)
 88    "tissuenet": ["raw/rgb", "labels/cell"],
 89    "covid_if": ["raw/serum_IgG/s0", "labels/cells/s0"],
 90    "dynamicnuclearnet": ["raw", "labels"],
 91    "hpa": [["raw/protein", "raw/microtubules", "raw/er"], "labels"],
 92    "lizard": ["image", "labels/segmentation"],
 93
 94    # 3d (LM)
 95    "plantseg_root": ["raw", "label"],
 96    "plantseg_ovules": ["raw", "label"],
 97    "gonuclear": ["raw/nuclei", "labels/nuclei"],
 98    "mouse_embryo": ["raw", "label"],
 99    "cellseg_3d": [None, None],
100
101    # 3d (EM)
102    "lucchi": ["raw", "labels"],
103    "uro_cell": ["raw", "labels/mito"],
104    "mitolab_3d": [None, None],
105    "sponge_em": ["volumes/raw", "volumes/labels/instances"],
106    "vnc": ["raw", "labels/mitochondria"]
107}
108
109
110def _download_benchmark_datasets(path, dataset_choice):
111    """Ensures whether all the datasets have been downloaded or not.
112
113    Args:
114        path: The path to directory where the supported datasets will be downloaded
115            for benchmarking Segment Anything models.
116        dataset_choice: The choice of dataset, expects the lower case name for the dataset.
117
118    Returns:
119        List of choice of dataset(s).
120    """
121    available_datasets = {
122        # Light Microscopy datasets
123
124        # 2d datasets: in-domain
125        "livecell": lambda: datasets.livecell.get_livecell_data(
126            path=os.path.join(path, "livecell"), download=True,
127        ),
128        "deepbacs": lambda: datasets.deepbacs.get_deepbacs_data(
129            path=os.path.join(path, "deepbacs"), bac_type="mixed", download=True,
130        ),
131        "tissuenet": lambda: datasets.tissuenet.get_tissuenet_data(
132            path=os.path.join(path, "tissuenet"), split="test", download=True,
133        ),
134        "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_data(
135            root=os.path.join(path, "neurips_cellseg"), split="test", download=True,
136        ),
137        "cellpose": lambda: datasets.cellpose.get_cellpose_data(
138            path=os.path.join(path, "cellpose"), split="train", choice="cyto2", download=True,
139        ),
140        "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_data(
141            path=os.path.join(path, "dynamicnuclearnet"), split="test", download=True,
142        ),
143        "orgasegment": lambda: datasets.orgasegment.get_orgasegment_data(
144            path=os.path.join(path, "orgasegment"), split="eval", download=True,
145        ),
146        "yeaz": lambda: datasets.yeaz.get_yeaz_data(
147            path=os.path.join(path, "yeaz"), choice="bf", download=True,
148        ),
149
150        # 2d datasets: out-of-domain
151        "arvidsson": lambda: datasets.arvidsson.get_arvidsson_data(
152            path=os.path.join(path, "arvidsson"), split="test", download=True,
153        ),
154        "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_data(
155            path=os.path.join(path, "bitdepth_nucseg"), download=True,
156        ),
157        "cellbindb": lambda: datasets.cellbindb.get_cellbindb_data(
158            path=os.path.join(path, "cellbindb"), download=True,
159        ),
160        "covid_if": lambda: datasets.covid_if.get_covid_if_data(
161            path=os.path.join(path, "covid_if"), download=True,
162        ),
163        "deepseas": lambda: datasets.deepseas.get_deepseas_data(
164            path=os.path.join(path, "deepseas"), split="test",
165        ),
166        "hpa": lambda: datasets.hpa.get_hpa_segmentation_data(
167            path=os.path.join(path, "hpa"), download=True,
168        ),
169        "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_data(
170            path=os.path.join(path, "ifnuclei"), download=True,
171        ),
172        "lizard": lambda: datasets.lizard.get_lizard_data(
173            path=os.path.join(path, "lizard"), split="test", download=True,
174        ),
175        "organoidnet": lambda: datasets.organoidnet.get_organoidnet_data(
176            path=os.path.join(path, "organoidnet"), split="Test", download=True,
177        ),
178        "toiam": lambda: datasets.toiam.get_toiam_data(
179            path=os.path.join(path, "toiam"), download=True,
180        ),
181        "vicar": lambda: datasets.vicar.get_vicar_data(
182            path=os.path.join(path, "vicar"), download=True,
183        ),
184
185        # 3d datasets: in-domain
186        "plantseg_root": lambda: datasets.plantseg.get_plantseg_data(
187            path=os.path.join(path, "plantseg_root"), split="test", download=True, name="root",
188        ),
189
190        # 3d datasets: out-of-domain
191        "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_data(
192            path=os.path.join(path, "plantseg_ovules"), split="test", download=True, name="ovules",
193        ),
194        "gonuclear": lambda: datasets.gonuclear.get_gonuclear_data(
195            path=os.path.join(path, "gonuclear"), download=True,
196        ),
197        "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_data(
198            path=os.path.join(path, "mouse_embryo"), download=True,
199        ),
200        "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_data(
201            path=os.path.join(path, "cellseg_3d"), download=True,
202        ),
203
204        # Electron Microscopy datasets
205
206        # 2d datasets: out-of-domain
207        "mitolab_tem": lambda: datasets.cem.get_benchmark_data(
208            path=os.path.join(path, "mitolab"), dataset_id=7, download=True
209        ),
210
211        # 3d datasets: out-of-domain
212        "lucchi": lambda: datasets.lucchi.get_lucchi_data(
213            path=os.path.join(path, "lucchi"), split="test", download=True,
214        ),
215        "mitolab_3d": lambda: [
216            datasets.cem.get_benchmark_data(
217                path=os.path.join(path, "mitolab"), dataset_id=dataset_id, download=True,
218            ) for dataset_id in range(1, 7)
219        ],
220        "uro_cell": lambda: datasets.uro_cell.get_uro_cell_data(
221            path=os.path.join(path, "uro_cell"), download=True,
222        ),
223        "vnc": lambda: datasets.vnc.get_vnc_data(
224            path=os.path.join(path, "vnc"), download=True,
225        ),
226        "sponge_em": lambda: datasets.sponge_em.get_sponge_em_data(
227            path=os.path.join(path, "sponge_em"), download=True,
228        ),
229        "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_data(
230            path=os.path.join(path, "nuc_mm"), sample="mouse", download=True,
231        ),
232        "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_data(
233            path=os.path.join(path, "nuc_mm"), sample="zebrafish", download=True,
234        ),
235        "asem_mito": lambda: datasets.asem.get_asem_data(
236            path=os.path.join(path, "asem"), volume_ids=datasets.asem.ORGANELLES["mito"], download=True,
237        ),
238        "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_data(
239            path=os.path.join(path, "platynereis"), name="cilia", download=True,
240        ),
241    }
242
243    if dataset_choice is None:
244        dataset_choice = available_datasets.keys()
245    else:
246        if not isinstance(dataset_choice, list):
247            dataset_choice = [dataset_choice]
248
249    for choice in dataset_choice:
250        if choice in available_datasets:
251            available_datasets[choice]()
252        else:
253            raise ValueError(f"'{choice}' is not a supported choice of dataset.")
254
255    return dataset_choice
256
257
258def _extract_slices_from_dataset(path, dataset_choice, crops_per_input=10):
259    """Extracts crops of desired shapes for performing evaluation in both 2d and 3d using `micro-sam`.
260
261    Args:
262        path: The path to directory where the supported datasets have be downloaded
263            for benchmarking Segment Anything models.
264        dataset_choice: The name of the dataset of choice to extract crops.
265        crops_per_input: The maximum number of crops to extract per inputs.
266        extract_2d: Whether to extract 2d crops from 3d patches.
267
268    Returns:
269        Filepath to the folder where extracted images are stored.
270        Filepath to the folder where corresponding extracted labels are stored.
271        The number of dimensions supported by the input.
272    """
273    ndim = 2 if dataset_choice in [*LM_2D_DATASETS, *EM_2D_DATASETS] else 3
274    tile_shape = (512, 512) if ndim == 2 else (32, 512, 512)
275
276    # For 3d inputs, we extract both 2d and 3d crops.
277    extract_2d_crops_from_volumes = (ndim == 3)
278
279    available_datasets = {
280        # Light Microscopy datasets
281
282        # 2d: in-domain
283        "livecell": lambda: datasets.livecell.get_livecell_paths(path=path, split="test"),
284        "deepbacs": lambda: datasets.deepbacs.get_deepbacs_paths(path=path, split="test", bac_type="mixed"),
285        "tissuenet": lambda: datasets.tissuenet.get_tissuenet_paths(path=path, split="test"),
286        "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_paths(root=path, split="test"),
287        "cellpose": lambda: datasets.cellpose.get_cellpose_paths(path=path, split="train", choice="cyto2"),
288        "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_paths(path=path, split="test"),
289        "orgasegment": lambda: datasets.orgasegment.get_orgasegment_paths(path=path, split="eval"),
290        "yeaz": lambda: datasets.yeaz.get_yeaz_paths(path=path, choice="bf", split="test"),
291
292        # 2d: out-of-domain
293        "arvidsson": lambda: datasets.arvidsson.get_arvidsson_paths(path=path, split="test"),
294        "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_paths(path=path, magnification="20x"),
295        "cellbindb": lambda: datasets.cellbindb.get_cellbindb_paths(
296            path=path, data_choice=["10×Genomics_DAPI", "DAPI", "mIF"]
297        ),
298        "covid_if": lambda: datasets.covid_if.get_covid_if_paths(path=path),
299        "deepseas": lambda: datasets.deepseas.get_deepseas_paths(path=path, split="test"),
300        "hpa": lambda: datasets.hpa.get_hpa_segmentation_paths(path=path, split="val"),
301        "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_paths(path=path),
302        "lizard": lambda: datasets.lizard.get_lizard_paths(path=path, split="test"),
303        "organoidnet": lambda: datasets.organoidnet.get_organoidnet_paths(path=path, split="Test"),
304        "toiam": lambda: datasets.toiam.get_toiam_paths(path=path),
305        "vicar": lambda: datasets.vicar.get_vicar_paths(path=path),
306
307        # 3d: in-domain
308        "plantseg_root": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="root", split="test"),
309
310        # 3d: out-of-domain
311        "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="ovules", split="test"),
312        "gonuclear": lambda: datasets.gonuclear.get_gonuclear_paths(path=path),
313        "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_paths(path=path, name="nuclei", split="val"),
314        "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_paths(path=path),
315
316        # Electron Microscopy datasets
317
318        # 2d: out-of-domain
319        "mitolab_tem": lambda: datasets.cem.get_benchmark_paths(
320            path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=7
321        )[:2],
322
323        # 3d: out-of-domain"lucchi": lambda: datasets.lucchi.get_lucchi_paths(path=path, split="test"),
324        "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="cilia"),
325        "uro_cell": lambda: datasets.uro_cell.get_uro_cell_paths(path=path, target="mito"),
326        "vnc": lambda: datasets.vnc.get_vnc_mito_paths(path=path),
327        "sponge_em": lambda: datasets.sponge_em.get_sponge_em_paths(path=path, sample_ids=None),
328        "mitolab_3d": lambda: (
329            [
330                datasets.cem.get_benchmark_paths(
331                    path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=i
332                )[0] for i in range(1, 7)
333            ],
334            [
335                datasets.cem.get_benchmark_paths(
336                    path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=i
337                )[1] for i in range(1, 7)
338            ]
339        ),
340        "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="mouse", split="val"),
341        "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="zebrafish", split="val"),
342        "asem_mito": lambda: datasets.asem.get_asem_paths(path=path, volume_ids=datasets.asem.ORGANELLES["mito"])
343    }
344
345    if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice in ["cellseg_3d", "mitolab_3d"]:
346        image_paths, gt_paths = available_datasets[dataset_choice]()
347
348        if dataset_choice in DATASET_RETURNS_FOLDER:
349            image_paths = glob(os.path.join(image_paths, DATASET_RETURNS_FOLDER[dataset_choice]))
350            gt_paths = glob(os.path.join(gt_paths, DATASET_RETURNS_FOLDER[dataset_choice]))
351
352        image_paths, gt_paths = natsorted(image_paths), natsorted(gt_paths)
353        assert len(image_paths) == len(gt_paths)
354
355        paths_set = zip(image_paths, gt_paths)
356
357    else:
358        image_paths = available_datasets[dataset_choice]()
359        if isinstance(image_paths, str):
360            paths_set = [image_paths]
361        else:
362            paths_set = natsorted(image_paths)
363
364    # Directory where we store the extracted ROIs.
365    save_image_dir = [os.path.join(path, f"roi_{ndim}d", "inputs")]
366    save_gt_dir = [os.path.join(path, f"roi_{ndim}d", "labels")]
367    if extract_2d_crops_from_volumes:
368        save_image_dir.append(os.path.join(path, "roi_2d", "inputs"))
369        save_gt_dir.append(os.path.join(path, "roi_2d", "labels"))
370
371    _dir_exists = [os.path.exists(idir) and os.path.exists(gdir) for idir, gdir in zip(save_image_dir, save_gt_dir)]
372    if all(_dir_exists):
373        return ndim
374
375    [os.makedirs(idir, exist_ok=True) for idir in save_image_dir]
376    [os.makedirs(gdir, exist_ok=True) for gdir in save_gt_dir]
377
378    # Logic to extract relevant patches for inference
379    image_counter = 1
380    for per_paths in tqdm(paths_set, desc=f"Extracting {ndim}d patches for {dataset_choice}"):
381        if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice in ["cellseg_3d", "mitolab_3d"]:  # noqa
382            image_path, gt_path = per_paths
383            image, gt = util.load_image_data(image_path), util.load_image_data(gt_path)
384
385        else:
386            image_path = per_paths
387            gt = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][1])
388            if dataset_choice == "hpa":
389                # Get inputs per channel and stack them together to make the desired 3 channel image.
390                image = np.stack(
391                    [util.load_image_data(image_path, k) for k in DATASET_CONTAINER_KEYS[dataset_choice][0]], axis=0,
392                )
393                # Resize inputs to desired tile shape, in favor of working with the shape of foreground.
394                from torch_em.transform.generic import ResizeLongestSideInputs
395                raw_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_rgb=True)
396                label_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_label=True)
397                image, gt = raw_transform(image).transpose(1, 2, 0), label_transform(gt)
398
399            else:
400                image = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][0])
401
402        if dataset_choice in ["tissuenet", "lizard"]:
403            if image.ndim == 3 and image.shape[0] == 3:  # Make channels last for tissuenet RGB-style images.
404                image = image.transpose(1, 2, 0)
405
406        # Allow RGBs to stay as it is with channels last
407        if image.ndim == 3 and image.shape[-1] == 3:
408            skip_smaller_shape = (np.array(image.shape) >= np.array((*tile_shape, 3))).all()
409        else:
410            skip_smaller_shape = (np.array(image.shape) >= np.array(tile_shape)).all()
411
412        # Ensure ground truth has instance labels.
413        gt = connected_components(gt)
414
415        if len(np.unique(gt)) == 1:  # There could be labels which does not have any annotated foreground.
416            continue
417
418        # Let's extract and save all the crops.
419        # The first round of extraction is always to match the desired input dimensions.
420        image_crops, gt_crops = _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input)
421        image_counter = _save_image_label_crops(
422            image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir[0], save_gt_dir[0]
423        )
424
425        # The next round of extraction is to get 2d crops from 3d inputs.
426        if extract_2d_crops_from_volumes:
427            curr_tile_shape = tile_shape[1:]  # We expect 2d tile shape for this stage.
428
429            curr_image_crops, curr_gt_crops = [], []
430            for per_z_im, per_z_gt in zip(image, gt):
431                curr_skip_smaller_shape = (np.array(per_z_im.shape) >= np.array(curr_tile_shape)).all()
432
433                image_crops, gt_crops = _get_crops_for_input(
434                    image=per_z_im, gt=per_z_gt, ndim=2,
435                    tile_shape=curr_tile_shape,
436                    skip_smaller_shape=curr_skip_smaller_shape,
437                    crops_per_input=crops_per_input,
438                )
439                curr_image_crops.extend(image_crops)
440                curr_gt_crops.extend(gt_crops)
441
442            image_counter = _save_image_label_crops(
443                curr_image_crops, curr_gt_crops, dataset_choice, 2, image_counter, save_image_dir[1], save_gt_dir[1]
444            )
445
446    return ndim
447
448
449def _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input):
450    tiling = Blocking([0] * ndim, gt.shape, tile_shape)
451    n_tiles = tiling.number_of_blocks
452    tiles = [tiling.get_block(tile_id) for tile_id in range(n_tiles)]
453    crop_boxes = [
454        tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) for tile in tiles
455    ]
456    n_ids = [idx for idx in range(len(crop_boxes))]
457    n_instances = [len(np.unique(gt[crop])) for crop in crop_boxes]
458
459    # Extract the desired number of patches with higher number of instances.
460    image_crops, gt_crops = [], []
461    for i, (per_n_instance, per_id) in enumerate(sorted(zip(n_instances, n_ids), reverse=True), start=1):
462        crop_box = crop_boxes[per_id]
463        crop_image, crop_gt = image[crop_box], gt[crop_box]
464
465        # NOTE: We avoid using the crops which do not match the desired tile shape.
466        _rtile_shape = (*tile_shape, 3) if image.ndim == 3 and image.shape[-1] == 3 else tile_shape  # For RGB images.
467        if skip_smaller_shape and crop_image.shape != _rtile_shape:
468            continue
469
470        # NOTE: There could be a case where some later patches are invalid.
471        if per_n_instance == 1:
472            break
473
474        image_crops.append(crop_image)
475        gt_crops.append(crop_gt)
476
477        # NOTE: If the number of patches extracted have been fulfiled, we stop sampling patches.
478        if len(image_crops) > 0 and i >= crops_per_input:
479            break
480
481    return image_crops, gt_crops
482
483
484def _save_image_label_crops(image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir, save_gt_dir):
485    for image_crop, gt_crop in tqdm(
486        zip(image_crops, gt_crops), total=len(image_crops), desc=f"Saving {ndim}d crops for {dataset_choice}"
487    ):
488        fname = f"{dataset_choice}_{image_counter:05}.tif"
489
490        if image_crop.ndim == 3 and image_crop.shape[-1] == 3:
491            assert image_crop.shape[:2] == gt_crop.shape
492        else:
493            assert image_crop.shape == gt_crop.shape
494
495        imageio.imwrite(os.path.join(save_image_dir, fname), image_crop, compression="zlib")
496        imageio.imwrite(os.path.join(save_gt_dir, fname), gt_crop, compression="zlib")
497
498        image_counter += 1
499
500    return image_counter
501
502
503def _get_image_label_paths(path, ndim):
504    image_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "inputs", "*")))
505    gt_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "labels", "*")))
506    return image_paths, gt_paths
507
508
509def _run_automatic_segmentation_per_dataset(
510    image_paths: List[Union[os.PathLike, str]],
511    gt_paths: List[Union[os.PathLike, str]],
512    model_type: str,
513    output_folder: Union[os.PathLike, str],
514    ndim: Optional[int] = None,
515    device: Optional[Union[torch.device, str]] = None,
516    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
517    segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = "ais",
518    **auto_seg_kwargs
519):
520    """Functionality to run automatic segmentation for multiple input files at once.
521    It stores the evaluated automatic segmentation results (quantitative).
522
523    Args:
524        image_paths: List of filepaths for the input image data.
525        gt_paths: List of filepaths for the corresponding label data.
526        model_type: The choice of image encoder for the Segment Anything model.
527        output_folder: Filepath to the folder where we store all the results.
528        ndim: The number of input dimensions.
529        device: The torch device.
530        checkpoint_path: The filepath where the model checkpoints are stored.
531        segmentation_mode: The mode for automatic segmentation.
532        auto_seg_kwargs: Additional arguments for automatic segmentation parameters.
533    """
534    if segmentation_mode is None:  # The 2nd condition checks if you want AIS and if decoder state exists or not.
535        _, state = util.get_sam_model(
536            model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True
537        )
538        segmentation_mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg"
539
540    fname = f"{segmentation_mode}_{ndim}d"
541
542    result_path = os.path.join(output_folder, "results", f"{fname}.csv")
543    if os.path.exists(result_path):
544        return
545
546    prediction_dir = os.path.join(output_folder, fname, "inference")
547    os.makedirs(prediction_dir, exist_ok=True)
548
549    # Get the predictor (and the additional instance segmentation decoder, if available).
550    predictor, segmenter = get_predictor_and_segmenter(
551        model_type=model_type, checkpoint=checkpoint_path, device=device,
552        segmentation_mode=segmentation_mode, is_tiled=False,
553    )
554
555    for image_path in tqdm(image_paths, desc=f"Run {segmentation_mode} in {ndim}d"):
556        output_path = os.path.join(prediction_dir, os.path.basename(image_path))
557        if os.path.exists(output_path):
558            continue
559
560        # Run Automatic Segmentation (AMG and AIS)
561        automatic_instance_segmentation(
562            predictor=predictor,
563            segmenter=segmenter,
564            input_path=image_path,
565            output_path=output_path,
566            ndim=ndim,
567            verbose=False,
568            **auto_seg_kwargs
569        )
570
571    prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*")))
572    run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths, save_path=result_path)
573
574
575def _run_interactive_segmentation_per_dataset(
576    image_paths: List[Union[os.PathLike, str]],
577    gt_paths: List[Union[os.PathLike, str]],
578    output_folder: Union[os.PathLike, str],
579    model_type: str,
580    prompt_choice: Literal["box", "points"],
581    device: Optional[Union[torch.device, str]] = None,
582    ndim: Optional[int] = None,
583    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
584    use_masks: bool = False,
585):
586    """Functionality to run interactive segmentation for multiple input files at once.
587    It stores the evaluated interactive segmentation results.
588
589    Args:
590        image_paths: List of filepaths for the input image data.
591        gt_paths: List of filepaths for the corresponding label data.
592        output_folder: Filepath to the folder where we store all the results.
593        model_type: The choice of model type for Segment Anything.
594        prompt_choice: The choice of initial prompts to begin the interactive segmentation.
595        device: The torch device.
596        ndim: The number of input dimensions.
597        checkpoint_path: The filepath for stored checkpoints.
598        use_masks: Whether to use masks for iterative prompting.
599    """
600    if ndim == 2:
601        # Get the Segment Anything predictor.
602        predictor = util.get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path)
603
604        prediction_root = os.path.join(
605            output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}",
606            "iterative_prompting_" + ("with_masks" if use_masks else "without_masks")
607        )
608
609        # Run interactive instance segmentation
610        # (starting with box and points followed by iterative prompt-based correction)
611        run_inference_with_iterative_prompting(
612            predictor=predictor,
613            image_paths=image_paths,
614            gt_paths=gt_paths,
615            embedding_dir=None,  # We set this to None to compute embeddings on-the-fly.
616            prediction_dir=prediction_root,
617            start_with_box_prompt=(prompt_choice == "box"),
618            use_masks=use_masks,
619            # TODO: add parameter for deform over box prompts (to simulate prompts in practice).
620        )
621
622        # Evaluate the interactive instance segmentation.
623        run_evaluation_for_iterative_prompting(
624            gt_paths=gt_paths,
625            prediction_root=prediction_root,
626            experiment_folder=output_folder,
627            start_with_box_prompt=(prompt_choice == "box"),
628            use_masks=use_masks,
629        )
630
631    else:
632        save_path = os.path.join(output_folder, "results", f"interactive_segmentation_3d_with_{prompt_choice}.csv")
633        if os.path.exists(save_path):
634            print(
635                f"Results for 3d interactive segmentation with '{prompt_choice}' are already stored at '{save_path}'."
636            )
637            return
638
639        results = []
640        for image_path, gt_path in tqdm(
641            zip(image_paths, gt_paths), total=len(image_paths),
642            desc=f"Run interactive segmentation in 3d with '{prompt_choice}'"
643        ):
644            prediction_dir = os.path.join(output_folder, "interactive_segmentation_3d", f"{prompt_choice}")
645            os.makedirs(prediction_dir, exist_ok=True)
646
647            prediction_path = os.path.join(prediction_dir, os.path.basename(image_path))
648            if os.path.exists(prediction_path):
649                continue
650
651            per_vol_result = segment_slices_from_ground_truth(
652                volume=imageio.imread(image_path),
653                ground_truth=imageio.imread(gt_path),
654                model_type=model_type,
655                checkpoint_path=checkpoint_path,
656                save_path=prediction_path,
657                device=device,
658                interactive_seg_mode=prompt_choice,
659                min_size=10,
660            )
661            results.append(per_vol_result)
662
663        results = pd.concat(results)
664        results = results.groupby(results.index).mean()
665        results.to_csv(save_path)
666
667
668def _run_benchmark_evaluation_series(
669    image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path,
670    segmentation_mode, evaluation_methods,
671):
672    seg_kwargs = {
673        "image_paths": image_paths,
674        "gt_paths": gt_paths,
675        "output_folder": output_folder,
676        "ndim": ndim,
677        "model_type": model_type,
678        "device": device,
679        "checkpoint_path": checkpoint_path,
680    }
681
682    # Perform:
683    # a. automatic segmentation (supported in both 2d and 3d, wherever relevant)
684    #    The automatic segmentation steps below are configured in a way that AIS has priority (if decoder is found)
685    #    Otherwise, it runs for AMG.
686    #    Next, we check if the user expects to run AMG as well (after the run for AIS).
687
688    if evaluation_methods != "interactive":  # Avoid auto. seg. evaluation for 'interactive'-only run choice.
689        # i. Run automatic segmentation method supported with the SAM model (AMG or AIS).
690        _run_automatic_segmentation_per_dataset(segmentation_mode=None, **seg_kwargs)
691
692        # ii. Run automatic mask generation (AMG).
693        #     NOTE: This would only run if the user wants to. Else by default, it is set to 'False'.
694        _run_automatic_segmentation_per_dataset(segmentation_mode=segmentation_mode, **seg_kwargs)
695
696    if evaluation_methods != "automatic":  # Avoid int. seg. evaluation for 'automatic'-only run choice.
697        # b. Run interactive segmentation (supported in both 2d and 3d, wherever relevant)
698        _run_interactive_segmentation_per_dataset(prompt_choice="box", **seg_kwargs)
699        _run_interactive_segmentation_per_dataset(prompt_choice="box", use_masks=True, **seg_kwargs)
700        _run_interactive_segmentation_per_dataset(prompt_choice="points", **seg_kwargs)
701        _run_interactive_segmentation_per_dataset(prompt_choice="points", use_masks=True, **seg_kwargs)
702
703
704def _clear_cached_items(retain, path, output_folder):
705    import shutil
706    from pathlib import Path
707
708    REMOVE_LIST = ["data", "crops", "automatic", "interactive"]
709    if retain is None:
710        remove_list = REMOVE_LIST
711    else:
712        assert isinstance(retain, list)
713        remove_list = set(REMOVE_LIST) - set(retain)
714
715    paths = []
716    # Stage 1: Remove inputs.
717    if "data" in remove_list or "crops" in remove_list:
718        all_paths = glob(os.path.join(path, "*"))
719
720        # In case we want to remove both data and crops, we remove the data folder entirely.
721        if "data" in remove_list and "crops" in remove_list:
722            paths.extend(all_paths)
723            return
724
725        # Next, we verify whether the we only remove either of data or crops.
726        for curr_path in all_paths:
727            if os.path.basename(curr_path).startswith("roi") and "crops" in remove_list:
728                paths.append(curr_path)
729            elif "data" in remove_list:
730                paths.append(curr_path)
731
732    # Stage 2: Remove predictions
733    if "automatic" in remove_list:
734        paths.extend(glob(os.path.join(output_folder, "amg_*")))
735        paths.extend(glob(os.path.join(output_folder, "ais_*")))
736
737    if "interactive" in remove_list:
738        paths.extend(glob(os.path.join(output_folder, "interactive_segmentation_*")))
739
740    [shutil.rmtree(_path) if Path(_path).is_dir() else os.remove(_path) for _path in paths]
741
742
743def run_benchmark_evaluations(
744    input_folder: Union[os.PathLike, str],
745    dataset_choice: str,
746    model_type: str = util._DEFAULT_MODEL,
747    output_folder: Optional[Union[str, os.PathLike]] = None,
748    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
749    segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None,
750    retain: Optional[List[str]] = None,
751    evaluation_methods: Literal["all", "automatic", "interactive"] = "all",
752    ignore_warnings: bool = False,
753):
754    """Run evaluation for benchmarking Segment Anything models on microscopy datasets.
755
756    Args:
757        input_folder: The path to directory where all inputs will be stored and preprocessed.
758        dataset_choice: The dataset choice.
759        model_type: The model choice for SAM.
760        output_folder: The path to directory where all outputs will be stored.
761        checkpoint_path: The checkpoint path
762        segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
763        retain: Whether to retain certain parts of the benchmark runs.
764            By default, removes everything besides quantitative results.
765            There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
766        evaluation_methods: The choice of evaluation methods.
767            By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive').
768            Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
769        ignore_warnings: Whether to ignore warnings.
770    """
771    start = time.time()
772
773    with _filter_warnings(ignore_warnings):
774        device = util._get_default_device()
775
776        # Ensure if all the datasets have been installed by default.
777        dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice)
778
779        for choice in dataset_choice:
780            output_folder = os.path.join(output_folder, choice)
781            result_dir = os.path.join(output_folder, "results")
782            os.makedirs(result_dir, exist_ok=True)
783
784            data_path = os.path.join(input_folder, choice)
785
786            # Extrapolate desired set from the datasets:
787            # a. for 2d datasets - 2d patches with the most number of labels present
788            #    (in case of volumetric data, choose 2d patches per slice).
789            # b. for 3d datasets - 3d regions of interest with the most number of labels present.
790            ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10)
791
792            # Run inference and evaluation scripts on benchmark datasets.
793            image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim)
794            _run_benchmark_evaluation_series(
795                image_paths=image_paths,
796                gt_paths=gt_paths,
797                model_type=model_type,
798                output_folder=output_folder,
799                ndim=ndim,
800                device=device,
801                checkpoint_path=checkpoint_path,
802                segmentation_mode=segmentation_mode,
803                evaluation_methods=evaluation_methods,
804            )
805
806            # Run inference and evaluation scripts on '2d' crops for volumetric datasets
807            if ndim == 3:
808                image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2)
809                _run_benchmark_evaluation_series(
810                    image_paths=image_paths,
811                    gt_paths=gt_paths,
812                    model_type=model_type,
813                    output_folder=output_folder,
814                    ndim=2,
815                    device=device,
816                    checkpoint_path=checkpoint_path,
817                    segmentation_mode=segmentation_mode,
818                    evaluation_methods=evaluation_methods,
819                )
820
821            _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder)
822
823    diff = time.time() - start
824    hours, rest = divmod(diff, 3600)
825    minutes, seconds = divmod(rest, 60)
826    print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s")
827
828
829def main():
830    """@private"""
831    import argparse
832
833    available_models = list(util.get_model_names())
834    available_models = ", ".join(available_models)
835
836    parser = argparse.ArgumentParser(
837        description="Run evaluation for benchmarking Segment Anything models on microscopy datasets."
838    )
839    parser.add_argument(
840        "-i", "--input_folder", type=str, required=True,
841        help="The path to a directory where the microscopy datasets are and/or will be stored."
842    )
843    parser.add_argument(
844        "-m", "--model_type", type=str, default=util._DEFAULT_MODEL,
845        help=f"The segment anything model that will be used, one of '{available_models}'. By default, "
846        f"it uses'{util._DEFAULT_MODEL}'."
847    )
848    parser.add_argument(
849        "-c", "--checkpoint_path", type=str, default=None,
850        help="The filepath to checkpoint from which the SAM model will be loaded."
851    )
852    parser.add_argument(
853        "-d", "--dataset_choice", type=str, nargs='*', default=None,
854        help="The choice(s) of dataset for evaluating SAM models. Multiple datasets can be specified. "
855        "By default, it evaluates on all datasets."
856    )
857    parser.add_argument(
858        "-o", "--output_folder", type=str, required=True,
859        help="The path where the results for (automatic and interactive) instance segmentation results "
860        "will be stored as 'csv' files."
861    )
862    parser.add_argument(
863        "--amg", action="store_true",
864        help="Whether to run automatic segmentation in AMG mode (i.e. the default auto-seg approach for SAM)."
865    )
866    parser.add_argument(
867        "--retain", nargs="*", default=None,
868        help="By default, the functionality removes all besides quantitative results required for running benchmarks. "
869        "In case you would like to retain parts of the benchmark evaluation for visualization / reproducibility, "
870        "you should choose one or multiple of 'data', 'crops', 'automatic', 'interactive'. "
871        "where they are responsible for either retaining original inputs / extracted crops / "
872        "predictions of automatic segmentation / predictions of interactive segmentation, respectively."
873    )
874    parser.add_argument(
875        "--evaluate", type=str, default="all", choices=["all", "automatic", "interactive"],
876        help="The choice of methods for benchmarking evaluation for reproducibility. "
877        "By default, we run all evaluations with 'all'. If 'automatic' is chosen, it runs automatic segmentation only "
878        "/ 'interactive' runs interactive segmentation (starting from box and single point) with iterative prompting."
879    )
880    args = parser.parse_args()
881
882    run_benchmark_evaluations(
883        input_folder=args.input_folder,
884        dataset_choice=args.dataset_choice,
885        model_type=args.model_type,
886        output_folder=args.output_folder,
887        checkpoint_path=args.checkpoint_path,
888        segmentation_mode=args.segmentation_mode,
889        retain=args.retain,
890        evaluation_methods=args.evaluate,
891        ignore_warnings=True,
892    )
893
894
895if __name__ == "__main__":
896    main()
LM_2D_DATASETS = ['livecell', 'deepbacs', 'tissuenet', 'neurips_cellseg', 'cellpose', 'dynamicnuclearnet', 'orgasegment', 'yeaz', 'arvidsson', 'bitdepth_nucseg', 'cellbindb', 'covid_if', 'deepseas', 'hpa', 'ifnuclei', 'lizard', 'organoidnet', 'toiam', 'vicar']
LM_3D_DATASETS = ['plantseg_root', 'plantseg_ovules', 'gonuclear', 'mouse_embryo', 'cellseg3d']
EM_2D_DATASETS = ['mitolab_tem']
EM_3D_DATASETS = ['lucchi', 'mitolab', 'uro_cell', 'sponge_em', 'vnc', 'nuc_mm_mouse', 'num_mm_zebrafish', 'platynereis_cilia', 'asem_mito']
DATASET_RETURNS_FOLDER = {'deepbacs': '*.tif', 'mitolab_tem': '*.tiff'}
DATASET_CONTAINER_KEYS = {'tissuenet': ['raw/rgb', 'labels/cell'], 'covid_if': ['raw/serum_IgG/s0', 'labels/cells/s0'], 'dynamicnuclearnet': ['raw', 'labels'], 'hpa': [['raw/protein', 'raw/microtubules', 'raw/er'], 'labels'], 'lizard': ['image', 'labels/segmentation'], 'plantseg_root': ['raw', 'label'], 'plantseg_ovules': ['raw', 'label'], 'gonuclear': ['raw/nuclei', 'labels/nuclei'], 'mouse_embryo': ['raw', 'label'], 'cellseg_3d': [None, None], 'lucchi': ['raw', 'labels'], 'uro_cell': ['raw', 'labels/mito'], 'mitolab_3d': [None, None], 'sponge_em': ['volumes/raw', 'volumes/labels/instances'], 'vnc': ['raw', 'labels/mitochondria']}
def run_benchmark_evaluations( input_folder: Union[os.PathLike, str], dataset_choice: str, model_type: str = 'vit_b_lm', output_folder: Union[os.PathLike, str, NoneType] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, segmentation_mode: Optional[Literal['amg', 'ais', 'apg']] = None, retain: Optional[List[str]] = None, evaluation_methods: Literal['all', 'automatic', 'interactive'] = 'all', ignore_warnings: bool = False):
744def run_benchmark_evaluations(
745    input_folder: Union[os.PathLike, str],
746    dataset_choice: str,
747    model_type: str = util._DEFAULT_MODEL,
748    output_folder: Optional[Union[str, os.PathLike]] = None,
749    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
750    segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None,
751    retain: Optional[List[str]] = None,
752    evaluation_methods: Literal["all", "automatic", "interactive"] = "all",
753    ignore_warnings: bool = False,
754):
755    """Run evaluation for benchmarking Segment Anything models on microscopy datasets.
756
757    Args:
758        input_folder: The path to directory where all inputs will be stored and preprocessed.
759        dataset_choice: The dataset choice.
760        model_type: The model choice for SAM.
761        output_folder: The path to directory where all outputs will be stored.
762        checkpoint_path: The checkpoint path
763        segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
764        retain: Whether to retain certain parts of the benchmark runs.
765            By default, removes everything besides quantitative results.
766            There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
767        evaluation_methods: The choice of evaluation methods.
768            By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive').
769            Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
770        ignore_warnings: Whether to ignore warnings.
771    """
772    start = time.time()
773
774    with _filter_warnings(ignore_warnings):
775        device = util._get_default_device()
776
777        # Ensure if all the datasets have been installed by default.
778        dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice)
779
780        for choice in dataset_choice:
781            output_folder = os.path.join(output_folder, choice)
782            result_dir = os.path.join(output_folder, "results")
783            os.makedirs(result_dir, exist_ok=True)
784
785            data_path = os.path.join(input_folder, choice)
786
787            # Extrapolate desired set from the datasets:
788            # a. for 2d datasets - 2d patches with the most number of labels present
789            #    (in case of volumetric data, choose 2d patches per slice).
790            # b. for 3d datasets - 3d regions of interest with the most number of labels present.
791            ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10)
792
793            # Run inference and evaluation scripts on benchmark datasets.
794            image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim)
795            _run_benchmark_evaluation_series(
796                image_paths=image_paths,
797                gt_paths=gt_paths,
798                model_type=model_type,
799                output_folder=output_folder,
800                ndim=ndim,
801                device=device,
802                checkpoint_path=checkpoint_path,
803                segmentation_mode=segmentation_mode,
804                evaluation_methods=evaluation_methods,
805            )
806
807            # Run inference and evaluation scripts on '2d' crops for volumetric datasets
808            if ndim == 3:
809                image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2)
810                _run_benchmark_evaluation_series(
811                    image_paths=image_paths,
812                    gt_paths=gt_paths,
813                    model_type=model_type,
814                    output_folder=output_folder,
815                    ndim=2,
816                    device=device,
817                    checkpoint_path=checkpoint_path,
818                    segmentation_mode=segmentation_mode,
819                    evaluation_methods=evaluation_methods,
820                )
821
822            _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder)
823
824    diff = time.time() - start
825    hours, rest = divmod(diff, 3600)
826    minutes, seconds = divmod(rest, 60)
827    print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s")

Run evaluation for benchmarking Segment Anything models on microscopy datasets.

Arguments:
  • input_folder: The path to directory where all inputs will be stored and preprocessed.
  • dataset_choice: The dataset choice.
  • model_type: The model choice for SAM.
  • output_folder: The path to directory where all outputs will be stored.
  • checkpoint_path: The checkpoint path
  • segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
  • retain: Whether to retain certain parts of the benchmark runs. By default, removes everything besides quantitative results. There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
  • evaluation_methods: The choice of evaluation methods. By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
  • ignore_warnings: Whether to ignore warnings.