micro_sam.evaluation.benchmark_datasets

  1import os
  2import time
  3from glob import glob
  4from tqdm import tqdm
  5from natsort import natsorted
  6from typing import Union, Optional, List, Literal
  7
  8import numpy as np
  9import pandas as pd
 10import imageio.v3 as imageio
 11from skimage.measure import label as connected_components
 12
 13from nifty.tools import blocking
 14
 15import torch
 16
 17from torch_em.data import datasets
 18
 19from micro_sam import util
 20
 21from . import run_evaluation
 22from ..training.training import _filter_warnings
 23from .inference import run_inference_with_iterative_prompting
 24from .evaluation import run_evaluation_for_iterative_prompting
 25from .multi_dimensional_segmentation import segment_slices_from_ground_truth
 26from ..automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter
 27
 28
 29LM_2D_DATASETS = [
 30    # in-domain
 31    "livecell",  # cell segmentation in PhC (has a TEST-set)
 32    "deepbacs",  # bacteria segmentation in label-free microscopy (has a TEST-set),
 33    "tissuenet",  # cell segmentation in tissue microscopy images (has a TEST-set),
 34    "neurips_cellseg",  # cell segmentation in various (has a TEST-set),
 35    "cellpose",  # cell segmentation in FM (has 'cyto2' on which we can TEST on),
 36    "dynamicnuclearnet",  # nuclei segmentation in FM (has a TEST-set)
 37    "orgasegment",  # organoid segmentation in BF (has a TEST-set)
 38    "yeaz",  # yeast segmentation in BF (has a TEST-set)
 39
 40    # out-of-domain
 41    "arvidsson",  # nuclei segmentation in HCS FM
 42    "bitdepth_nucseg",  # nuclei segmentation in FM
 43    "cellbindb",  # cell segmentation in various microscopy
 44    "covid_if",  # cell segmentation in IF
 45    "deepseas",  # cell segmentation in PhC,
 46    "hpa",  # cell segmentation in confocal,
 47    "ifnuclei",  # nuclei segmentation in IFM
 48    "lizard",  # nuclei segmentation in H&E histopathology,
 49    "organoidnet",  # organoid segmentation in BF
 50    "toiam",  # microbial cell segmentation in PhC
 51    "vicar",  # cell segmentation in label-free
 52]
 53
 54LM_3D_DATASETS = [
 55    # in-domain
 56    "plantseg_root",  # cell segmentation in lightsheet (has a TEST-set)
 57
 58    # out-of-domain
 59    "plantseg_ovules",  # cell segmentation in confocal
 60    "gonuclear",  # nuclei segmentation in FM
 61    "mouse_embryo",  # cell segmentation in lightsheet
 62    "cellseg3d",  # nuclei segmentation in FM
 63]
 64
 65EM_2D_DATASETS = ["mitolab_tem"]
 66
 67EM_3D_DATASETS = [
 68    "mitoem_rat", "mitoem_human", "platynereis_nuclei", "lucchi", "mitolab", "nuc_mm_mouse",
 69    "num_mm_zebrafish", "uro_cell", "sponge_em", "platynereis_cilia", "vnc", "asem_mito",
 70]
 71
 72DATASET_RETURNS_FOLDER = {
 73    "deepbacs": "*.tif"
 74}
 75
 76DATASET_CONTAINER_KEYS = {
 77    # 2d
 78    "tissuenet": ["raw/rgb", "labels/cell"],
 79    "covid_if": ["raw/serum_IgG/s0", "labels/cells/s0"],
 80    "dynamicnuclearnet": ["raw", "labels"],
 81    "hpa": [["raw/protein", "raw/microtubules", "raw/er"], "labels"],
 82    "lizard": ["image", "labels/segmentation"],
 83
 84    # 3d
 85    "plantseg_root": ["raw", "label"],
 86    "plantseg_ovules": ["raw", "label"],
 87    "gonuclear": ["raw/nuclei", "labels/nuclei"],
 88    "mouse_embryo": ["raw", "label"],
 89    "cellseg_3d": [None, None],
 90    "lucchi": ["raw", "labels"],
 91}
 92
 93
 94def _download_benchmark_datasets(path, dataset_choice):
 95    """Ensures whether all the datasets have been downloaded or not.
 96
 97    Args:
 98        path: The path to directory where the supported datasets will be downloaded
 99            for benchmarking Segment Anything models.
100        dataset_choice: The choice of dataset, expects the lower case name for the dataset.
101
102    Returns:
103        List of choice of dataset(s).
104    """
105    available_datasets = {
106        # Light Microscopy datasets
107
108        # 2d datasets: in-domain
109        "livecell": lambda: datasets.livecell.get_livecell_data(
110            path=os.path.join(path, "livecell"), download=True,
111        ),
112        "deepbacs": lambda: datasets.deepbacs.get_deepbacs_data(
113            path=os.path.join(path, "deepbacs"), bac_type="mixed", download=True,
114        ),
115        "tissuenet": lambda: datasets.tissuenet.get_tissuenet_data(
116            path=os.path.join(path, "tissuenet"), split="test", download=True,
117        ),
118        "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_data(
119            root=os.path.join(path, "neurips_cellseg"), split="test", download=True,
120        ),
121        "cellpose": lambda: datasets.cellpose.get_cellpose_data(
122            path=os.path.join(path, "cellpose"), split="train", choice="cyto2", download=True,
123        ),
124        "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_data(
125            path=os.path.join(path, "dynamicnuclearnet"), split="test", download=True,
126        ),
127        "orgasegment": lambda: datasets.orgasegment.get_orgasegment_data(
128            path=os.path.join(path, "orgasegment"), split="eval", download=True,
129        ),
130        "yeaz": lambda: datasets.yeaz.get_yeaz_data(
131            path=os.path.join(path, "yeaz"), choice="bf", download=True,
132        ),
133
134        # 2d datasets: out-of-domain
135        "arvidsson": lambda: datasets.arvidsson.get_arvidsson_data(
136            path=os.path.join(path, "arvidsson"), split="test", download=True,
137        ),
138        "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_data(
139            path=os.path.join(path, "bitdepth_nucseg"), download=True,
140        ),
141        "cellbindb": lambda: datasets.cellbindb.get_cellbindb_data(
142            path=os.path.join(path, "cellbindb"), download=True,
143        ),
144        "covid_if": lambda: datasets.covid_if.get_covid_if_data(
145            path=os.path.join(path, "covid_if"), download=True,
146        ),
147        "deepseas": lambda: datasets.deepseas.get_deepseas_data(
148            path=os.path.join(path, "deepseas"), split="test",
149        ),
150        "hpa": lambda: datasets.hpa.get_hpa_segmentation_data(
151            path=os.path.join(path, "hpa"), download=True,
152        ),
153        "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_data(
154            path=os.path.join(path, "ifnuclei"), download=True,
155        ),
156        "lizard": lambda: datasets.lizard.get_lizard_data(
157            path=os.path.join(path, "lizard"), split="test", download=True,
158        ),
159        "organoidnet": lambda: datasets.organoidnet.get_organoidnet_data(
160            path=os.path.join(path, "organoidnet"), split="Test", download=True,
161        ),
162        "toiam": lambda: datasets.toiam.get_toiam_data(
163            path=os.path.join(path, "toiam"), download=True,
164        ),
165        "vicar": lambda: datasets.vicar.get_vicar_data(
166            path=os.path.join(path, "vicar"), download=True,
167        ),
168
169        # 3d datasets: in-domain
170        "plantseg_root": lambda: datasets.plantseg.get_plantseg_data(
171            path=os.path.join(path, "plantseg_root"), split="test", download=True, name="root",
172        ),
173
174        # 3d datasets: out-of-domain
175        "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_data(
176            path=os.path.join(path, "plantseg_ovules"), split="test", download=True, name="ovules",
177        ),
178        "gonuclear": lambda: datasets.gonuclear.get_gonuclear_data(
179            path=os.path.join(path, "gonuclear"), download=True,
180        ),
181        "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_data(
182            path=os.path.join(path, "mouse_embryo"), download=True,
183        ),
184        "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_data(
185            path=os.path.join(path, "cellseg_3d"), download=True,
186        ),
187
188        # Electron Microscopy datasets
189        "mitoem_rat": lambda: datasets.mitoem.get_mitoem_data(
190            path=os.path.join(path, "mitoem"), samples="rat", split="test", download=True,
191        ),
192        "mitoem_human": lambda: datasets.mitoem.get_mitoem_data(
193            path=os.path.join(path, "mitoem"), samples="human", split="test", download=True,
194        ),
195        "platynereis_nuclei": lambda: datasets.platynereis.get_platy_data(
196            path=os.path.join(path, "platynereis"), name="nuclei", download=True,
197        ),
198        "platynereis_cilia": lambda: datasets.platynereis.get_platy_data(
199            path=os.path.join(path, "platynereis"), name="cilia", download=True,
200        ),
201        "lucchi": lambda: datasets.lucchi.get_lucchi_data(
202            path=os.path.join(path, "lucchi"), split="test", download=True,
203        ),
204        "mitolab_3d": lambda: [
205            datasets.cem.get_benchmark_data(
206                path=os.path.join(path, "mitolab"), dataset_id=dataset_id, download=True,
207            ) for dataset_id in range(1, 7)
208        ],
209        "mitolab_tem": lambda: datasets.cem.get_benchmark_data(
210            path=os.path.join(path, "mitolab"), dataset_id=7, download=True
211        ),
212        "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_data(
213            path=os.path.join(path, "nuc_mm"), sample="mouse", download=True,
214        ),
215        "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_data(
216            path=os.path.join(path, "nuc_mm"), sample="zebrafish", download=True,
217        ),
218        "uro_cell": lambda: datasets.uro_cell.get_uro_cell_data(
219            path=os.path.join(path, "uro_cell"), download=True,
220        ),
221        "sponge_em": lambda: datasets.sponge_em.get_sponge_em_data(
222            path=os.path.join(path, "sponge_em"), download=True,
223        ),
224        "vnc": lambda: datasets.vnc.get_vnc_data(
225            path=os.path.join(path, "vnc"), download=True,
226        ),
227        "asem_mito": lambda: datasets.asem.get_asem_data(
228            path=os.path.join(path, "asem"), volume_ids=datasets.asem.ORGANELLES["mito"], download=True,
229        )
230    }
231
232    if dataset_choice is None:
233        dataset_choice = available_datasets.keys()
234    else:
235        if not isinstance(dataset_choice, list):
236            dataset_choice = [dataset_choice]
237
238    for choice in dataset_choice:
239        if choice in available_datasets:
240            available_datasets[choice]()
241        else:
242            raise ValueError(f"'{choice}' is not a supported choice of dataset.")
243
244    return dataset_choice
245
246
247def _extract_slices_from_dataset(path, dataset_choice, crops_per_input=10):
248    """Extracts crops of desired shapes for performing evaluation in both 2d and 3d using `micro-sam`.
249
250    Args:
251        path: The path to directory where the supported datasets have be downloaded
252            for benchmarking Segment Anything models.
253        dataset_choice: The name of the dataset of choice to extract crops.
254        crops_per_input: The maximum number of crops to extract per inputs.
255        extract_2d: Whether to extract 2d crops from 3d patches.
256
257    Returns:
258        Filepath to the folder where extracted images are stored.
259        Filepath to the folder where corresponding extracted labels are stored.
260        The number of dimensions supported by the input.
261    """
262    ndim = 2 if dataset_choice in [*LM_2D_DATASETS, *EM_2D_DATASETS] else 3
263    tile_shape = (512, 512) if ndim == 2 else (32, 512, 512)
264
265    # For 3d inputs, we extract both 2d and 3d crops.
266    extract_2d_crops_from_volumes = (ndim == 3)
267
268    available_datasets = {
269        # Light Microscopy datasets
270
271        # 2d: in-domain
272        "livecell": lambda: datasets.livecell.get_livecell_paths(path=path, split="test"),
273        "deepbacs": lambda: datasets.deepbacs.get_deepbacs_paths(path=path, split="test", bac_type="mixed"),
274        "tissuenet": lambda: datasets.tissuenet.get_tissuenet_paths(path=path, split="test"),
275        "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_paths(root=path, split="test"),
276        "cellpose": lambda: datasets.cellpose.get_cellpose_paths(path=path, split="train", choice="cyto2"),
277        "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_paths(path=path, split="test"),
278        "orgasegment": lambda: datasets.orgasegment.get_orgasegment_paths(path=path, split="eval"),
279        "yeaz": lambda: datasets.yeaz.get_yeaz_paths(path=path, choice="bf", split="test"),
280
281        # 2d: out-of-domain
282        "arvidsson": lambda: datasets.arvidsson.get_arvidsson_paths(path=path, split="test"),
283        "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_paths(path=path, magnification="20x"),
284        "cellbindb": lambda: datasets.cellbindb.get_cellbindb_paths(
285            path=path, data_choice=["10×Genomics_DAPI", "DAPI", "mIF"]
286        ),
287        "covid_if": lambda: datasets.covid_if.get_covid_if_paths(path=path),
288        "deepseas": lambda: datasets.deepseas.get_deepseas_paths(path=path, split="test"),
289        "hpa": lambda: datasets.hpa.get_hpa_segmentation_paths(path=path, split="val"),
290        "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_paths(path=path),
291        "lizard": lambda: datasets.lizard.get_lizard_paths(path=path, split="test"),
292        "organoidnet": lambda: datasets.organoidnet.get_organoidnet_paths(path=path, split="Test"),
293        "toiam": lambda: datasets.toiam.get_toiam_paths(path=path),
294        "vicar": lambda: datasets.vicar.get_vicar_paths(path=path),
295
296        # 3d: in-domain
297        "plantseg_root": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="root", split="test"),
298
299        # 3d: out-of-domain
300        "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="ovules", split="test"),
301        "gonuclear": lambda: datasets.gonuclear.get_gonuclear_paths(path=path),
302        "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_paths(path=path, name="nuclei", split="val"),
303        "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_paths(path=path),
304
305        # Electron Microscopy datasets
306        "mitoem_rat": lambda: datasets.mitoem.get_mitoem_paths(path=path, splits="test", samples="rat"),
307        "mitem_human": lambda: datasets.mitoem.get_mitoem_paths(path=path, splits="test", samples="human"),
308        "platynereis_nuclei": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="nuclei"),
309        "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="cilia"),
310        "lucchi": lambda: datasets.lucchi.get_lucchi_paths(path=path, split="test"),
311        "mitolab_3d": lambda: (
312            [rpath for i in range(1, 7) for rpath in datasets.cem.get_benchmark_paths(path=path, dataset_id=i)[0]],
313            [lpath for i in range(1, 7) for lpath in datasets.cem.get_benchmark_paths(path=path, dataset_id=i)[1]]
314        ),
315        "mitolab_tem": lambda: datasets.cem.get_benchmark_paths(path=path, dataset_id=7),
316        "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="mouse", split="val"),
317        "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="zebrafish", split="val"),
318        "uro_cell": lambda: datasets.uro_cell.get_uro_cell_paths(path=path, target="mito"),
319        "sponge_em": lambda: datasets.sponge_em.get_sponge_em_paths(path=path, sample_ids=None),
320        "vnc": lambda: datasets.vnc.get_vnc_mito_paths(path=path),
321        "asem_mito": lambda: datasets.asem.get_asem_paths(path=path, volume_ids=datasets.asem.ORGANELLES["mito"])
322    }
323
324    if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice == "cellseg_3d":
325        image_paths, gt_paths = available_datasets[dataset_choice]()
326
327        if dataset_choice in DATASET_RETURNS_FOLDER:
328            image_paths = glob(os.path.join(image_paths, DATASET_RETURNS_FOLDER[dataset_choice]))
329            gt_paths = glob(os.path.join(gt_paths, DATASET_RETURNS_FOLDER[dataset_choice]))
330
331        image_paths, gt_paths = natsorted(image_paths), natsorted(gt_paths)
332        assert len(image_paths) == len(gt_paths)
333
334        paths_set = zip(image_paths, gt_paths)
335
336    else:
337        image_paths = available_datasets[dataset_choice]()
338        if isinstance(image_paths, str):
339            paths_set = [image_paths]
340        else:
341            paths_set = natsorted(image_paths)
342
343    # Directory where we store the extracted ROIs.
344    save_image_dir = [os.path.join(path, f"roi_{ndim}d", "inputs")]
345    save_gt_dir = [os.path.join(path, f"roi_{ndim}d", "labels")]
346    if extract_2d_crops_from_volumes:
347        save_image_dir.append(os.path.join(path, "roi_2d", "inputs"))
348        save_gt_dir.append(os.path.join(path, "roi_2d", "labels"))
349
350    _dir_exists = [os.path.exists(idir) and os.path.exists(gdir) for idir, gdir in zip(save_image_dir, save_gt_dir)]
351    if all(_dir_exists):
352        return ndim
353
354    [os.makedirs(idir, exist_ok=True) for idir in save_image_dir]
355    [os.makedirs(gdir, exist_ok=True) for gdir in save_gt_dir]
356
357    # Logic to extract relevant patches for inference
358    image_counter = 1
359    for per_paths in tqdm(paths_set, desc=f"Extracting {ndim}d patches for {dataset_choice}"):
360        if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice == "cellseg_3d":
361            image_path, gt_path = per_paths
362            image, gt = util.load_image_data(image_path), util.load_image_data(gt_path)
363        else:
364            image_path = per_paths
365            gt = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][1])
366            if dataset_choice == "hpa":
367                # Get inputs per channel and stack them together to make the desired 3 channel image.
368                image = np.stack(
369                    [util.load_image_data(image_path, k) for k in DATASET_CONTAINER_KEYS[dataset_choice][0]], axis=0,
370                )
371                # Resize inputs to desired tile shape, in favor of working with the shape of foreground.
372                from torch_em.transform.generic import ResizeLongestSideInputs
373                raw_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_rgb=True)
374                label_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_label=True)
375                image, gt = raw_transform(image).transpose(1, 2, 0), label_transform(gt)
376
377            else:
378                image = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][0])
379
380        if dataset_choice in ["tissuenet", "lizard"]:
381            if image.ndim == 3 and image.shape[0] == 3:  # Make channels last for tissuenet RGB-style images.
382                image = image.transpose(1, 2, 0)
383
384        # Allow RGBs to stay as it is with channels last
385        if image.ndim == 3 and image.shape[-1] == 3:
386            skip_smaller_shape = (np.array(image.shape) >= np.array((*tile_shape, 3))).all()
387        else:
388            skip_smaller_shape = (np.array(image.shape) >= np.array(tile_shape)).all()
389
390        # Ensure ground truth has instance labels.
391        gt = connected_components(gt)
392
393        if len(np.unique(gt)) == 1:  # There could be labels which does not have any annotated foreground.
394            continue
395
396        # Let's extract and save all the crops.
397        # The first round of extraction is always to match the desired input dimensions.
398        image_crops, gt_crops = _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input)
399        image_counter = _save_image_label_crops(
400            image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir[0], save_gt_dir[0]
401        )
402
403        # The next round of extraction is to get 2d crops from 3d inputs.
404        if extract_2d_crops_from_volumes:
405            curr_tile_shape = tile_shape[1:]  # We expect 2d tile shape for this stage.
406
407            curr_image_crops, curr_gt_crops = [], []
408            for per_z_im, per_z_gt in zip(image, gt):
409                curr_skip_smaller_shape = (np.array(per_z_im.shape) >= np.array(curr_tile_shape)).all()
410
411                image_crops, gt_crops = _get_crops_for_input(
412                    image=per_z_im, gt=per_z_gt, ndim=2,
413                    tile_shape=curr_tile_shape,
414                    skip_smaller_shape=curr_skip_smaller_shape,
415                    crops_per_input=crops_per_input,
416                )
417                curr_image_crops.extend(image_crops)
418                curr_gt_crops.extend(gt_crops)
419
420            image_counter = _save_image_label_crops(
421                curr_image_crops, curr_gt_crops, dataset_choice, 2, image_counter, save_image_dir[1], save_gt_dir[1]
422            )
423
424    return ndim
425
426
427def _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input):
428    tiling = blocking([0] * ndim, gt.shape, tile_shape)
429    n_tiles = tiling.numberOfBlocks
430    tiles = [tiling.getBlock(tile_id) for tile_id in range(n_tiles)]
431    crop_boxes = [
432        tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) for tile in tiles
433    ]
434    n_ids = [idx for idx in range(len(crop_boxes))]
435    n_instances = [len(np.unique(gt[crop])) for crop in crop_boxes]
436
437    # Extract the desired number of patches with higher number of instances.
438    image_crops, gt_crops = [], []
439    for i, (per_n_instance, per_id) in enumerate(sorted(zip(n_instances, n_ids), reverse=True), start=1):
440        crop_box = crop_boxes[per_id]
441        crop_image, crop_gt = image[crop_box], gt[crop_box]
442
443        # NOTE: We avoid using the crops which do not match the desired tile shape.
444        _rtile_shape = (*tile_shape, 3) if image.ndim == 3 and image.shape[-1] == 3 else tile_shape  # For RGB images.
445        if skip_smaller_shape and crop_image.shape != _rtile_shape:
446            continue
447
448        # NOTE: There could be a case where some later patches are invalid.
449        if per_n_instance == 1:
450            break
451
452        image_crops.append(crop_image)
453        gt_crops.append(crop_gt)
454
455        # NOTE: If the number of patches extracted have been fulfiled, we stop sampling patches.
456        if len(image_crops) > 0 and i >= crops_per_input:
457            break
458
459    return image_crops, gt_crops
460
461
462def _save_image_label_crops(image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir, save_gt_dir):
463    for image_crop, gt_crop in tqdm(
464        zip(image_crops, gt_crops), total=len(image_crops), desc=f"Saving {ndim}d crops for {dataset_choice}"
465    ):
466        fname = f"{dataset_choice}_{image_counter:05}.tif"
467
468        if image_crop.ndim == 3 and image_crop.shape[-1] == 3:
469            assert image_crop.shape[:2] == gt_crop.shape
470        else:
471            assert image_crop.shape == gt_crop.shape
472
473        imageio.imwrite(os.path.join(save_image_dir, fname), image_crop, compression="zlib")
474        imageio.imwrite(os.path.join(save_gt_dir, fname), gt_crop, compression="zlib")
475
476        image_counter += 1
477
478    return image_counter
479
480
481def _get_image_label_paths(path, ndim):
482    image_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "inputs", "*")))
483    gt_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "labels", "*")))
484    return image_paths, gt_paths
485
486
487def _run_automatic_segmentation_per_dataset(
488    image_paths: List[Union[os.PathLike, str]],
489    gt_paths: List[Union[os.PathLike, str]],
490    model_type: str,
491    output_folder: Union[os.PathLike, str],
492    ndim: Optional[int] = None,
493    device: Optional[Union[torch.device, str]] = None,
494    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
495    run_amg: Optional[bool] = None,
496    **auto_seg_kwargs
497):
498    """Functionality to run automatic segmentation for multiple input files at once.
499    It stores the evaluated automatic segmentation results (quantitative).
500
501    Args:
502        image_paths: List of filepaths for the input image data.
503        gt_paths: List of filepaths for the corresponding label data.
504        model_type: The choice of image encoder for the Segment Anything model.
505        output_folder: Filepath to the folder where we store all the results.
506        ndim: The number of input dimensions.
507        device: The torch device.
508        checkpoint_path: The filepath where the model checkpoints are stored.
509        run_amg: Whether to run automatic segmentation in AMG mode.
510        auto_seg_kwargs: Additional arguments for automatic segmentation parameters.
511    """
512    # First, we check if 'run_amg' is done, whether decoder is available or not.
513    # Depending on that, we can set 'run_amg' to the default best automatic segmentation (i.e. AIS > AMG).
514    if run_amg is None or (not run_amg):  # The 2nd condition checks if you want AIS and if decoder state exists or not.
515        _, state = util.get_sam_model(
516            model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True
517        )
518        run_amg = ("decoder_state" not in state)
519
520    experiment_name = "AMG" if run_amg else "AIS"
521    fname = f"{experiment_name.lower()}_{ndim}d"
522
523    result_path = os.path.join(output_folder, "results", f"{fname}.csv")
524    if os.path.exists(result_path):
525        return
526
527    prediction_dir = os.path.join(output_folder, fname, "inference")
528    os.makedirs(prediction_dir, exist_ok=True)
529
530    # Get the predictor (and the additional instance segmentation decoder, if available).
531    predictor, segmenter = get_predictor_and_segmenter(
532        model_type=model_type, checkpoint=checkpoint_path, device=device, amg=run_amg, is_tiled=False,
533    )
534
535    for image_path in tqdm(image_paths, desc=f"Run {experiment_name} in {ndim}d"):
536        output_path = os.path.join(prediction_dir, os.path.basename(image_path))
537        if os.path.exists(output_path):
538            continue
539
540        # Run Automatic Segmentation (AMG and AIS)
541        automatic_instance_segmentation(
542            predictor=predictor,
543            segmenter=segmenter,
544            input_path=image_path,
545            output_path=output_path,
546            ndim=ndim,
547            verbose=False,
548            **auto_seg_kwargs
549        )
550
551    prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*")))
552    run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths, save_path=result_path)
553
554
555def _run_interactive_segmentation_per_dataset(
556    image_paths: List[Union[os.PathLike, str]],
557    gt_paths: List[Union[os.PathLike, str]],
558    output_folder: Union[os.PathLike, str],
559    model_type: str,
560    prompt_choice: Literal["box", "points"],
561    device: Optional[Union[torch.device, str]] = None,
562    ndim: Optional[int] = None,
563    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
564    use_masks: bool = False,
565):
566    """Functionality to run interactive segmentation for multiple input files at once.
567    It stores the evaluated interactive segmentation results.
568
569    Args:
570        image_paths: List of filepaths for the input image data.
571        gt_paths: List of filepaths for the corresponding label data.
572        output_folder: Filepath to the folder where we store all the results.
573        model_type: The choice of model type for Segment Anything.
574        prompt_choice: The choice of initial prompts to begin the interactive segmentation.
575        device: The torch device.
576        ndim: The number of input dimensions.
577        checkpoint_path: The filepath for stored checkpoints.
578        use_masks: Whether to use masks for iterative prompting.
579    """
580    if ndim == 2:
581        # Get the Segment Anything predictor.
582        predictor = util.get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path)
583
584        prediction_root = os.path.join(
585            output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}",
586            "iterative_prompting_" + ("with_masks" if use_masks else "without_masks")
587        )
588
589        # Run interactive instance segmentation
590        # (starting with box and points followed by iterative prompt-based correction)
591        run_inference_with_iterative_prompting(
592            predictor=predictor,
593            image_paths=image_paths,
594            gt_paths=gt_paths,
595            embedding_dir=None,  # We set this to None to compute embeddings on-the-fly.
596            prediction_dir=prediction_root,
597            start_with_box_prompt=(prompt_choice == "box"),
598            use_masks=use_masks,
599            # TODO: add parameter for deform over box prompts (to simulate prompts in practice).
600        )
601
602        # Evaluate the interactive instance segmentation.
603        run_evaluation_for_iterative_prompting(
604            gt_paths=gt_paths,
605            prediction_root=prediction_root,
606            experiment_folder=output_folder,
607            start_with_box_prompt=(prompt_choice == "box"),
608            use_masks=use_masks,
609        )
610
611    else:
612        save_path = os.path.join(output_folder, "results", f"interactive_segmentation_3d_with_{prompt_choice}.csv")
613        if os.path.exists(save_path):
614            print(
615                f"Results for 3d interactive segmentation with '{prompt_choice}' are already stored at '{save_path}'."
616            )
617            return
618
619        results = []
620        for image_path, gt_path in tqdm(
621            zip(image_paths, gt_paths), total=len(image_paths),
622            desc=f"Run interactive segmentation in 3d with '{prompt_choice}'"
623        ):
624            prediction_dir = os.path.join(output_folder, "interactive_segmentation_3d", f"{prompt_choice}")
625            os.makedirs(prediction_dir, exist_ok=True)
626
627            prediction_path = os.path.join(prediction_dir, os.path.basename(image_path))
628            if os.path.exists(prediction_path):
629                continue
630
631            per_vol_result = segment_slices_from_ground_truth(
632                volume=imageio.imread(image_path),
633                ground_truth=imageio.imread(gt_path),
634                model_type=model_type,
635                checkpoint_path=checkpoint_path,
636                save_path=prediction_path,
637                device=device,
638                interactive_seg_mode=prompt_choice,
639                min_size=10,
640            )
641            results.append(per_vol_result)
642
643        results = pd.concat(results)
644        results = results.groupby(results.index).mean()
645        results.to_csv(save_path)
646
647
648def _run_benchmark_evaluation_series(
649    image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg, evaluation_methods,
650):
651    seg_kwargs = {
652        "image_paths": image_paths,
653        "gt_paths": gt_paths,
654        "output_folder": output_folder,
655        "ndim": ndim,
656        "model_type": model_type,
657        "device": device,
658        "checkpoint_path": checkpoint_path,
659    }
660
661    # Perform:
662    # a. automatic segmentation (supported in both 2d and 3d, wherever relevant)
663    #    The automatic segmentation steps below are configured in a way that AIS has priority (if decoder is found)
664    #    Otherwise, it runs for AMG.
665    #    Next, we check if the user expects to run AMG as well (after the run for AIS).
666
667    if evaluation_methods != "interactive":  # Avoid auto. seg. evaluation for 'interactive'-only run choice.
668        # i. Run automatic segmentation method supported with the SAM model (AMG or AIS).
669        _run_automatic_segmentation_per_dataset(run_amg=None, **seg_kwargs)
670
671        # ii. Run automatic mask generation (AMG).
672        #     NOTE: This would only run if the user wants to. Else by default, it is set to 'False'.
673        _run_automatic_segmentation_per_dataset(run_amg=run_amg, **seg_kwargs)
674
675    if evaluation_methods != "automatic":  # Avoid int. seg. evaluation for 'automatic'-only run choice.
676        # b. Run interactive segmentation (supported in both 2d and 3d, wherever relevant)
677        _run_interactive_segmentation_per_dataset(prompt_choice="box", **seg_kwargs)
678        _run_interactive_segmentation_per_dataset(prompt_choice="box", use_masks=True, **seg_kwargs)
679        _run_interactive_segmentation_per_dataset(prompt_choice="points", **seg_kwargs)
680        _run_interactive_segmentation_per_dataset(prompt_choice="points", use_masks=True, **seg_kwargs)
681
682
683def _clear_cached_items(retain, path, output_folder):
684    import shutil
685    from pathlib import Path
686
687    REMOVE_LIST = ["data", "crops", "automatic", "interactive"]
688    if retain is None:
689        remove_list = REMOVE_LIST
690    else:
691        assert isinstance(retain, list)
692        remove_list = set(REMOVE_LIST) - set(retain)
693
694    paths = []
695    # Stage 1: Remove inputs.
696    if "data" in remove_list or "crops" in remove_list:
697        all_paths = glob(os.path.join(path, "*"))
698
699        # In case we want to remove both data and crops, we remove the data folder entirely.
700        if "data" in remove_list and "crops" in remove_list:
701            paths.extend(all_paths)
702            return
703
704        # Next, we verify whether the we only remove either of data or crops.
705        for curr_path in all_paths:
706            if os.path.basename(curr_path).startswith("roi") and "crops" in remove_list:
707                paths.append(curr_path)
708            elif "data" in remove_list:
709                paths.append(curr_path)
710
711    # Stage 2: Remove predictions
712    if "automatic" in remove_list:
713        paths.extend(glob(os.path.join(output_folder, "amg_*")))
714        paths.extend(glob(os.path.join(output_folder, "ais_*")))
715
716    if "interactive" in remove_list:
717        paths.extend(glob(os.path.join(output_folder, "interactive_segmentation_*")))
718
719    [shutil.rmtree(_path) if Path(_path).is_dir() else os.remove(_path) for _path in paths]
720
721
722def run_benchmark_evaluations(
723    input_folder: Union[os.PathLike, str],
724    dataset_choice: str,
725    model_type: str = util._DEFAULT_MODEL,
726    output_folder: Optional[Union[str, os.PathLike]] = None,
727    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
728    run_amg: bool = False,
729    retain: Optional[List[str]] = None,
730    evaluation_methods: Literal["all", "automatic", "interactive"] = "all",
731    ignore_warnings: bool = False,
732):
733    """Run evaluation for benchmarking Segment Anything models on microscopy datasets.
734
735    Args:
736        input_folder: The path to directory where all inputs will be stored and preprocessed.
737        dataset_choice: The dataset choice.
738        model_type: The model choice for SAM.
739        output_folder: The path to directory where all outputs will be stored.
740        checkpoint_path: The checkpoint path
741        run_amg: Whether to run automatic segmentation in AMG mode.
742        retain: Whether to retain certain parts of the benchmark runs.
743            By default, removes everything besides quantitative results.
744            There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
745        evaluation_methods: The choice of evaluation methods.
746            By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive').
747            Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
748        ignore_warnings: Whether to ignore warnings.
749    """
750    start = time.time()
751
752    with _filter_warnings(ignore_warnings):
753        device = util._get_default_device()
754
755        # Ensure if all the datasets have been installed by default.
756        dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice)
757
758        for choice in dataset_choice:
759            output_folder = os.path.join(output_folder, choice)
760            result_dir = os.path.join(output_folder, "results")
761            os.makedirs(result_dir, exist_ok=True)
762
763            data_path = os.path.join(input_folder, choice)
764
765            # Extrapolate desired set from the datasets:
766            # a. for 2d datasets - 2d patches with the most number of labels present
767            #    (in case of volumetric data, choose 2d patches per slice).
768            # b. for 3d datasets - 3d regions of interest with the most number of labels present.
769            ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10)
770
771            # Run inference and evaluation scripts on benchmark datasets.
772            image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim)
773            _run_benchmark_evaluation_series(
774                image_paths=image_paths,
775                gt_paths=gt_paths,
776                model_type=model_type,
777                output_folder=output_folder,
778                ndim=ndim,
779                device=device,
780                checkpoint_path=checkpoint_path,
781                run_amg=run_amg,
782                evaluation_methods=evaluation_methods,
783            )
784
785            # Run inference and evaluation scripts on '2d' crops for volumetric datasets
786            if ndim == 3:
787                image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2)
788                _run_benchmark_evaluation_series(
789                    image_paths=image_paths,
790                    gt_paths=gt_paths,
791                    model_type=model_type,
792                    output_folder=output_folder,
793                    ndim=2,
794                    device=device,
795                    checkpoint_path=checkpoint_path,
796                    run_amg=run_amg,
797                    evaluation_methods=evaluation_methods,
798                )
799
800            _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder)
801
802    diff = time.time() - start
803    hours, rest = divmod(diff, 3600)
804    minutes, seconds = divmod(rest, 60)
805    print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s")
806
807
808def main():
809    """@private"""
810    import argparse
811
812    available_models = list(util.get_model_names())
813    available_models = ", ".join(available_models)
814
815    parser = argparse.ArgumentParser(
816        description="Run evaluation for benchmarking Segment Anything models on microscopy datasets."
817    )
818    parser.add_argument(
819        "-i", "--input_folder", type=str, required=True,
820        help="The path to a directory where the microscopy datasets are / will be stored."
821    )
822    parser.add_argument(
823        "-m", "--model_type", type=str, default=util._DEFAULT_MODEL,
824        help=f"The segment anything model that will be used, one of {available_models}."
825    )
826    parser.add_argument(
827        "-c", "--checkpoint_path", type=str, default=None,
828        help="Checkpoint from which the SAM model will be loaded loaded."
829    )
830    parser.add_argument(
831        "-d", "--dataset_choice", type=str, nargs='*', default=None,
832        help="The choice(s) of dataset for evaluating SAM models. Multiple datasets can be specified."
833    )
834    parser.add_argument(
835        "-o", "--output_folder", type=str, required=True,
836        help="The path where the results for automatic and interactive instance segmentation will be stored as 'csv'."
837    )
838    parser.add_argument(
839        "--amg", action="store_true",
840        help="Whether to run automatic segmentation in AMG mode (i.e. the default auto-seg approach for SAM)."
841    )
842    parser.add_argument(
843        "--retain", nargs="*", default=None,
844        help="By default, the functionality removes all besides quantitative results required for running benchmarks. "
845        "In case you would like to retain parts of the benchmark evaluation for visualization / reproducibility, "
846        "you should choose one or multiple of 'data', 'crops', 'automatic', 'interactive'. "
847        "where they are responsible for either retaining original inputs / extracted crops / "
848        "predictions of automatic segmentation / predictions of interactive segmentation, respectively."
849    )
850    parser.add_argument(
851        "--evaluate", type=str, default="all", choices=["all", "automatic", "interactive"],
852        help="The choice of methods for benchmarking evaluation for reproducibility. "
853        "By default, we run all evaluations with 'all' / 'automatic' runs automatic segmentation only / "
854        "'interactive' runs interactive segmentation (starting from box and single point) with iterative prompting."
855    )
856    args = parser.parse_args()
857
858    run_benchmark_evaluations(
859        input_folder=args.input_folder,
860        dataset_choice=args.dataset_choice,
861        model_type=args.model_type,
862        output_folder=args.output_folder,
863        checkpoint_path=args.checkpoint_path,
864        run_amg=args.amg,
865        retain=args.retain,
866        evaluation_methods=args.evaluate,
867        ignore_warnings=True,
868    )
869
870
871if __name__ == "__main__":
872    main()
LM_2D_DATASETS = ['livecell', 'deepbacs', 'tissuenet', 'neurips_cellseg', 'cellpose', 'dynamicnuclearnet', 'orgasegment', 'yeaz', 'arvidsson', 'bitdepth_nucseg', 'cellbindb', 'covid_if', 'deepseas', 'hpa', 'ifnuclei', 'lizard', 'organoidnet', 'toiam', 'vicar']
LM_3D_DATASETS = ['plantseg_root', 'plantseg_ovules', 'gonuclear', 'mouse_embryo', 'cellseg3d']
EM_2D_DATASETS = ['mitolab_tem']
EM_3D_DATASETS = ['mitoem_rat', 'mitoem_human', 'platynereis_nuclei', 'lucchi', 'mitolab', 'nuc_mm_mouse', 'num_mm_zebrafish', 'uro_cell', 'sponge_em', 'platynereis_cilia', 'vnc', 'asem_mito']
DATASET_RETURNS_FOLDER = {'deepbacs': '*.tif'}
DATASET_CONTAINER_KEYS = {'tissuenet': ['raw/rgb', 'labels/cell'], 'covid_if': ['raw/serum_IgG/s0', 'labels/cells/s0'], 'dynamicnuclearnet': ['raw', 'labels'], 'hpa': [['raw/protein', 'raw/microtubules', 'raw/er'], 'labels'], 'lizard': ['image', 'labels/segmentation'], 'plantseg_root': ['raw', 'label'], 'plantseg_ovules': ['raw', 'label'], 'gonuclear': ['raw/nuclei', 'labels/nuclei'], 'mouse_embryo': ['raw', 'label'], 'cellseg_3d': [None, None], 'lucchi': ['raw', 'labels']}
def run_benchmark_evaluations( input_folder: Union[os.PathLike, str], dataset_choice: str, model_type: str = 'vit_b_lm', output_folder: Union[os.PathLike, str, NoneType] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, run_amg: bool = False, retain: Optional[List[str]] = None, evaluation_methods: Literal['all', 'automatic', 'interactive'] = 'all', ignore_warnings: bool = False):
723def run_benchmark_evaluations(
724    input_folder: Union[os.PathLike, str],
725    dataset_choice: str,
726    model_type: str = util._DEFAULT_MODEL,
727    output_folder: Optional[Union[str, os.PathLike]] = None,
728    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
729    run_amg: bool = False,
730    retain: Optional[List[str]] = None,
731    evaluation_methods: Literal["all", "automatic", "interactive"] = "all",
732    ignore_warnings: bool = False,
733):
734    """Run evaluation for benchmarking Segment Anything models on microscopy datasets.
735
736    Args:
737        input_folder: The path to directory where all inputs will be stored and preprocessed.
738        dataset_choice: The dataset choice.
739        model_type: The model choice for SAM.
740        output_folder: The path to directory where all outputs will be stored.
741        checkpoint_path: The checkpoint path
742        run_amg: Whether to run automatic segmentation in AMG mode.
743        retain: Whether to retain certain parts of the benchmark runs.
744            By default, removes everything besides quantitative results.
745            There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
746        evaluation_methods: The choice of evaluation methods.
747            By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive').
748            Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
749        ignore_warnings: Whether to ignore warnings.
750    """
751    start = time.time()
752
753    with _filter_warnings(ignore_warnings):
754        device = util._get_default_device()
755
756        # Ensure if all the datasets have been installed by default.
757        dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice)
758
759        for choice in dataset_choice:
760            output_folder = os.path.join(output_folder, choice)
761            result_dir = os.path.join(output_folder, "results")
762            os.makedirs(result_dir, exist_ok=True)
763
764            data_path = os.path.join(input_folder, choice)
765
766            # Extrapolate desired set from the datasets:
767            # a. for 2d datasets - 2d patches with the most number of labels present
768            #    (in case of volumetric data, choose 2d patches per slice).
769            # b. for 3d datasets - 3d regions of interest with the most number of labels present.
770            ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10)
771
772            # Run inference and evaluation scripts on benchmark datasets.
773            image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim)
774            _run_benchmark_evaluation_series(
775                image_paths=image_paths,
776                gt_paths=gt_paths,
777                model_type=model_type,
778                output_folder=output_folder,
779                ndim=ndim,
780                device=device,
781                checkpoint_path=checkpoint_path,
782                run_amg=run_amg,
783                evaluation_methods=evaluation_methods,
784            )
785
786            # Run inference and evaluation scripts on '2d' crops for volumetric datasets
787            if ndim == 3:
788                image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2)
789                _run_benchmark_evaluation_series(
790                    image_paths=image_paths,
791                    gt_paths=gt_paths,
792                    model_type=model_type,
793                    output_folder=output_folder,
794                    ndim=2,
795                    device=device,
796                    checkpoint_path=checkpoint_path,
797                    run_amg=run_amg,
798                    evaluation_methods=evaluation_methods,
799                )
800
801            _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder)
802
803    diff = time.time() - start
804    hours, rest = divmod(diff, 3600)
805    minutes, seconds = divmod(rest, 60)
806    print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s")

Run evaluation for benchmarking Segment Anything models on microscopy datasets.

Arguments:
  • input_folder: The path to directory where all inputs will be stored and preprocessed.
  • dataset_choice: The dataset choice.
  • model_type: The model choice for SAM.
  • output_folder: The path to directory where all outputs will be stored.
  • checkpoint_path: The checkpoint path
  • run_amg: Whether to run automatic segmentation in AMG mode.
  • retain: Whether to retain certain parts of the benchmark runs. By default, removes everything besides quantitative results. There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
  • evaluation_methods: The choice of evaluation methods. By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
  • ignore_warnings: Whether to ignore warnings.