micro_sam.evaluation.benchmark_datasets

  1import os
  2import time
  3from glob import glob
  4from tqdm import tqdm
  5from natsort import natsorted
  6from typing import Union, Optional, List, Literal
  7
  8import numpy as np
  9import pandas as pd
 10import imageio.v3 as imageio
 11from skimage.measure import label as connected_components
 12
 13from nifty.tools import blocking
 14
 15import torch
 16
 17from torch_em.data import datasets
 18
 19from micro_sam import util
 20
 21from . import run_evaluation
 22from ..training.training import _filter_warnings
 23from .inference import run_inference_with_iterative_prompting
 24from .evaluation import run_evaluation_for_iterative_prompting
 25from .multi_dimensional_segmentation import segment_slices_from_ground_truth
 26from ..automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter
 27
 28
 29LM_2D_DATASETS = [
 30    # in-domain
 31    "livecell",  # cell segmentation in PhC (has a TEST-set)
 32    "deepbacs",  # bacteria segmentation in label-free microscopy (has a TEST-set),
 33    "tissuenet",  # cell segmentation in tissue microscopy images (has a TEST-set),
 34    "neurips_cellseg",  # cell segmentation in various (has a TEST-set),
 35    "cellpose",  # cell segmentation in FM (has 'cyto2' on which we can TEST on),
 36    "dynamicnuclearnet",  # nuclei segmentation in FM (has a TEST-set)
 37    "orgasegment",  # organoid segmentation in BF (has a TEST-set)
 38    "yeaz",  # yeast segmentation in BF (has a TEST-set)
 39
 40    # out-of-domain
 41    "arvidsson",  # nuclei segmentation in HCS FM
 42    "bitdepth_nucseg",  # nuclei segmentation in FM
 43    "cellbindb",  # cell segmentation in various microscopy
 44    "covid_if",  # cell segmentation in IF
 45    "deepseas",  # cell segmentation in PhC,
 46    "hpa",  # cell segmentation in confocal,
 47    "ifnuclei",  # nuclei segmentation in IFM
 48    "lizard",  # nuclei segmentation in H&E histopathology,
 49    "organoidnet",  # organoid segmentation in BF
 50    "toiam",  # microbial cell segmentation in PhC
 51    "vicar",  # cell segmentation in label-free
 52]
 53
 54LM_3D_DATASETS = [
 55    # in-domain
 56    "plantseg_root",  # cell segmentation in lightsheet (has a TEST-set)
 57
 58    # out-of-domain
 59    "plantseg_ovules",  # cell segmentation in confocal
 60    "gonuclear",  # nuclei segmentation in FM
 61    "mouse_embryo",  # cell segmentation in lightsheet
 62    "cellseg3d",  # nuclei segmentation in FM
 63]
 64
 65EM_2D_DATASETS = ["mitolab_tem"]
 66
 67EM_3D_DATASETS = [
 68    # out-of-domain
 69    "lucchi",  # mitochondria segmentation in vEM
 70    "mitolab",  # mitochondria segmentation in various
 71    "uro_cell",  # mitochondria segmentation (and other organelles) in FIB-SEM
 72    "sponge_em",  # microvili segmentation (and other organelles) in sponge chamber vEM
 73    "vnc",  # mitochondria segmentation in drosophila brain TEM
 74    "nuc_mm_mouse",  # nuclei segmentation in microCT
 75    "num_mm_zebrafish",  # nuclei segmentation in EM
 76    "platynereis_cilia",  # cilia segmentation (and other structures) in platynereis larvae vEM
 77    "asem_mito",  # mitochondria segmentation (and other organelles) in FIB-SEM
 78]
 79
 80DATASET_RETURNS_FOLDER = {
 81    "deepbacs": "*.tif",
 82    "mitolab_tem": "*.tiff"
 83}
 84
 85DATASET_CONTAINER_KEYS = {
 86    # 2d (LM)
 87    "tissuenet": ["raw/rgb", "labels/cell"],
 88    "covid_if": ["raw/serum_IgG/s0", "labels/cells/s0"],
 89    "dynamicnuclearnet": ["raw", "labels"],
 90    "hpa": [["raw/protein", "raw/microtubules", "raw/er"], "labels"],
 91    "lizard": ["image", "labels/segmentation"],
 92
 93    # 3d (LM)
 94    "plantseg_root": ["raw", "label"],
 95    "plantseg_ovules": ["raw", "label"],
 96    "gonuclear": ["raw/nuclei", "labels/nuclei"],
 97    "mouse_embryo": ["raw", "label"],
 98    "cellseg_3d": [None, None],
 99
100    # 3d (EM)
101    "lucchi": ["raw", "labels"],
102    "uro_cell": ["raw", "labels/mito"],
103    "mitolab_3d": [None, None],
104    "sponge_em": ["volumes/raw", "volumes/labels/instances"],
105    "vnc": ["raw", "labels/mitochondria"]
106}
107
108
109def _download_benchmark_datasets(path, dataset_choice):
110    """Ensures whether all the datasets have been downloaded or not.
111
112    Args:
113        path: The path to directory where the supported datasets will be downloaded
114            for benchmarking Segment Anything models.
115        dataset_choice: The choice of dataset, expects the lower case name for the dataset.
116
117    Returns:
118        List of choice of dataset(s).
119    """
120    available_datasets = {
121        # Light Microscopy datasets
122
123        # 2d datasets: in-domain
124        "livecell": lambda: datasets.livecell.get_livecell_data(
125            path=os.path.join(path, "livecell"), download=True,
126        ),
127        "deepbacs": lambda: datasets.deepbacs.get_deepbacs_data(
128            path=os.path.join(path, "deepbacs"), bac_type="mixed", download=True,
129        ),
130        "tissuenet": lambda: datasets.tissuenet.get_tissuenet_data(
131            path=os.path.join(path, "tissuenet"), split="test", download=True,
132        ),
133        "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_data(
134            root=os.path.join(path, "neurips_cellseg"), split="test", download=True,
135        ),
136        "cellpose": lambda: datasets.cellpose.get_cellpose_data(
137            path=os.path.join(path, "cellpose"), split="train", choice="cyto2", download=True,
138        ),
139        "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_data(
140            path=os.path.join(path, "dynamicnuclearnet"), split="test", download=True,
141        ),
142        "orgasegment": lambda: datasets.orgasegment.get_orgasegment_data(
143            path=os.path.join(path, "orgasegment"), split="eval", download=True,
144        ),
145        "yeaz": lambda: datasets.yeaz.get_yeaz_data(
146            path=os.path.join(path, "yeaz"), choice="bf", download=True,
147        ),
148
149        # 2d datasets: out-of-domain
150        "arvidsson": lambda: datasets.arvidsson.get_arvidsson_data(
151            path=os.path.join(path, "arvidsson"), split="test", download=True,
152        ),
153        "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_data(
154            path=os.path.join(path, "bitdepth_nucseg"), download=True,
155        ),
156        "cellbindb": lambda: datasets.cellbindb.get_cellbindb_data(
157            path=os.path.join(path, "cellbindb"), download=True,
158        ),
159        "covid_if": lambda: datasets.covid_if.get_covid_if_data(
160            path=os.path.join(path, "covid_if"), download=True,
161        ),
162        "deepseas": lambda: datasets.deepseas.get_deepseas_data(
163            path=os.path.join(path, "deepseas"), split="test",
164        ),
165        "hpa": lambda: datasets.hpa.get_hpa_segmentation_data(
166            path=os.path.join(path, "hpa"), download=True,
167        ),
168        "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_data(
169            path=os.path.join(path, "ifnuclei"), download=True,
170        ),
171        "lizard": lambda: datasets.lizard.get_lizard_data(
172            path=os.path.join(path, "lizard"), split="test", download=True,
173        ),
174        "organoidnet": lambda: datasets.organoidnet.get_organoidnet_data(
175            path=os.path.join(path, "organoidnet"), split="Test", download=True,
176        ),
177        "toiam": lambda: datasets.toiam.get_toiam_data(
178            path=os.path.join(path, "toiam"), download=True,
179        ),
180        "vicar": lambda: datasets.vicar.get_vicar_data(
181            path=os.path.join(path, "vicar"), download=True,
182        ),
183
184        # 3d datasets: in-domain
185        "plantseg_root": lambda: datasets.plantseg.get_plantseg_data(
186            path=os.path.join(path, "plantseg_root"), split="test", download=True, name="root",
187        ),
188
189        # 3d datasets: out-of-domain
190        "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_data(
191            path=os.path.join(path, "plantseg_ovules"), split="test", download=True, name="ovules",
192        ),
193        "gonuclear": lambda: datasets.gonuclear.get_gonuclear_data(
194            path=os.path.join(path, "gonuclear"), download=True,
195        ),
196        "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_data(
197            path=os.path.join(path, "mouse_embryo"), download=True,
198        ),
199        "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_data(
200            path=os.path.join(path, "cellseg_3d"), download=True,
201        ),
202
203        # Electron Microscopy datasets
204
205        # 2d datasets: out-of-domain
206        "mitolab_tem": lambda: datasets.cem.get_benchmark_data(
207            path=os.path.join(path, "mitolab"), dataset_id=7, download=True
208        ),
209
210        # 3d datasets: out-of-domain
211        "lucchi": lambda: datasets.lucchi.get_lucchi_data(
212            path=os.path.join(path, "lucchi"), split="test", download=True,
213        ),
214        "mitolab_3d": lambda: [
215            datasets.cem.get_benchmark_data(
216                path=os.path.join(path, "mitolab"), dataset_id=dataset_id, download=True,
217            ) for dataset_id in range(1, 7)
218        ],
219        "uro_cell": lambda: datasets.uro_cell.get_uro_cell_data(
220            path=os.path.join(path, "uro_cell"), download=True,
221        ),
222        "vnc": lambda: datasets.vnc.get_vnc_data(
223            path=os.path.join(path, "vnc"), download=True,
224        ),
225        "sponge_em": lambda: datasets.sponge_em.get_sponge_em_data(
226            path=os.path.join(path, "sponge_em"), download=True,
227        ),
228        "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_data(
229            path=os.path.join(path, "nuc_mm"), sample="mouse", download=True,
230        ),
231        "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_data(
232            path=os.path.join(path, "nuc_mm"), sample="zebrafish", download=True,
233        ),
234        "asem_mito": lambda: datasets.asem.get_asem_data(
235            path=os.path.join(path, "asem"), volume_ids=datasets.asem.ORGANELLES["mito"], download=True,
236        ),
237        "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_data(
238            path=os.path.join(path, "platynereis"), name="cilia", download=True,
239        ),
240    }
241
242    if dataset_choice is None:
243        dataset_choice = available_datasets.keys()
244    else:
245        if not isinstance(dataset_choice, list):
246            dataset_choice = [dataset_choice]
247
248    for choice in dataset_choice:
249        if choice in available_datasets:
250            available_datasets[choice]()
251        else:
252            raise ValueError(f"'{choice}' is not a supported choice of dataset.")
253
254    return dataset_choice
255
256
257def _extract_slices_from_dataset(path, dataset_choice, crops_per_input=10):
258    """Extracts crops of desired shapes for performing evaluation in both 2d and 3d using `micro-sam`.
259
260    Args:
261        path: The path to directory where the supported datasets have be downloaded
262            for benchmarking Segment Anything models.
263        dataset_choice: The name of the dataset of choice to extract crops.
264        crops_per_input: The maximum number of crops to extract per inputs.
265        extract_2d: Whether to extract 2d crops from 3d patches.
266
267    Returns:
268        Filepath to the folder where extracted images are stored.
269        Filepath to the folder where corresponding extracted labels are stored.
270        The number of dimensions supported by the input.
271    """
272    ndim = 2 if dataset_choice in [*LM_2D_DATASETS, *EM_2D_DATASETS] else 3
273    tile_shape = (512, 512) if ndim == 2 else (32, 512, 512)
274
275    # For 3d inputs, we extract both 2d and 3d crops.
276    extract_2d_crops_from_volumes = (ndim == 3)
277
278    available_datasets = {
279        # Light Microscopy datasets
280
281        # 2d: in-domain
282        "livecell": lambda: datasets.livecell.get_livecell_paths(path=path, split="test"),
283        "deepbacs": lambda: datasets.deepbacs.get_deepbacs_paths(path=path, split="test", bac_type="mixed"),
284        "tissuenet": lambda: datasets.tissuenet.get_tissuenet_paths(path=path, split="test"),
285        "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_paths(root=path, split="test"),
286        "cellpose": lambda: datasets.cellpose.get_cellpose_paths(path=path, split="train", choice="cyto2"),
287        "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_paths(path=path, split="test"),
288        "orgasegment": lambda: datasets.orgasegment.get_orgasegment_paths(path=path, split="eval"),
289        "yeaz": lambda: datasets.yeaz.get_yeaz_paths(path=path, choice="bf", split="test"),
290
291        # 2d: out-of-domain
292        "arvidsson": lambda: datasets.arvidsson.get_arvidsson_paths(path=path, split="test"),
293        "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_paths(path=path, magnification="20x"),
294        "cellbindb": lambda: datasets.cellbindb.get_cellbindb_paths(
295            path=path, data_choice=["10×Genomics_DAPI", "DAPI", "mIF"]
296        ),
297        "covid_if": lambda: datasets.covid_if.get_covid_if_paths(path=path),
298        "deepseas": lambda: datasets.deepseas.get_deepseas_paths(path=path, split="test"),
299        "hpa": lambda: datasets.hpa.get_hpa_segmentation_paths(path=path, split="val"),
300        "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_paths(path=path),
301        "lizard": lambda: datasets.lizard.get_lizard_paths(path=path, split="test"),
302        "organoidnet": lambda: datasets.organoidnet.get_organoidnet_paths(path=path, split="Test"),
303        "toiam": lambda: datasets.toiam.get_toiam_paths(path=path),
304        "vicar": lambda: datasets.vicar.get_vicar_paths(path=path),
305
306        # 3d: in-domain
307        "plantseg_root": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="root", split="test"),
308
309        # 3d: out-of-domain
310        "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="ovules", split="test"),
311        "gonuclear": lambda: datasets.gonuclear.get_gonuclear_paths(path=path),
312        "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_paths(path=path, name="nuclei", split="val"),
313        "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_paths(path=path),
314
315        # Electron Microscopy datasets
316
317        # 2d: out-of-domain
318        "mitolab_tem": lambda: datasets.cem.get_benchmark_paths(
319            path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=7
320        )[:2],
321
322        # 3d: out-of-domain"lucchi": lambda: datasets.lucchi.get_lucchi_paths(path=path, split="test"),
323        "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="cilia"),
324        "uro_cell": lambda: datasets.uro_cell.get_uro_cell_paths(path=path, target="mito"),
325        "vnc": lambda: datasets.vnc.get_vnc_mito_paths(path=path),
326        "sponge_em": lambda: datasets.sponge_em.get_sponge_em_paths(path=path, sample_ids=None),
327        "mitolab_3d": lambda: (
328            [
329                datasets.cem.get_benchmark_paths(
330                    path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=i
331                )[0] for i in range(1, 7)
332            ],
333            [
334                datasets.cem.get_benchmark_paths(
335                    path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=i
336                )[1] for i in range(1, 7)
337            ]
338        ),
339        "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="mouse", split="val"),
340        "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="zebrafish", split="val"),
341        "asem_mito": lambda: datasets.asem.get_asem_paths(path=path, volume_ids=datasets.asem.ORGANELLES["mito"])
342    }
343
344    if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice in ["cellseg_3d", "mitolab_3d"]:
345        image_paths, gt_paths = available_datasets[dataset_choice]()
346
347        if dataset_choice in DATASET_RETURNS_FOLDER:
348            image_paths = glob(os.path.join(image_paths, DATASET_RETURNS_FOLDER[dataset_choice]))
349            gt_paths = glob(os.path.join(gt_paths, DATASET_RETURNS_FOLDER[dataset_choice]))
350
351        image_paths, gt_paths = natsorted(image_paths), natsorted(gt_paths)
352        assert len(image_paths) == len(gt_paths)
353
354        paths_set = zip(image_paths, gt_paths)
355
356    else:
357        image_paths = available_datasets[dataset_choice]()
358        if isinstance(image_paths, str):
359            paths_set = [image_paths]
360        else:
361            paths_set = natsorted(image_paths)
362
363    # Directory where we store the extracted ROIs.
364    save_image_dir = [os.path.join(path, f"roi_{ndim}d", "inputs")]
365    save_gt_dir = [os.path.join(path, f"roi_{ndim}d", "labels")]
366    if extract_2d_crops_from_volumes:
367        save_image_dir.append(os.path.join(path, "roi_2d", "inputs"))
368        save_gt_dir.append(os.path.join(path, "roi_2d", "labels"))
369
370    _dir_exists = [os.path.exists(idir) and os.path.exists(gdir) for idir, gdir in zip(save_image_dir, save_gt_dir)]
371    if all(_dir_exists):
372        return ndim
373
374    [os.makedirs(idir, exist_ok=True) for idir in save_image_dir]
375    [os.makedirs(gdir, exist_ok=True) for gdir in save_gt_dir]
376
377    # Logic to extract relevant patches for inference
378    image_counter = 1
379    for per_paths in tqdm(paths_set, desc=f"Extracting {ndim}d patches for {dataset_choice}"):
380        if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice in ["cellseg_3d", "mitolab_3d"]:  # noqa
381            image_path, gt_path = per_paths
382            image, gt = util.load_image_data(image_path), util.load_image_data(gt_path)
383
384        else:
385            image_path = per_paths
386            gt = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][1])
387            if dataset_choice == "hpa":
388                # Get inputs per channel and stack them together to make the desired 3 channel image.
389                image = np.stack(
390                    [util.load_image_data(image_path, k) for k in DATASET_CONTAINER_KEYS[dataset_choice][0]], axis=0,
391                )
392                # Resize inputs to desired tile shape, in favor of working with the shape of foreground.
393                from torch_em.transform.generic import ResizeLongestSideInputs
394                raw_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_rgb=True)
395                label_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_label=True)
396                image, gt = raw_transform(image).transpose(1, 2, 0), label_transform(gt)
397
398            else:
399                image = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][0])
400
401        if dataset_choice in ["tissuenet", "lizard"]:
402            if image.ndim == 3 and image.shape[0] == 3:  # Make channels last for tissuenet RGB-style images.
403                image = image.transpose(1, 2, 0)
404
405        # Allow RGBs to stay as it is with channels last
406        if image.ndim == 3 and image.shape[-1] == 3:
407            skip_smaller_shape = (np.array(image.shape) >= np.array((*tile_shape, 3))).all()
408        else:
409            skip_smaller_shape = (np.array(image.shape) >= np.array(tile_shape)).all()
410
411        # Ensure ground truth has instance labels.
412        gt = connected_components(gt)
413
414        if len(np.unique(gt)) == 1:  # There could be labels which does not have any annotated foreground.
415            continue
416
417        # Let's extract and save all the crops.
418        # The first round of extraction is always to match the desired input dimensions.
419        image_crops, gt_crops = _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input)
420        image_counter = _save_image_label_crops(
421            image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir[0], save_gt_dir[0]
422        )
423
424        # The next round of extraction is to get 2d crops from 3d inputs.
425        if extract_2d_crops_from_volumes:
426            curr_tile_shape = tile_shape[1:]  # We expect 2d tile shape for this stage.
427
428            curr_image_crops, curr_gt_crops = [], []
429            for per_z_im, per_z_gt in zip(image, gt):
430                curr_skip_smaller_shape = (np.array(per_z_im.shape) >= np.array(curr_tile_shape)).all()
431
432                image_crops, gt_crops = _get_crops_for_input(
433                    image=per_z_im, gt=per_z_gt, ndim=2,
434                    tile_shape=curr_tile_shape,
435                    skip_smaller_shape=curr_skip_smaller_shape,
436                    crops_per_input=crops_per_input,
437                )
438                curr_image_crops.extend(image_crops)
439                curr_gt_crops.extend(gt_crops)
440
441            image_counter = _save_image_label_crops(
442                curr_image_crops, curr_gt_crops, dataset_choice, 2, image_counter, save_image_dir[1], save_gt_dir[1]
443            )
444
445    return ndim
446
447
448def _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input):
449    tiling = blocking([0] * ndim, gt.shape, tile_shape)
450    n_tiles = tiling.numberOfBlocks
451    tiles = [tiling.getBlock(tile_id) for tile_id in range(n_tiles)]
452    crop_boxes = [
453        tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) for tile in tiles
454    ]
455    n_ids = [idx for idx in range(len(crop_boxes))]
456    n_instances = [len(np.unique(gt[crop])) for crop in crop_boxes]
457
458    # Extract the desired number of patches with higher number of instances.
459    image_crops, gt_crops = [], []
460    for i, (per_n_instance, per_id) in enumerate(sorted(zip(n_instances, n_ids), reverse=True), start=1):
461        crop_box = crop_boxes[per_id]
462        crop_image, crop_gt = image[crop_box], gt[crop_box]
463
464        # NOTE: We avoid using the crops which do not match the desired tile shape.
465        _rtile_shape = (*tile_shape, 3) if image.ndim == 3 and image.shape[-1] == 3 else tile_shape  # For RGB images.
466        if skip_smaller_shape and crop_image.shape != _rtile_shape:
467            continue
468
469        # NOTE: There could be a case where some later patches are invalid.
470        if per_n_instance == 1:
471            break
472
473        image_crops.append(crop_image)
474        gt_crops.append(crop_gt)
475
476        # NOTE: If the number of patches extracted have been fulfiled, we stop sampling patches.
477        if len(image_crops) > 0 and i >= crops_per_input:
478            break
479
480    return image_crops, gt_crops
481
482
483def _save_image_label_crops(image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir, save_gt_dir):
484    for image_crop, gt_crop in tqdm(
485        zip(image_crops, gt_crops), total=len(image_crops), desc=f"Saving {ndim}d crops for {dataset_choice}"
486    ):
487        fname = f"{dataset_choice}_{image_counter:05}.tif"
488
489        if image_crop.ndim == 3 and image_crop.shape[-1] == 3:
490            assert image_crop.shape[:2] == gt_crop.shape
491        else:
492            assert image_crop.shape == gt_crop.shape
493
494        imageio.imwrite(os.path.join(save_image_dir, fname), image_crop, compression="zlib")
495        imageio.imwrite(os.path.join(save_gt_dir, fname), gt_crop, compression="zlib")
496
497        image_counter += 1
498
499    return image_counter
500
501
502def _get_image_label_paths(path, ndim):
503    image_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "inputs", "*")))
504    gt_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "labels", "*")))
505    return image_paths, gt_paths
506
507
508def _run_automatic_segmentation_per_dataset(
509    image_paths: List[Union[os.PathLike, str]],
510    gt_paths: List[Union[os.PathLike, str]],
511    model_type: str,
512    output_folder: Union[os.PathLike, str],
513    ndim: Optional[int] = None,
514    device: Optional[Union[torch.device, str]] = None,
515    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
516    run_amg: Optional[bool] = None,
517    **auto_seg_kwargs
518):
519    """Functionality to run automatic segmentation for multiple input files at once.
520    It stores the evaluated automatic segmentation results (quantitative).
521
522    Args:
523        image_paths: List of filepaths for the input image data.
524        gt_paths: List of filepaths for the corresponding label data.
525        model_type: The choice of image encoder for the Segment Anything model.
526        output_folder: Filepath to the folder where we store all the results.
527        ndim: The number of input dimensions.
528        device: The torch device.
529        checkpoint_path: The filepath where the model checkpoints are stored.
530        run_amg: Whether to run automatic segmentation in AMG mode.
531        auto_seg_kwargs: Additional arguments for automatic segmentation parameters.
532    """
533    # First, we check if 'run_amg' is done, whether decoder is available or not.
534    # Depending on that, we can set 'run_amg' to the default best automatic segmentation (i.e. AIS > AMG).
535    if run_amg is None or (not run_amg):  # The 2nd condition checks if you want AIS and if decoder state exists or not.
536        _, state = util.get_sam_model(
537            model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True
538        )
539        run_amg = ("decoder_state" not in state)
540
541    experiment_name = "AMG" if run_amg else "AIS"
542    fname = f"{experiment_name.lower()}_{ndim}d"
543
544    result_path = os.path.join(output_folder, "results", f"{fname}.csv")
545    if os.path.exists(result_path):
546        return
547
548    prediction_dir = os.path.join(output_folder, fname, "inference")
549    os.makedirs(prediction_dir, exist_ok=True)
550
551    # Get the predictor (and the additional instance segmentation decoder, if available).
552    predictor, segmenter = get_predictor_and_segmenter(
553        model_type=model_type, checkpoint=checkpoint_path, device=device, amg=run_amg, is_tiled=False,
554    )
555
556    for image_path in tqdm(image_paths, desc=f"Run {experiment_name} in {ndim}d"):
557        output_path = os.path.join(prediction_dir, os.path.basename(image_path))
558        if os.path.exists(output_path):
559            continue
560
561        # Run Automatic Segmentation (AMG and AIS)
562        automatic_instance_segmentation(
563            predictor=predictor,
564            segmenter=segmenter,
565            input_path=image_path,
566            output_path=output_path,
567            ndim=ndim,
568            verbose=False,
569            **auto_seg_kwargs
570        )
571
572    prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*")))
573    run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths, save_path=result_path)
574
575
576def _run_interactive_segmentation_per_dataset(
577    image_paths: List[Union[os.PathLike, str]],
578    gt_paths: List[Union[os.PathLike, str]],
579    output_folder: Union[os.PathLike, str],
580    model_type: str,
581    prompt_choice: Literal["box", "points"],
582    device: Optional[Union[torch.device, str]] = None,
583    ndim: Optional[int] = None,
584    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
585    use_masks: bool = False,
586):
587    """Functionality to run interactive segmentation for multiple input files at once.
588    It stores the evaluated interactive segmentation results.
589
590    Args:
591        image_paths: List of filepaths for the input image data.
592        gt_paths: List of filepaths for the corresponding label data.
593        output_folder: Filepath to the folder where we store all the results.
594        model_type: The choice of model type for Segment Anything.
595        prompt_choice: The choice of initial prompts to begin the interactive segmentation.
596        device: The torch device.
597        ndim: The number of input dimensions.
598        checkpoint_path: The filepath for stored checkpoints.
599        use_masks: Whether to use masks for iterative prompting.
600    """
601    if ndim == 2:
602        # Get the Segment Anything predictor.
603        predictor = util.get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path)
604
605        prediction_root = os.path.join(
606            output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}",
607            "iterative_prompting_" + ("with_masks" if use_masks else "without_masks")
608        )
609
610        # Run interactive instance segmentation
611        # (starting with box and points followed by iterative prompt-based correction)
612        run_inference_with_iterative_prompting(
613            predictor=predictor,
614            image_paths=image_paths,
615            gt_paths=gt_paths,
616            embedding_dir=None,  # We set this to None to compute embeddings on-the-fly.
617            prediction_dir=prediction_root,
618            start_with_box_prompt=(prompt_choice == "box"),
619            use_masks=use_masks,
620            # TODO: add parameter for deform over box prompts (to simulate prompts in practice).
621        )
622
623        # Evaluate the interactive instance segmentation.
624        run_evaluation_for_iterative_prompting(
625            gt_paths=gt_paths,
626            prediction_root=prediction_root,
627            experiment_folder=output_folder,
628            start_with_box_prompt=(prompt_choice == "box"),
629            use_masks=use_masks,
630        )
631
632    else:
633        save_path = os.path.join(output_folder, "results", f"interactive_segmentation_3d_with_{prompt_choice}.csv")
634        if os.path.exists(save_path):
635            print(
636                f"Results for 3d interactive segmentation with '{prompt_choice}' are already stored at '{save_path}'."
637            )
638            return
639
640        results = []
641        for image_path, gt_path in tqdm(
642            zip(image_paths, gt_paths), total=len(image_paths),
643            desc=f"Run interactive segmentation in 3d with '{prompt_choice}'"
644        ):
645            prediction_dir = os.path.join(output_folder, "interactive_segmentation_3d", f"{prompt_choice}")
646            os.makedirs(prediction_dir, exist_ok=True)
647
648            prediction_path = os.path.join(prediction_dir, os.path.basename(image_path))
649            if os.path.exists(prediction_path):
650                continue
651
652            per_vol_result = segment_slices_from_ground_truth(
653                volume=imageio.imread(image_path),
654                ground_truth=imageio.imread(gt_path),
655                model_type=model_type,
656                checkpoint_path=checkpoint_path,
657                save_path=prediction_path,
658                device=device,
659                interactive_seg_mode=prompt_choice,
660                min_size=10,
661            )
662            results.append(per_vol_result)
663
664        results = pd.concat(results)
665        results = results.groupby(results.index).mean()
666        results.to_csv(save_path)
667
668
669def _run_benchmark_evaluation_series(
670    image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg, evaluation_methods,
671):
672    seg_kwargs = {
673        "image_paths": image_paths,
674        "gt_paths": gt_paths,
675        "output_folder": output_folder,
676        "ndim": ndim,
677        "model_type": model_type,
678        "device": device,
679        "checkpoint_path": checkpoint_path,
680    }
681
682    # Perform:
683    # a. automatic segmentation (supported in both 2d and 3d, wherever relevant)
684    #    The automatic segmentation steps below are configured in a way that AIS has priority (if decoder is found)
685    #    Otherwise, it runs for AMG.
686    #    Next, we check if the user expects to run AMG as well (after the run for AIS).
687
688    if evaluation_methods != "interactive":  # Avoid auto. seg. evaluation for 'interactive'-only run choice.
689        # i. Run automatic segmentation method supported with the SAM model (AMG or AIS).
690        _run_automatic_segmentation_per_dataset(run_amg=None, **seg_kwargs)
691
692        # ii. Run automatic mask generation (AMG).
693        #     NOTE: This would only run if the user wants to. Else by default, it is set to 'False'.
694        _run_automatic_segmentation_per_dataset(run_amg=run_amg, **seg_kwargs)
695
696    if evaluation_methods != "automatic":  # Avoid int. seg. evaluation for 'automatic'-only run choice.
697        # b. Run interactive segmentation (supported in both 2d and 3d, wherever relevant)
698        _run_interactive_segmentation_per_dataset(prompt_choice="box", **seg_kwargs)
699        _run_interactive_segmentation_per_dataset(prompt_choice="box", use_masks=True, **seg_kwargs)
700        _run_interactive_segmentation_per_dataset(prompt_choice="points", **seg_kwargs)
701        _run_interactive_segmentation_per_dataset(prompt_choice="points", use_masks=True, **seg_kwargs)
702
703
704def _clear_cached_items(retain, path, output_folder):
705    import shutil
706    from pathlib import Path
707
708    REMOVE_LIST = ["data", "crops", "automatic", "interactive"]
709    if retain is None:
710        remove_list = REMOVE_LIST
711    else:
712        assert isinstance(retain, list)
713        remove_list = set(REMOVE_LIST) - set(retain)
714
715    paths = []
716    # Stage 1: Remove inputs.
717    if "data" in remove_list or "crops" in remove_list:
718        all_paths = glob(os.path.join(path, "*"))
719
720        # In case we want to remove both data and crops, we remove the data folder entirely.
721        if "data" in remove_list and "crops" in remove_list:
722            paths.extend(all_paths)
723            return
724
725        # Next, we verify whether the we only remove either of data or crops.
726        for curr_path in all_paths:
727            if os.path.basename(curr_path).startswith("roi") and "crops" in remove_list:
728                paths.append(curr_path)
729            elif "data" in remove_list:
730                paths.append(curr_path)
731
732    # Stage 2: Remove predictions
733    if "automatic" in remove_list:
734        paths.extend(glob(os.path.join(output_folder, "amg_*")))
735        paths.extend(glob(os.path.join(output_folder, "ais_*")))
736
737    if "interactive" in remove_list:
738        paths.extend(glob(os.path.join(output_folder, "interactive_segmentation_*")))
739
740    [shutil.rmtree(_path) if Path(_path).is_dir() else os.remove(_path) for _path in paths]
741
742
743def run_benchmark_evaluations(
744    input_folder: Union[os.PathLike, str],
745    dataset_choice: str,
746    model_type: str = util._DEFAULT_MODEL,
747    output_folder: Optional[Union[str, os.PathLike]] = None,
748    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
749    run_amg: bool = False,
750    retain: Optional[List[str]] = None,
751    evaluation_methods: Literal["all", "automatic", "interactive"] = "all",
752    ignore_warnings: bool = False,
753):
754    """Run evaluation for benchmarking Segment Anything models on microscopy datasets.
755
756    Args:
757        input_folder: The path to directory where all inputs will be stored and preprocessed.
758        dataset_choice: The dataset choice.
759        model_type: The model choice for SAM.
760        output_folder: The path to directory where all outputs will be stored.
761        checkpoint_path: The checkpoint path
762        run_amg: Whether to run automatic segmentation in AMG mode.
763        retain: Whether to retain certain parts of the benchmark runs.
764            By default, removes everything besides quantitative results.
765            There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
766        evaluation_methods: The choice of evaluation methods.
767            By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive').
768            Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
769        ignore_warnings: Whether to ignore warnings.
770    """
771    start = time.time()
772
773    with _filter_warnings(ignore_warnings):
774        device = util._get_default_device()
775
776        # Ensure if all the datasets have been installed by default.
777        dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice)
778
779        for choice in dataset_choice:
780            output_folder = os.path.join(output_folder, choice)
781            result_dir = os.path.join(output_folder, "results")
782            os.makedirs(result_dir, exist_ok=True)
783
784            data_path = os.path.join(input_folder, choice)
785
786            # Extrapolate desired set from the datasets:
787            # a. for 2d datasets - 2d patches with the most number of labels present
788            #    (in case of volumetric data, choose 2d patches per slice).
789            # b. for 3d datasets - 3d regions of interest with the most number of labels present.
790            ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10)
791
792            # Run inference and evaluation scripts on benchmark datasets.
793            image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim)
794            _run_benchmark_evaluation_series(
795                image_paths=image_paths,
796                gt_paths=gt_paths,
797                model_type=model_type,
798                output_folder=output_folder,
799                ndim=ndim,
800                device=device,
801                checkpoint_path=checkpoint_path,
802                run_amg=run_amg,
803                evaluation_methods=evaluation_methods,
804            )
805
806            # Run inference and evaluation scripts on '2d' crops for volumetric datasets
807            if ndim == 3:
808                image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2)
809                _run_benchmark_evaluation_series(
810                    image_paths=image_paths,
811                    gt_paths=gt_paths,
812                    model_type=model_type,
813                    output_folder=output_folder,
814                    ndim=2,
815                    device=device,
816                    checkpoint_path=checkpoint_path,
817                    run_amg=run_amg,
818                    evaluation_methods=evaluation_methods,
819                )
820
821            _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder)
822
823    diff = time.time() - start
824    hours, rest = divmod(diff, 3600)
825    minutes, seconds = divmod(rest, 60)
826    print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s")
827
828
829def main():
830    """@private"""
831    import argparse
832
833    available_models = list(util.get_model_names())
834    available_models = ", ".join(available_models)
835
836    parser = argparse.ArgumentParser(
837        description="Run evaluation for benchmarking Segment Anything models on microscopy datasets."
838    )
839    parser.add_argument(
840        "-i", "--input_folder", type=str, required=True,
841        help="The path to a directory where the microscopy datasets are and/or will be stored."
842    )
843    parser.add_argument(
844        "-m", "--model_type", type=str, default=util._DEFAULT_MODEL,
845        help=f"The segment anything model that will be used, one of '{available_models}'. By default, "
846        f"it uses'{util._DEFAULT_MODEL}'."
847    )
848    parser.add_argument(
849        "-c", "--checkpoint_path", type=str, default=None,
850        help="The filepath to checkpoint from which the SAM model will be loaded."
851    )
852    parser.add_argument(
853        "-d", "--dataset_choice", type=str, nargs='*', default=None,
854        help="The choice(s) of dataset for evaluating SAM models. Multiple datasets can be specified. "
855        "By default, it evaluates on all datasets."
856    )
857    parser.add_argument(
858        "-o", "--output_folder", type=str, required=True,
859        help="The path where the results for (automatic and interactive) instance segmentation results "
860        "will be stored as 'csv' files."
861    )
862    parser.add_argument(
863        "--amg", action="store_true",
864        help="Whether to run automatic segmentation in AMG mode (i.e. the default auto-seg approach for SAM)."
865    )
866    parser.add_argument(
867        "--retain", nargs="*", default=None,
868        help="By default, the functionality removes all besides quantitative results required for running benchmarks. "
869        "In case you would like to retain parts of the benchmark evaluation for visualization / reproducibility, "
870        "you should choose one or multiple of 'data', 'crops', 'automatic', 'interactive'. "
871        "where they are responsible for either retaining original inputs / extracted crops / "
872        "predictions of automatic segmentation / predictions of interactive segmentation, respectively."
873    )
874    parser.add_argument(
875        "--evaluate", type=str, default="all", choices=["all", "automatic", "interactive"],
876        help="The choice of methods for benchmarking evaluation for reproducibility. "
877        "By default, we run all evaluations with 'all'. If 'automatic' is chosen, it runs automatic segmentation only "
878        "/ 'interactive' runs interactive segmentation (starting from box and single point) with iterative prompting."
879    )
880    args = parser.parse_args()
881
882    run_benchmark_evaluations(
883        input_folder=args.input_folder,
884        dataset_choice=args.dataset_choice,
885        model_type=args.model_type,
886        output_folder=args.output_folder,
887        checkpoint_path=args.checkpoint_path,
888        run_amg=args.amg,
889        retain=args.retain,
890        evaluation_methods=args.evaluate,
891        ignore_warnings=True,
892    )
893
894
895if __name__ == "__main__":
896    main()
LM_2D_DATASETS = ['livecell', 'deepbacs', 'tissuenet', 'neurips_cellseg', 'cellpose', 'dynamicnuclearnet', 'orgasegment', 'yeaz', 'arvidsson', 'bitdepth_nucseg', 'cellbindb', 'covid_if', 'deepseas', 'hpa', 'ifnuclei', 'lizard', 'organoidnet', 'toiam', 'vicar']
LM_3D_DATASETS = ['plantseg_root', 'plantseg_ovules', 'gonuclear', 'mouse_embryo', 'cellseg3d']
EM_2D_DATASETS = ['mitolab_tem']
EM_3D_DATASETS = ['lucchi', 'mitolab', 'uro_cell', 'sponge_em', 'vnc', 'nuc_mm_mouse', 'num_mm_zebrafish', 'platynereis_cilia', 'asem_mito']
DATASET_RETURNS_FOLDER = {'deepbacs': '*.tif', 'mitolab_tem': '*.tiff'}
DATASET_CONTAINER_KEYS = {'tissuenet': ['raw/rgb', 'labels/cell'], 'covid_if': ['raw/serum_IgG/s0', 'labels/cells/s0'], 'dynamicnuclearnet': ['raw', 'labels'], 'hpa': [['raw/protein', 'raw/microtubules', 'raw/er'], 'labels'], 'lizard': ['image', 'labels/segmentation'], 'plantseg_root': ['raw', 'label'], 'plantseg_ovules': ['raw', 'label'], 'gonuclear': ['raw/nuclei', 'labels/nuclei'], 'mouse_embryo': ['raw', 'label'], 'cellseg_3d': [None, None], 'lucchi': ['raw', 'labels'], 'uro_cell': ['raw', 'labels/mito'], 'mitolab_3d': [None, None], 'sponge_em': ['volumes/raw', 'volumes/labels/instances'], 'vnc': ['raw', 'labels/mitochondria']}
def run_benchmark_evaluations( input_folder: Union[os.PathLike, str], dataset_choice: str, model_type: str = 'vit_b_lm', output_folder: Union[os.PathLike, str, NoneType] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, run_amg: bool = False, retain: Optional[List[str]] = None, evaluation_methods: Literal['all', 'automatic', 'interactive'] = 'all', ignore_warnings: bool = False):
744def run_benchmark_evaluations(
745    input_folder: Union[os.PathLike, str],
746    dataset_choice: str,
747    model_type: str = util._DEFAULT_MODEL,
748    output_folder: Optional[Union[str, os.PathLike]] = None,
749    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
750    run_amg: bool = False,
751    retain: Optional[List[str]] = None,
752    evaluation_methods: Literal["all", "automatic", "interactive"] = "all",
753    ignore_warnings: bool = False,
754):
755    """Run evaluation for benchmarking Segment Anything models on microscopy datasets.
756
757    Args:
758        input_folder: The path to directory where all inputs will be stored and preprocessed.
759        dataset_choice: The dataset choice.
760        model_type: The model choice for SAM.
761        output_folder: The path to directory where all outputs will be stored.
762        checkpoint_path: The checkpoint path
763        run_amg: Whether to run automatic segmentation in AMG mode.
764        retain: Whether to retain certain parts of the benchmark runs.
765            By default, removes everything besides quantitative results.
766            There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
767        evaluation_methods: The choice of evaluation methods.
768            By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive').
769            Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
770        ignore_warnings: Whether to ignore warnings.
771    """
772    start = time.time()
773
774    with _filter_warnings(ignore_warnings):
775        device = util._get_default_device()
776
777        # Ensure if all the datasets have been installed by default.
778        dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice)
779
780        for choice in dataset_choice:
781            output_folder = os.path.join(output_folder, choice)
782            result_dir = os.path.join(output_folder, "results")
783            os.makedirs(result_dir, exist_ok=True)
784
785            data_path = os.path.join(input_folder, choice)
786
787            # Extrapolate desired set from the datasets:
788            # a. for 2d datasets - 2d patches with the most number of labels present
789            #    (in case of volumetric data, choose 2d patches per slice).
790            # b. for 3d datasets - 3d regions of interest with the most number of labels present.
791            ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10)
792
793            # Run inference and evaluation scripts on benchmark datasets.
794            image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim)
795            _run_benchmark_evaluation_series(
796                image_paths=image_paths,
797                gt_paths=gt_paths,
798                model_type=model_type,
799                output_folder=output_folder,
800                ndim=ndim,
801                device=device,
802                checkpoint_path=checkpoint_path,
803                run_amg=run_amg,
804                evaluation_methods=evaluation_methods,
805            )
806
807            # Run inference and evaluation scripts on '2d' crops for volumetric datasets
808            if ndim == 3:
809                image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2)
810                _run_benchmark_evaluation_series(
811                    image_paths=image_paths,
812                    gt_paths=gt_paths,
813                    model_type=model_type,
814                    output_folder=output_folder,
815                    ndim=2,
816                    device=device,
817                    checkpoint_path=checkpoint_path,
818                    run_amg=run_amg,
819                    evaluation_methods=evaluation_methods,
820                )
821
822            _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder)
823
824    diff = time.time() - start
825    hours, rest = divmod(diff, 3600)
826    minutes, seconds = divmod(rest, 60)
827    print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s")

Run evaluation for benchmarking Segment Anything models on microscopy datasets.

Arguments:
  • input_folder: The path to directory where all inputs will be stored and preprocessed.
  • dataset_choice: The dataset choice.
  • model_type: The model choice for SAM.
  • output_folder: The path to directory where all outputs will be stored.
  • checkpoint_path: The checkpoint path
  • run_amg: Whether to run automatic segmentation in AMG mode.
  • retain: Whether to retain certain parts of the benchmark runs. By default, removes everything besides quantitative results. There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
  • evaluation_methods: The choice of evaluation methods. By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
  • ignore_warnings: Whether to ignore warnings.