micro_sam.evaluation.benchmark_datasets

  1import os
  2import time
  3from glob import glob
  4from tqdm import tqdm
  5from natsort import natsorted
  6from typing import Union, Optional, List, Literal
  7
  8import numpy as np
  9import pandas as pd
 10import imageio.v3 as imageio
 11from skimage.measure import label as connected_components
 12
 13from nifty.tools import blocking
 14
 15import torch
 16
 17from torch_em.data import datasets
 18
 19from micro_sam import util
 20
 21from . import run_evaluation
 22from ..training.training import _filter_warnings
 23from .inference import run_inference_with_iterative_prompting
 24from .evaluation import run_evaluation_for_iterative_prompting
 25from .multi_dimensional_segmentation import segment_slices_from_ground_truth
 26from ..automatic_segmentation import (
 27    automatic_instance_segmentation, get_predictor_and_segmenter, DEFAULT_SEGMENTATION_MODE_WITH_DECODER,
 28)
 29
 30
 31LM_2D_DATASETS = [
 32    # in-domain
 33    "livecell",  # cell segmentation in PhC (has a TEST-set)
 34    "deepbacs",  # bacteria segmentation in label-free microscopy (has a TEST-set),
 35    "tissuenet",  # cell segmentation in tissue microscopy images (has a TEST-set),
 36    "neurips_cellseg",  # cell segmentation in various (has a TEST-set),
 37    "cellpose",  # cell segmentation in FM (has 'cyto2' on which we can TEST on),
 38    "dynamicnuclearnet",  # nuclei segmentation in FM (has a TEST-set)
 39    "orgasegment",  # organoid segmentation in BF (has a TEST-set)
 40    "yeaz",  # yeast segmentation in BF (has a TEST-set)
 41
 42    # out-of-domain
 43    "arvidsson",  # nuclei segmentation in HCS FM
 44    "bitdepth_nucseg",  # nuclei segmentation in FM
 45    "cellbindb",  # cell segmentation in various microscopy
 46    "covid_if",  # cell segmentation in IF
 47    "deepseas",  # cell segmentation in PhC,
 48    "hpa",  # cell segmentation in confocal,
 49    "ifnuclei",  # nuclei segmentation in IFM
 50    "lizard",  # nuclei segmentation in H&E histopathology,
 51    "organoidnet",  # organoid segmentation in BF
 52    "toiam",  # microbial cell segmentation in PhC
 53    "vicar",  # cell segmentation in label-free
 54]
 55
 56LM_3D_DATASETS = [
 57    # in-domain
 58    "plantseg_root",  # cell segmentation in lightsheet (has a TEST-set)
 59
 60    # out-of-domain
 61    "plantseg_ovules",  # cell segmentation in confocal
 62    "gonuclear",  # nuclei segmentation in FM
 63    "mouse_embryo",  # cell segmentation in lightsheet
 64    "cellseg3d",  # nuclei segmentation in FM
 65]
 66
 67EM_2D_DATASETS = ["mitolab_tem"]
 68
 69EM_3D_DATASETS = [
 70    # out-of-domain
 71    "lucchi",  # mitochondria segmentation in vEM
 72    "mitolab",  # mitochondria segmentation in various
 73    "uro_cell",  # mitochondria segmentation (and other organelles) in FIB-SEM
 74    "sponge_em",  # microvili segmentation (and other organelles) in sponge chamber vEM
 75    "vnc",  # mitochondria segmentation in drosophila brain TEM
 76    "nuc_mm_mouse",  # nuclei segmentation in microCT
 77    "num_mm_zebrafish",  # nuclei segmentation in EM
 78    "platynereis_cilia",  # cilia segmentation (and other structures) in platynereis larvae vEM
 79    "asem_mito",  # mitochondria segmentation (and other organelles) in FIB-SEM
 80]
 81
 82DATASET_RETURNS_FOLDER = {
 83    "deepbacs": "*.tif",
 84    "mitolab_tem": "*.tiff"
 85}
 86
 87DATASET_CONTAINER_KEYS = {
 88    # 2d (LM)
 89    "tissuenet": ["raw/rgb", "labels/cell"],
 90    "covid_if": ["raw/serum_IgG/s0", "labels/cells/s0"],
 91    "dynamicnuclearnet": ["raw", "labels"],
 92    "hpa": [["raw/protein", "raw/microtubules", "raw/er"], "labels"],
 93    "lizard": ["image", "labels/segmentation"],
 94
 95    # 3d (LM)
 96    "plantseg_root": ["raw", "label"],
 97    "plantseg_ovules": ["raw", "label"],
 98    "gonuclear": ["raw/nuclei", "labels/nuclei"],
 99    "mouse_embryo": ["raw", "label"],
100    "cellseg_3d": [None, None],
101
102    # 3d (EM)
103    "lucchi": ["raw", "labels"],
104    "uro_cell": ["raw", "labels/mito"],
105    "mitolab_3d": [None, None],
106    "sponge_em": ["volumes/raw", "volumes/labels/instances"],
107    "vnc": ["raw", "labels/mitochondria"]
108}
109
110
111def _download_benchmark_datasets(path, dataset_choice):
112    """Ensures whether all the datasets have been downloaded or not.
113
114    Args:
115        path: The path to directory where the supported datasets will be downloaded
116            for benchmarking Segment Anything models.
117        dataset_choice: The choice of dataset, expects the lower case name for the dataset.
118
119    Returns:
120        List of choice of dataset(s).
121    """
122    available_datasets = {
123        # Light Microscopy datasets
124
125        # 2d datasets: in-domain
126        "livecell": lambda: datasets.livecell.get_livecell_data(
127            path=os.path.join(path, "livecell"), download=True,
128        ),
129        "deepbacs": lambda: datasets.deepbacs.get_deepbacs_data(
130            path=os.path.join(path, "deepbacs"), bac_type="mixed", download=True,
131        ),
132        "tissuenet": lambda: datasets.tissuenet.get_tissuenet_data(
133            path=os.path.join(path, "tissuenet"), split="test", download=True,
134        ),
135        "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_data(
136            root=os.path.join(path, "neurips_cellseg"), split="test", download=True,
137        ),
138        "cellpose": lambda: datasets.cellpose.get_cellpose_data(
139            path=os.path.join(path, "cellpose"), split="train", choice="cyto2", download=True,
140        ),
141        "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_data(
142            path=os.path.join(path, "dynamicnuclearnet"), split="test", download=True,
143        ),
144        "orgasegment": lambda: datasets.orgasegment.get_orgasegment_data(
145            path=os.path.join(path, "orgasegment"), split="eval", download=True,
146        ),
147        "yeaz": lambda: datasets.yeaz.get_yeaz_data(
148            path=os.path.join(path, "yeaz"), choice="bf", download=True,
149        ),
150
151        # 2d datasets: out-of-domain
152        "arvidsson": lambda: datasets.arvidsson.get_arvidsson_data(
153            path=os.path.join(path, "arvidsson"), split="test", download=True,
154        ),
155        "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_data(
156            path=os.path.join(path, "bitdepth_nucseg"), download=True,
157        ),
158        "cellbindb": lambda: datasets.cellbindb.get_cellbindb_data(
159            path=os.path.join(path, "cellbindb"), download=True,
160        ),
161        "covid_if": lambda: datasets.covid_if.get_covid_if_data(
162            path=os.path.join(path, "covid_if"), download=True,
163        ),
164        "deepseas": lambda: datasets.deepseas.get_deepseas_data(
165            path=os.path.join(path, "deepseas"), split="test",
166        ),
167        "hpa": lambda: datasets.hpa.get_hpa_segmentation_data(
168            path=os.path.join(path, "hpa"), download=True,
169        ),
170        "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_data(
171            path=os.path.join(path, "ifnuclei"), download=True,
172        ),
173        "lizard": lambda: datasets.lizard.get_lizard_data(
174            path=os.path.join(path, "lizard"), split="test", download=True,
175        ),
176        "organoidnet": lambda: datasets.organoidnet.get_organoidnet_data(
177            path=os.path.join(path, "organoidnet"), split="Test", download=True,
178        ),
179        "toiam": lambda: datasets.toiam.get_toiam_data(
180            path=os.path.join(path, "toiam"), download=True,
181        ),
182        "vicar": lambda: datasets.vicar.get_vicar_data(
183            path=os.path.join(path, "vicar"), download=True,
184        ),
185
186        # 3d datasets: in-domain
187        "plantseg_root": lambda: datasets.plantseg.get_plantseg_data(
188            path=os.path.join(path, "plantseg_root"), split="test", download=True, name="root",
189        ),
190
191        # 3d datasets: out-of-domain
192        "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_data(
193            path=os.path.join(path, "plantseg_ovules"), split="test", download=True, name="ovules",
194        ),
195        "gonuclear": lambda: datasets.gonuclear.get_gonuclear_data(
196            path=os.path.join(path, "gonuclear"), download=True,
197        ),
198        "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_data(
199            path=os.path.join(path, "mouse_embryo"), download=True,
200        ),
201        "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_data(
202            path=os.path.join(path, "cellseg_3d"), download=True,
203        ),
204
205        # Electron Microscopy datasets
206
207        # 2d datasets: out-of-domain
208        "mitolab_tem": lambda: datasets.cem.get_benchmark_data(
209            path=os.path.join(path, "mitolab"), dataset_id=7, download=True
210        ),
211
212        # 3d datasets: out-of-domain
213        "lucchi": lambda: datasets.lucchi.get_lucchi_data(
214            path=os.path.join(path, "lucchi"), split="test", download=True,
215        ),
216        "mitolab_3d": lambda: [
217            datasets.cem.get_benchmark_data(
218                path=os.path.join(path, "mitolab"), dataset_id=dataset_id, download=True,
219            ) for dataset_id in range(1, 7)
220        ],
221        "uro_cell": lambda: datasets.uro_cell.get_uro_cell_data(
222            path=os.path.join(path, "uro_cell"), download=True,
223        ),
224        "vnc": lambda: datasets.vnc.get_vnc_data(
225            path=os.path.join(path, "vnc"), download=True,
226        ),
227        "sponge_em": lambda: datasets.sponge_em.get_sponge_em_data(
228            path=os.path.join(path, "sponge_em"), download=True,
229        ),
230        "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_data(
231            path=os.path.join(path, "nuc_mm"), sample="mouse", download=True,
232        ),
233        "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_data(
234            path=os.path.join(path, "nuc_mm"), sample="zebrafish", download=True,
235        ),
236        "asem_mito": lambda: datasets.asem.get_asem_data(
237            path=os.path.join(path, "asem"), volume_ids=datasets.asem.ORGANELLES["mito"], download=True,
238        ),
239        "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_data(
240            path=os.path.join(path, "platynereis"), name="cilia", download=True,
241        ),
242    }
243
244    if dataset_choice is None:
245        dataset_choice = available_datasets.keys()
246    else:
247        if not isinstance(dataset_choice, list):
248            dataset_choice = [dataset_choice]
249
250    for choice in dataset_choice:
251        if choice in available_datasets:
252            available_datasets[choice]()
253        else:
254            raise ValueError(f"'{choice}' is not a supported choice of dataset.")
255
256    return dataset_choice
257
258
259def _extract_slices_from_dataset(path, dataset_choice, crops_per_input=10):
260    """Extracts crops of desired shapes for performing evaluation in both 2d and 3d using `micro-sam`.
261
262    Args:
263        path: The path to directory where the supported datasets have be downloaded
264            for benchmarking Segment Anything models.
265        dataset_choice: The name of the dataset of choice to extract crops.
266        crops_per_input: The maximum number of crops to extract per inputs.
267        extract_2d: Whether to extract 2d crops from 3d patches.
268
269    Returns:
270        Filepath to the folder where extracted images are stored.
271        Filepath to the folder where corresponding extracted labels are stored.
272        The number of dimensions supported by the input.
273    """
274    ndim = 2 if dataset_choice in [*LM_2D_DATASETS, *EM_2D_DATASETS] else 3
275    tile_shape = (512, 512) if ndim == 2 else (32, 512, 512)
276
277    # For 3d inputs, we extract both 2d and 3d crops.
278    extract_2d_crops_from_volumes = (ndim == 3)
279
280    available_datasets = {
281        # Light Microscopy datasets
282
283        # 2d: in-domain
284        "livecell": lambda: datasets.livecell.get_livecell_paths(path=path, split="test"),
285        "deepbacs": lambda: datasets.deepbacs.get_deepbacs_paths(path=path, split="test", bac_type="mixed"),
286        "tissuenet": lambda: datasets.tissuenet.get_tissuenet_paths(path=path, split="test"),
287        "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_paths(root=path, split="test"),
288        "cellpose": lambda: datasets.cellpose.get_cellpose_paths(path=path, split="train", choice="cyto2"),
289        "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_paths(path=path, split="test"),
290        "orgasegment": lambda: datasets.orgasegment.get_orgasegment_paths(path=path, split="eval"),
291        "yeaz": lambda: datasets.yeaz.get_yeaz_paths(path=path, choice="bf", split="test"),
292
293        # 2d: out-of-domain
294        "arvidsson": lambda: datasets.arvidsson.get_arvidsson_paths(path=path, split="test"),
295        "bitdepth_nucseg": lambda: datasets.bitdepth_nucseg.get_bitdepth_nucseg_paths(path=path, magnification="20x"),
296        "cellbindb": lambda: datasets.cellbindb.get_cellbindb_paths(
297            path=path, data_choice=["10×Genomics_DAPI", "DAPI", "mIF"]
298        ),
299        "covid_if": lambda: datasets.covid_if.get_covid_if_paths(path=path),
300        "deepseas": lambda: datasets.deepseas.get_deepseas_paths(path=path, split="test"),
301        "hpa": lambda: datasets.hpa.get_hpa_segmentation_paths(path=path, split="val"),
302        "ifnuclei": lambda: datasets.ifnuclei.get_ifnuclei_paths(path=path),
303        "lizard": lambda: datasets.lizard.get_lizard_paths(path=path, split="test"),
304        "organoidnet": lambda: datasets.organoidnet.get_organoidnet_paths(path=path, split="Test"),
305        "toiam": lambda: datasets.toiam.get_toiam_paths(path=path),
306        "vicar": lambda: datasets.vicar.get_vicar_paths(path=path),
307
308        # 3d: in-domain
309        "plantseg_root": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="root", split="test"),
310
311        # 3d: out-of-domain
312        "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="ovules", split="test"),
313        "gonuclear": lambda: datasets.gonuclear.get_gonuclear_paths(path=path),
314        "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_paths(path=path, name="nuclei", split="val"),
315        "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_paths(path=path),
316
317        # Electron Microscopy datasets
318
319        # 2d: out-of-domain
320        "mitolab_tem": lambda: datasets.cem.get_benchmark_paths(
321            path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=7
322        )[:2],
323
324        # 3d: out-of-domain"lucchi": lambda: datasets.lucchi.get_lucchi_paths(path=path, split="test"),
325        "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="cilia"),
326        "uro_cell": lambda: datasets.uro_cell.get_uro_cell_paths(path=path, target="mito"),
327        "vnc": lambda: datasets.vnc.get_vnc_mito_paths(path=path),
328        "sponge_em": lambda: datasets.sponge_em.get_sponge_em_paths(path=path, sample_ids=None),
329        "mitolab_3d": lambda: (
330            [
331                datasets.cem.get_benchmark_paths(
332                    path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=i
333                )[0] for i in range(1, 7)
334            ],
335            [
336                datasets.cem.get_benchmark_paths(
337                    path=os.path.join(os.path.dirname(path), "mitolab"), dataset_id=i
338                )[1] for i in range(1, 7)
339            ]
340        ),
341        "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="mouse", split="val"),
342        "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="zebrafish", split="val"),
343        "asem_mito": lambda: datasets.asem.get_asem_paths(path=path, volume_ids=datasets.asem.ORGANELLES["mito"])
344    }
345
346    if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice in ["cellseg_3d", "mitolab_3d"]:
347        image_paths, gt_paths = available_datasets[dataset_choice]()
348
349        if dataset_choice in DATASET_RETURNS_FOLDER:
350            image_paths = glob(os.path.join(image_paths, DATASET_RETURNS_FOLDER[dataset_choice]))
351            gt_paths = glob(os.path.join(gt_paths, DATASET_RETURNS_FOLDER[dataset_choice]))
352
353        image_paths, gt_paths = natsorted(image_paths), natsorted(gt_paths)
354        assert len(image_paths) == len(gt_paths)
355
356        paths_set = zip(image_paths, gt_paths)
357
358    else:
359        image_paths = available_datasets[dataset_choice]()
360        if isinstance(image_paths, str):
361            paths_set = [image_paths]
362        else:
363            paths_set = natsorted(image_paths)
364
365    # Directory where we store the extracted ROIs.
366    save_image_dir = [os.path.join(path, f"roi_{ndim}d", "inputs")]
367    save_gt_dir = [os.path.join(path, f"roi_{ndim}d", "labels")]
368    if extract_2d_crops_from_volumes:
369        save_image_dir.append(os.path.join(path, "roi_2d", "inputs"))
370        save_gt_dir.append(os.path.join(path, "roi_2d", "labels"))
371
372    _dir_exists = [os.path.exists(idir) and os.path.exists(gdir) for idir, gdir in zip(save_image_dir, save_gt_dir)]
373    if all(_dir_exists):
374        return ndim
375
376    [os.makedirs(idir, exist_ok=True) for idir in save_image_dir]
377    [os.makedirs(gdir, exist_ok=True) for gdir in save_gt_dir]
378
379    # Logic to extract relevant patches for inference
380    image_counter = 1
381    for per_paths in tqdm(paths_set, desc=f"Extracting {ndim}d patches for {dataset_choice}"):
382        if (ndim == 2 and dataset_choice not in DATASET_CONTAINER_KEYS) or dataset_choice in ["cellseg_3d", "mitolab_3d"]:  # noqa
383            image_path, gt_path = per_paths
384            image, gt = util.load_image_data(image_path), util.load_image_data(gt_path)
385
386        else:
387            image_path = per_paths
388            gt = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][1])
389            if dataset_choice == "hpa":
390                # Get inputs per channel and stack them together to make the desired 3 channel image.
391                image = np.stack(
392                    [util.load_image_data(image_path, k) for k in DATASET_CONTAINER_KEYS[dataset_choice][0]], axis=0,
393                )
394                # Resize inputs to desired tile shape, in favor of working with the shape of foreground.
395                from torch_em.transform.generic import ResizeLongestSideInputs
396                raw_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_rgb=True)
397                label_transform = ResizeLongestSideInputs(target_shape=tile_shape, is_label=True)
398                image, gt = raw_transform(image).transpose(1, 2, 0), label_transform(gt)
399
400            else:
401                image = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][0])
402
403        if dataset_choice in ["tissuenet", "lizard"]:
404            if image.ndim == 3 and image.shape[0] == 3:  # Make channels last for tissuenet RGB-style images.
405                image = image.transpose(1, 2, 0)
406
407        # Allow RGBs to stay as it is with channels last
408        if image.ndim == 3 and image.shape[-1] == 3:
409            skip_smaller_shape = (np.array(image.shape) >= np.array((*tile_shape, 3))).all()
410        else:
411            skip_smaller_shape = (np.array(image.shape) >= np.array(tile_shape)).all()
412
413        # Ensure ground truth has instance labels.
414        gt = connected_components(gt)
415
416        if len(np.unique(gt)) == 1:  # There could be labels which does not have any annotated foreground.
417            continue
418
419        # Let's extract and save all the crops.
420        # The first round of extraction is always to match the desired input dimensions.
421        image_crops, gt_crops = _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input)
422        image_counter = _save_image_label_crops(
423            image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir[0], save_gt_dir[0]
424        )
425
426        # The next round of extraction is to get 2d crops from 3d inputs.
427        if extract_2d_crops_from_volumes:
428            curr_tile_shape = tile_shape[1:]  # We expect 2d tile shape for this stage.
429
430            curr_image_crops, curr_gt_crops = [], []
431            for per_z_im, per_z_gt in zip(image, gt):
432                curr_skip_smaller_shape = (np.array(per_z_im.shape) >= np.array(curr_tile_shape)).all()
433
434                image_crops, gt_crops = _get_crops_for_input(
435                    image=per_z_im, gt=per_z_gt, ndim=2,
436                    tile_shape=curr_tile_shape,
437                    skip_smaller_shape=curr_skip_smaller_shape,
438                    crops_per_input=crops_per_input,
439                )
440                curr_image_crops.extend(image_crops)
441                curr_gt_crops.extend(gt_crops)
442
443            image_counter = _save_image_label_crops(
444                curr_image_crops, curr_gt_crops, dataset_choice, 2, image_counter, save_image_dir[1], save_gt_dir[1]
445            )
446
447    return ndim
448
449
450def _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input):
451    tiling = blocking([0] * ndim, gt.shape, tile_shape)
452    n_tiles = tiling.numberOfBlocks
453    tiles = [tiling.getBlock(tile_id) for tile_id in range(n_tiles)]
454    crop_boxes = [
455        tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) for tile in tiles
456    ]
457    n_ids = [idx for idx in range(len(crop_boxes))]
458    n_instances = [len(np.unique(gt[crop])) for crop in crop_boxes]
459
460    # Extract the desired number of patches with higher number of instances.
461    image_crops, gt_crops = [], []
462    for i, (per_n_instance, per_id) in enumerate(sorted(zip(n_instances, n_ids), reverse=True), start=1):
463        crop_box = crop_boxes[per_id]
464        crop_image, crop_gt = image[crop_box], gt[crop_box]
465
466        # NOTE: We avoid using the crops which do not match the desired tile shape.
467        _rtile_shape = (*tile_shape, 3) if image.ndim == 3 and image.shape[-1] == 3 else tile_shape  # For RGB images.
468        if skip_smaller_shape and crop_image.shape != _rtile_shape:
469            continue
470
471        # NOTE: There could be a case where some later patches are invalid.
472        if per_n_instance == 1:
473            break
474
475        image_crops.append(crop_image)
476        gt_crops.append(crop_gt)
477
478        # NOTE: If the number of patches extracted have been fulfiled, we stop sampling patches.
479        if len(image_crops) > 0 and i >= crops_per_input:
480            break
481
482    return image_crops, gt_crops
483
484
485def _save_image_label_crops(image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir, save_gt_dir):
486    for image_crop, gt_crop in tqdm(
487        zip(image_crops, gt_crops), total=len(image_crops), desc=f"Saving {ndim}d crops for {dataset_choice}"
488    ):
489        fname = f"{dataset_choice}_{image_counter:05}.tif"
490
491        if image_crop.ndim == 3 and image_crop.shape[-1] == 3:
492            assert image_crop.shape[:2] == gt_crop.shape
493        else:
494            assert image_crop.shape == gt_crop.shape
495
496        imageio.imwrite(os.path.join(save_image_dir, fname), image_crop, compression="zlib")
497        imageio.imwrite(os.path.join(save_gt_dir, fname), gt_crop, compression="zlib")
498
499        image_counter += 1
500
501    return image_counter
502
503
504def _get_image_label_paths(path, ndim):
505    image_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "inputs", "*")))
506    gt_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "labels", "*")))
507    return image_paths, gt_paths
508
509
510def _run_automatic_segmentation_per_dataset(
511    image_paths: List[Union[os.PathLike, str]],
512    gt_paths: List[Union[os.PathLike, str]],
513    model_type: str,
514    output_folder: Union[os.PathLike, str],
515    ndim: Optional[int] = None,
516    device: Optional[Union[torch.device, str]] = None,
517    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
518    segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = "ais",
519    **auto_seg_kwargs
520):
521    """Functionality to run automatic segmentation for multiple input files at once.
522    It stores the evaluated automatic segmentation results (quantitative).
523
524    Args:
525        image_paths: List of filepaths for the input image data.
526        gt_paths: List of filepaths for the corresponding label data.
527        model_type: The choice of image encoder for the Segment Anything model.
528        output_folder: Filepath to the folder where we store all the results.
529        ndim: The number of input dimensions.
530        device: The torch device.
531        checkpoint_path: The filepath where the model checkpoints are stored.
532        segmentation_mode: The mode for automatic segmentation.
533        auto_seg_kwargs: Additional arguments for automatic segmentation parameters.
534    """
535    if segmentation_mode is None:  # The 2nd condition checks if you want AIS and if decoder state exists or not.
536        _, state = util.get_sam_model(
537            model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True
538        )
539        segmentation_mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg"
540
541    fname = f"{segmentation_mode}_{ndim}d"
542
543    result_path = os.path.join(output_folder, "results", f"{fname}.csv")
544    if os.path.exists(result_path):
545        return
546
547    prediction_dir = os.path.join(output_folder, fname, "inference")
548    os.makedirs(prediction_dir, exist_ok=True)
549
550    # Get the predictor (and the additional instance segmentation decoder, if available).
551    predictor, segmenter = get_predictor_and_segmenter(
552        model_type=model_type, checkpoint=checkpoint_path, device=device,
553        segmentation_mode=segmentation_mode, is_tiled=False,
554    )
555
556    for image_path in tqdm(image_paths, desc=f"Run {segmentation_mode} in {ndim}d"):
557        output_path = os.path.join(prediction_dir, os.path.basename(image_path))
558        if os.path.exists(output_path):
559            continue
560
561        # Run Automatic Segmentation (AMG and AIS)
562        automatic_instance_segmentation(
563            predictor=predictor,
564            segmenter=segmenter,
565            input_path=image_path,
566            output_path=output_path,
567            ndim=ndim,
568            verbose=False,
569            **auto_seg_kwargs
570        )
571
572    prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*")))
573    run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths, save_path=result_path)
574
575
576def _run_interactive_segmentation_per_dataset(
577    image_paths: List[Union[os.PathLike, str]],
578    gt_paths: List[Union[os.PathLike, str]],
579    output_folder: Union[os.PathLike, str],
580    model_type: str,
581    prompt_choice: Literal["box", "points"],
582    device: Optional[Union[torch.device, str]] = None,
583    ndim: Optional[int] = None,
584    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
585    use_masks: bool = False,
586):
587    """Functionality to run interactive segmentation for multiple input files at once.
588    It stores the evaluated interactive segmentation results.
589
590    Args:
591        image_paths: List of filepaths for the input image data.
592        gt_paths: List of filepaths for the corresponding label data.
593        output_folder: Filepath to the folder where we store all the results.
594        model_type: The choice of model type for Segment Anything.
595        prompt_choice: The choice of initial prompts to begin the interactive segmentation.
596        device: The torch device.
597        ndim: The number of input dimensions.
598        checkpoint_path: The filepath for stored checkpoints.
599        use_masks: Whether to use masks for iterative prompting.
600    """
601    if ndim == 2:
602        # Get the Segment Anything predictor.
603        predictor = util.get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path)
604
605        prediction_root = os.path.join(
606            output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}",
607            "iterative_prompting_" + ("with_masks" if use_masks else "without_masks")
608        )
609
610        # Run interactive instance segmentation
611        # (starting with box and points followed by iterative prompt-based correction)
612        run_inference_with_iterative_prompting(
613            predictor=predictor,
614            image_paths=image_paths,
615            gt_paths=gt_paths,
616            embedding_dir=None,  # We set this to None to compute embeddings on-the-fly.
617            prediction_dir=prediction_root,
618            start_with_box_prompt=(prompt_choice == "box"),
619            use_masks=use_masks,
620            # TODO: add parameter for deform over box prompts (to simulate prompts in practice).
621        )
622
623        # Evaluate the interactive instance segmentation.
624        run_evaluation_for_iterative_prompting(
625            gt_paths=gt_paths,
626            prediction_root=prediction_root,
627            experiment_folder=output_folder,
628            start_with_box_prompt=(prompt_choice == "box"),
629            use_masks=use_masks,
630        )
631
632    else:
633        save_path = os.path.join(output_folder, "results", f"interactive_segmentation_3d_with_{prompt_choice}.csv")
634        if os.path.exists(save_path):
635            print(
636                f"Results for 3d interactive segmentation with '{prompt_choice}' are already stored at '{save_path}'."
637            )
638            return
639
640        results = []
641        for image_path, gt_path in tqdm(
642            zip(image_paths, gt_paths), total=len(image_paths),
643            desc=f"Run interactive segmentation in 3d with '{prompt_choice}'"
644        ):
645            prediction_dir = os.path.join(output_folder, "interactive_segmentation_3d", f"{prompt_choice}")
646            os.makedirs(prediction_dir, exist_ok=True)
647
648            prediction_path = os.path.join(prediction_dir, os.path.basename(image_path))
649            if os.path.exists(prediction_path):
650                continue
651
652            per_vol_result = segment_slices_from_ground_truth(
653                volume=imageio.imread(image_path),
654                ground_truth=imageio.imread(gt_path),
655                model_type=model_type,
656                checkpoint_path=checkpoint_path,
657                save_path=prediction_path,
658                device=device,
659                interactive_seg_mode=prompt_choice,
660                min_size=10,
661            )
662            results.append(per_vol_result)
663
664        results = pd.concat(results)
665        results = results.groupby(results.index).mean()
666        results.to_csv(save_path)
667
668
669def _run_benchmark_evaluation_series(
670    image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path,
671    segmentation_mode, evaluation_methods,
672):
673    seg_kwargs = {
674        "image_paths": image_paths,
675        "gt_paths": gt_paths,
676        "output_folder": output_folder,
677        "ndim": ndim,
678        "model_type": model_type,
679        "device": device,
680        "checkpoint_path": checkpoint_path,
681    }
682
683    # Perform:
684    # a. automatic segmentation (supported in both 2d and 3d, wherever relevant)
685    #    The automatic segmentation steps below are configured in a way that AIS has priority (if decoder is found)
686    #    Otherwise, it runs for AMG.
687    #    Next, we check if the user expects to run AMG as well (after the run for AIS).
688
689    if evaluation_methods != "interactive":  # Avoid auto. seg. evaluation for 'interactive'-only run choice.
690        # i. Run automatic segmentation method supported with the SAM model (AMG or AIS).
691        _run_automatic_segmentation_per_dataset(segmentation_mode=None, **seg_kwargs)
692
693        # ii. Run automatic mask generation (AMG).
694        #     NOTE: This would only run if the user wants to. Else by default, it is set to 'False'.
695        _run_automatic_segmentation_per_dataset(segmentation_mode=segmentation_mode, **seg_kwargs)
696
697    if evaluation_methods != "automatic":  # Avoid int. seg. evaluation for 'automatic'-only run choice.
698        # b. Run interactive segmentation (supported in both 2d and 3d, wherever relevant)
699        _run_interactive_segmentation_per_dataset(prompt_choice="box", **seg_kwargs)
700        _run_interactive_segmentation_per_dataset(prompt_choice="box", use_masks=True, **seg_kwargs)
701        _run_interactive_segmentation_per_dataset(prompt_choice="points", **seg_kwargs)
702        _run_interactive_segmentation_per_dataset(prompt_choice="points", use_masks=True, **seg_kwargs)
703
704
705def _clear_cached_items(retain, path, output_folder):
706    import shutil
707    from pathlib import Path
708
709    REMOVE_LIST = ["data", "crops", "automatic", "interactive"]
710    if retain is None:
711        remove_list = REMOVE_LIST
712    else:
713        assert isinstance(retain, list)
714        remove_list = set(REMOVE_LIST) - set(retain)
715
716    paths = []
717    # Stage 1: Remove inputs.
718    if "data" in remove_list or "crops" in remove_list:
719        all_paths = glob(os.path.join(path, "*"))
720
721        # In case we want to remove both data and crops, we remove the data folder entirely.
722        if "data" in remove_list and "crops" in remove_list:
723            paths.extend(all_paths)
724            return
725
726        # Next, we verify whether the we only remove either of data or crops.
727        for curr_path in all_paths:
728            if os.path.basename(curr_path).startswith("roi") and "crops" in remove_list:
729                paths.append(curr_path)
730            elif "data" in remove_list:
731                paths.append(curr_path)
732
733    # Stage 2: Remove predictions
734    if "automatic" in remove_list:
735        paths.extend(glob(os.path.join(output_folder, "amg_*")))
736        paths.extend(glob(os.path.join(output_folder, "ais_*")))
737
738    if "interactive" in remove_list:
739        paths.extend(glob(os.path.join(output_folder, "interactive_segmentation_*")))
740
741    [shutil.rmtree(_path) if Path(_path).is_dir() else os.remove(_path) for _path in paths]
742
743
744def run_benchmark_evaluations(
745    input_folder: Union[os.PathLike, str],
746    dataset_choice: str,
747    model_type: str = util._DEFAULT_MODEL,
748    output_folder: Optional[Union[str, os.PathLike]] = None,
749    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
750    segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None,
751    retain: Optional[List[str]] = None,
752    evaluation_methods: Literal["all", "automatic", "interactive"] = "all",
753    ignore_warnings: bool = False,
754):
755    """Run evaluation for benchmarking Segment Anything models on microscopy datasets.
756
757    Args:
758        input_folder: The path to directory where all inputs will be stored and preprocessed.
759        dataset_choice: The dataset choice.
760        model_type: The model choice for SAM.
761        output_folder: The path to directory where all outputs will be stored.
762        checkpoint_path: The checkpoint path
763        segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
764        retain: Whether to retain certain parts of the benchmark runs.
765            By default, removes everything besides quantitative results.
766            There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
767        evaluation_methods: The choice of evaluation methods.
768            By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive').
769            Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
770        ignore_warnings: Whether to ignore warnings.
771    """
772    start = time.time()
773
774    with _filter_warnings(ignore_warnings):
775        device = util._get_default_device()
776
777        # Ensure if all the datasets have been installed by default.
778        dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice)
779
780        for choice in dataset_choice:
781            output_folder = os.path.join(output_folder, choice)
782            result_dir = os.path.join(output_folder, "results")
783            os.makedirs(result_dir, exist_ok=True)
784
785            data_path = os.path.join(input_folder, choice)
786
787            # Extrapolate desired set from the datasets:
788            # a. for 2d datasets - 2d patches with the most number of labels present
789            #    (in case of volumetric data, choose 2d patches per slice).
790            # b. for 3d datasets - 3d regions of interest with the most number of labels present.
791            ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10)
792
793            # Run inference and evaluation scripts on benchmark datasets.
794            image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim)
795            _run_benchmark_evaluation_series(
796                image_paths=image_paths,
797                gt_paths=gt_paths,
798                model_type=model_type,
799                output_folder=output_folder,
800                ndim=ndim,
801                device=device,
802                checkpoint_path=checkpoint_path,
803                segmentation_mode=segmentation_mode,
804                evaluation_methods=evaluation_methods,
805            )
806
807            # Run inference and evaluation scripts on '2d' crops for volumetric datasets
808            if ndim == 3:
809                image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2)
810                _run_benchmark_evaluation_series(
811                    image_paths=image_paths,
812                    gt_paths=gt_paths,
813                    model_type=model_type,
814                    output_folder=output_folder,
815                    ndim=2,
816                    device=device,
817                    checkpoint_path=checkpoint_path,
818                    segmentation_mode=segmentation_mode,
819                    evaluation_methods=evaluation_methods,
820                )
821
822            _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder)
823
824    diff = time.time() - start
825    hours, rest = divmod(diff, 3600)
826    minutes, seconds = divmod(rest, 60)
827    print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s")
828
829
830def main():
831    """@private"""
832    import argparse
833
834    available_models = list(util.get_model_names())
835    available_models = ", ".join(available_models)
836
837    parser = argparse.ArgumentParser(
838        description="Run evaluation for benchmarking Segment Anything models on microscopy datasets."
839    )
840    parser.add_argument(
841        "-i", "--input_folder", type=str, required=True,
842        help="The path to a directory where the microscopy datasets are and/or will be stored."
843    )
844    parser.add_argument(
845        "-m", "--model_type", type=str, default=util._DEFAULT_MODEL,
846        help=f"The segment anything model that will be used, one of '{available_models}'. By default, "
847        f"it uses'{util._DEFAULT_MODEL}'."
848    )
849    parser.add_argument(
850        "-c", "--checkpoint_path", type=str, default=None,
851        help="The filepath to checkpoint from which the SAM model will be loaded."
852    )
853    parser.add_argument(
854        "-d", "--dataset_choice", type=str, nargs='*', default=None,
855        help="The choice(s) of dataset for evaluating SAM models. Multiple datasets can be specified. "
856        "By default, it evaluates on all datasets."
857    )
858    parser.add_argument(
859        "-o", "--output_folder", type=str, required=True,
860        help="The path where the results for (automatic and interactive) instance segmentation results "
861        "will be stored as 'csv' files."
862    )
863    parser.add_argument(
864        "--amg", action="store_true",
865        help="Whether to run automatic segmentation in AMG mode (i.e. the default auto-seg approach for SAM)."
866    )
867    parser.add_argument(
868        "--retain", nargs="*", default=None,
869        help="By default, the functionality removes all besides quantitative results required for running benchmarks. "
870        "In case you would like to retain parts of the benchmark evaluation for visualization / reproducibility, "
871        "you should choose one or multiple of 'data', 'crops', 'automatic', 'interactive'. "
872        "where they are responsible for either retaining original inputs / extracted crops / "
873        "predictions of automatic segmentation / predictions of interactive segmentation, respectively."
874    )
875    parser.add_argument(
876        "--evaluate", type=str, default="all", choices=["all", "automatic", "interactive"],
877        help="The choice of methods for benchmarking evaluation for reproducibility. "
878        "By default, we run all evaluations with 'all'. If 'automatic' is chosen, it runs automatic segmentation only "
879        "/ 'interactive' runs interactive segmentation (starting from box and single point) with iterative prompting."
880    )
881    args = parser.parse_args()
882
883    run_benchmark_evaluations(
884        input_folder=args.input_folder,
885        dataset_choice=args.dataset_choice,
886        model_type=args.model_type,
887        output_folder=args.output_folder,
888        checkpoint_path=args.checkpoint_path,
889        segmentation_mode=args.segmentation_mode,
890        retain=args.retain,
891        evaluation_methods=args.evaluate,
892        ignore_warnings=True,
893    )
894
895
896if __name__ == "__main__":
897    main()
LM_2D_DATASETS = ['livecell', 'deepbacs', 'tissuenet', 'neurips_cellseg', 'cellpose', 'dynamicnuclearnet', 'orgasegment', 'yeaz', 'arvidsson', 'bitdepth_nucseg', 'cellbindb', 'covid_if', 'deepseas', 'hpa', 'ifnuclei', 'lizard', 'organoidnet', 'toiam', 'vicar']
LM_3D_DATASETS = ['plantseg_root', 'plantseg_ovules', 'gonuclear', 'mouse_embryo', 'cellseg3d']
EM_2D_DATASETS = ['mitolab_tem']
EM_3D_DATASETS = ['lucchi', 'mitolab', 'uro_cell', 'sponge_em', 'vnc', 'nuc_mm_mouse', 'num_mm_zebrafish', 'platynereis_cilia', 'asem_mito']
DATASET_RETURNS_FOLDER = {'deepbacs': '*.tif', 'mitolab_tem': '*.tiff'}
DATASET_CONTAINER_KEYS = {'tissuenet': ['raw/rgb', 'labels/cell'], 'covid_if': ['raw/serum_IgG/s0', 'labels/cells/s0'], 'dynamicnuclearnet': ['raw', 'labels'], 'hpa': [['raw/protein', 'raw/microtubules', 'raw/er'], 'labels'], 'lizard': ['image', 'labels/segmentation'], 'plantseg_root': ['raw', 'label'], 'plantseg_ovules': ['raw', 'label'], 'gonuclear': ['raw/nuclei', 'labels/nuclei'], 'mouse_embryo': ['raw', 'label'], 'cellseg_3d': [None, None], 'lucchi': ['raw', 'labels'], 'uro_cell': ['raw', 'labels/mito'], 'mitolab_3d': [None, None], 'sponge_em': ['volumes/raw', 'volumes/labels/instances'], 'vnc': ['raw', 'labels/mitochondria']}
def run_benchmark_evaluations( input_folder: Union[os.PathLike, str], dataset_choice: str, model_type: str = 'vit_b_lm', output_folder: Union[os.PathLike, str, NoneType] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, segmentation_mode: Optional[Literal['amg', 'ais', 'apg']] = None, retain: Optional[List[str]] = None, evaluation_methods: Literal['all', 'automatic', 'interactive'] = 'all', ignore_warnings: bool = False):
745def run_benchmark_evaluations(
746    input_folder: Union[os.PathLike, str],
747    dataset_choice: str,
748    model_type: str = util._DEFAULT_MODEL,
749    output_folder: Optional[Union[str, os.PathLike]] = None,
750    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
751    segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None,
752    retain: Optional[List[str]] = None,
753    evaluation_methods: Literal["all", "automatic", "interactive"] = "all",
754    ignore_warnings: bool = False,
755):
756    """Run evaluation for benchmarking Segment Anything models on microscopy datasets.
757
758    Args:
759        input_folder: The path to directory where all inputs will be stored and preprocessed.
760        dataset_choice: The dataset choice.
761        model_type: The model choice for SAM.
762        output_folder: The path to directory where all outputs will be stored.
763        checkpoint_path: The checkpoint path
764        segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
765        retain: Whether to retain certain parts of the benchmark runs.
766            By default, removes everything besides quantitative results.
767            There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
768        evaluation_methods: The choice of evaluation methods.
769            By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive').
770            Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
771        ignore_warnings: Whether to ignore warnings.
772    """
773    start = time.time()
774
775    with _filter_warnings(ignore_warnings):
776        device = util._get_default_device()
777
778        # Ensure if all the datasets have been installed by default.
779        dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice)
780
781        for choice in dataset_choice:
782            output_folder = os.path.join(output_folder, choice)
783            result_dir = os.path.join(output_folder, "results")
784            os.makedirs(result_dir, exist_ok=True)
785
786            data_path = os.path.join(input_folder, choice)
787
788            # Extrapolate desired set from the datasets:
789            # a. for 2d datasets - 2d patches with the most number of labels present
790            #    (in case of volumetric data, choose 2d patches per slice).
791            # b. for 3d datasets - 3d regions of interest with the most number of labels present.
792            ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10)
793
794            # Run inference and evaluation scripts on benchmark datasets.
795            image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim)
796            _run_benchmark_evaluation_series(
797                image_paths=image_paths,
798                gt_paths=gt_paths,
799                model_type=model_type,
800                output_folder=output_folder,
801                ndim=ndim,
802                device=device,
803                checkpoint_path=checkpoint_path,
804                segmentation_mode=segmentation_mode,
805                evaluation_methods=evaluation_methods,
806            )
807
808            # Run inference and evaluation scripts on '2d' crops for volumetric datasets
809            if ndim == 3:
810                image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2)
811                _run_benchmark_evaluation_series(
812                    image_paths=image_paths,
813                    gt_paths=gt_paths,
814                    model_type=model_type,
815                    output_folder=output_folder,
816                    ndim=2,
817                    device=device,
818                    checkpoint_path=checkpoint_path,
819                    segmentation_mode=segmentation_mode,
820                    evaluation_methods=evaluation_methods,
821                )
822
823            _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder)
824
825    diff = time.time() - start
826    hours, rest = divmod(diff, 3600)
827    minutes, seconds = divmod(rest, 60)
828    print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {int(seconds)}s")

Run evaluation for benchmarking Segment Anything models on microscopy datasets.

Arguments:
  • input_folder: The path to directory where all inputs will be stored and preprocessed.
  • dataset_choice: The dataset choice.
  • model_type: The model choice for SAM.
  • output_folder: The path to directory where all outputs will be stored.
  • checkpoint_path: The checkpoint path
  • segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
  • retain: Whether to retain certain parts of the benchmark runs. By default, removes everything besides quantitative results. There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'.
  • evaluation_methods: The choice of evaluation methods. By default, runs 'all' evaluation methods (i.e. both 'automatic' or 'interactive'). Otherwise, specify either 'automatic' / 'interactive' for specific evaluation runs.
  • ignore_warnings: Whether to ignore warnings.