micro_sam.evaluation.benchmark_datasets

  1import os
  2import time
  3from glob import glob
  4from tqdm import tqdm
  5from natsort import natsorted
  6from typing import Union, Optional, List, Literal
  7
  8import numpy as np
  9import pandas as pd
 10import imageio.v3 as imageio
 11from skimage.measure import label as connected_components
 12
 13from nifty.tools import blocking
 14
 15import torch
 16
 17from torch_em.data import datasets
 18
 19from micro_sam import util
 20
 21from . import run_evaluation
 22from ..training.training import _filter_warnings
 23from .inference import run_inference_with_iterative_prompting
 24from .evaluation import run_evaluation_for_iterative_prompting
 25from .multi_dimensional_segmentation import segment_slices_from_ground_truth
 26from ..automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter
 27
 28
 29LM_2D_DATASETS = [
 30    "livecell", "deepbacs", "tissuenet", "neurips_cellseg", "dynamicnuclearnet",
 31    "hpa", "covid_if", "pannuke", "lizard", "orgasegment", "omnipose", "dic_hepg2",
 32]
 33
 34LM_3D_DATASETS = [
 35    "plantseg_root", "plantseg_ovules", "gonuclear", "mouse_embryo", "embegseg", "cellseg3d"
 36]
 37
 38EM_2D_DATASETS = ["mitolab_tem"]
 39
 40EM_3D_DATASETS = [
 41    "mitoem_rat", "mitoem_human", "platynereis_nuclei", "lucchi", "mitolab", "nuc_mm_mouse",
 42    "num_mm_zebrafish", "uro_cell", "sponge_em", "platynereis_cilia", "vnc", "asem_mito",
 43]
 44
 45DATASET_RETURNS_FOLDER = {
 46    "deepbacs": "*.tif"
 47}
 48
 49DATASET_CONTAINER_KEYS = {
 50    "lucchi": ["raw", "labels"],
 51}
 52
 53
 54def _download_benchmark_datasets(path, dataset_choice):
 55    """Ensures whether all the datasets have been downloaded or not.
 56
 57    Args:
 58        path: The path to directory where the supported datasets will be downloaded
 59            for benchmarking Segment Anything models.
 60        dataset_choice: The choice of dataset, expects the lower case name for the dataset.
 61
 62    Returns:
 63        List of choice of dataset(s).
 64    """
 65    available_datasets = {
 66        # Light Microscopy datasets
 67        "livecell": lambda: datasets.livecell.get_livecell_data(
 68            path=os.path.join(path, "livecell"), split="test", download=True,
 69        ),
 70        "deepbacs": lambda: datasets.deepbacs.get_deepbacs_data(
 71            path=os.path.join(path, "deepbacs"), bac_type="mixed", download=True,
 72        ),
 73        "tissuenet": lambda: datasets.tissuenet.get_tissuenet_data(
 74            path=os.path.join(path, "tissuenet"), split="test", download=True,
 75        ),
 76        "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_data(
 77            root=os.path.join(path, "neurips_cellseg"), split="test", download=True,
 78        ),
 79        "plantseg_root": lambda: datasets.plantseg.get_plantseg_data(
 80            path=os.path.join(path, "plantseg"), download=True, name="root",
 81        ),
 82        "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_data(
 83            path=os.path.join(path, "plantseg"), download=True, name="ovules",
 84        ),
 85        "covid_if": lambda: datasets.covid_if.get_covid_if_data(
 86            path=os.path.join(path, "covid_if"), download=True,
 87        ),
 88        "hpa": lambda: datasets.hpa.get_hpa_segmentation_data(
 89            path=os.path.join(path, "hpa"), download=True,
 90        ),
 91        "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_data(
 92            path=os.path.join(path, "dynamicnuclearnet"), split="test", download=True,
 93        ),
 94        "pannuke": lambda: datasets.pannuke.get_pannuke_data(
 95            path=os.path.join(path, "pannuke"), download=True, folds=["fold_1", "fold_2", "fold_3"],
 96        ),
 97        "lizard": lambda: datasets.lizard.get_lizard_data(
 98            path=os.path.join(path, "lizard"), download=True,
 99        ),
100        "orgasegment": lambda: datasets.orgasegment.get_orgasegment_data(
101            path=os.path.join(path, "orgasegment"), split="eval", download=True,
102        ),
103        "omnipose": lambda: datasets.omnipose.get_omnipose_data(
104            path=os.path.join(path, "omnipose"), download=True,
105        ),
106        "gonuclear": lambda: datasets.gonuclear.get_gonuclear_data(
107            path=os.path.join(path, "gonuclear"), download=True,
108        ),
109        "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_data(
110            path=os.path.join(path, "mouse_embryo"), download=True,
111        ),
112        "embedseg_data": lambda: [
113            datasets.embedseg_data.get_embedseg_data(path=os.path.join(path, "embedseg_data"), download=True, name=name)
114            for name in datasets.embedseg_data.URLS.keys()
115        ],
116        "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_data(
117            path=os.path.join(path, "cellseg_3d"), download=True,
118        ),
119        "dic_hepg2": lambda: datasets.dic_hepg2.get_dic_hepg2_data(
120            path=os.path.join(path, "dic_hepg2"), download=True,
121        ),
122
123        # Electron Microscopy datasets
124        "mitoem_rat": lambda: datasets.mitoem.get_mitoem_data(
125            path=os.path.join(path, "mitoem"), samples="rat", split="test", download=True,
126        ),
127        "mitoem_human": lambda: datasets.mitoem.get_mitoem_data(
128            path=os.path.join(path, "mitoem"), samples="human", split="test", download=True,
129        ),
130        "platynereis_nuclei": lambda: datasets.platynereis.get_platy_data(
131            path=os.path.join(path, "platynereis"), name="nuclei", download=True,
132        ),
133        "platynereis_cilia": lambda: datasets.platynereis.get_platy_data(
134            path=os.path.join(path, "platynereis"), name="cilia", download=True,
135        ),
136        "lucchi": lambda: datasets.lucchi.get_lucchi_data(
137            path=os.path.join(path, "lucchi"), split="test", download=True,
138        ),
139        "mitolab_3d": lambda: [
140            datasets.cem.get_benchmark_data(
141                path=os.path.join(path, "mitolab"), dataset_id=dataset_id, download=True,
142            ) for dataset_id in range(1, 7)
143        ],
144        "mitolab_tem": lambda: datasets.cem.get_benchmark_data(
145            path=os.path.join(path, "mitolab"), dataset_id=7, download=True
146        ),
147        "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_data(
148            path=os.path.join(path, "nuc_mm"), sample="mouse", download=True,
149        ),
150        "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_data(
151            path=os.path.join(path, "nuc_mm"), sample="zebrafish", download=True,
152        ),
153        "uro_cell": lambda: datasets.uro_cell.get_uro_cell_data(
154            path=os.path.join(path, "uro_cell"), download=True,
155        ),
156        "sponge_em": lambda: datasets.sponge_em.get_sponge_em_data(
157            path=os.path.join(path, "sponge_em"), download=True,
158        ),
159        "vnc": lambda: datasets.vnc.get_vnc_data(
160            path=os.path.join(path, "vnc"), download=True,
161        ),
162        "asem_mito": lambda: datasets.asem.get_asem_data(
163            path=os.path.join(path, "asem"), volume_ids=datasets.asem.ORGANELLES["mito"], download=True,
164        )
165    }
166
167    if dataset_choice is None:
168        dataset_choice = available_datasets.keys()
169    else:
170        if not isinstance(dataset_choice, list):
171            dataset_choice = [dataset_choice]
172
173    for choice in dataset_choice:
174        if choice in available_datasets:
175            available_datasets[choice]()
176        else:
177            raise ValueError(f"'{choice}' is not a supported choice of dataset.")
178
179    return dataset_choice
180
181
182def _extract_slices_from_dataset(path, dataset_choice, crops_per_input=10):
183    """Extracts crops of desired shapes for performing evaluation in both 2d and 3d using `micro-sam`.
184
185    Args:
186        path: The path to directory where the supported datasets have be downloaded
187            for benchmarking Segment Anything models.
188        dataset_choice: The name of the dataset of choice to extract crops.
189        crops_per_input: The maximum number of crops to extract per inputs.
190        extract_2d: Whether to extract 2d crops from 3d patches.
191
192    Returns:
193        Filepath to the folder where extracted images are stored.
194        Filepath to the folder where corresponding extracted labels are stored.
195        The number of dimensions supported by the input.
196    """
197    ndim = 2 if dataset_choice in [*LM_2D_DATASETS, *EM_2D_DATASETS] else 3
198    tile_shape = (512, 512) if ndim == 2 else (32, 512, 512)
199
200    # For 3d inputs, we extract both 2d and 3d crops.
201    extract_2d_crops_from_volumes = (ndim == 3)
202
203    available_datasets = {
204        # Light Microscopy datasets
205        "livecell": lambda: datasets.livecell.get_livecell_paths(path=path, split="test"),
206        "deepbacs": lambda: datasets.deepbacs.get_deepbacs_paths(path=path, split="test", bac_type="mixed"),
207        "tissuenet": lambda: datasets.tissuenet.get_tissuenet_paths(path=path, split="test"),
208        "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_paths(root=path, split="test"),
209        "plantseg_root": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="root", split="test"),
210        "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="ovules", split="test"),
211        "covid_if": lambda: datasets.covid_if.get_covid_if_paths(path=path),
212        "hpa": lambda: datasets.hpa.get_hpa_segmentation_paths(path=path, split="test"),
213        "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_paths(path=path, split="test"),
214        "pannuke": lambda: datasets.pannuke.get_pannuke_paths(path=path),
215        "lizard": lambda: datasets.lizard.get_lizard_paths(parth=path),
216        "orgasegment": lambda: datasets.orgasegment.get_orgasegment_paths(path=path, split="eval"),
217        "omnipose": lambda: datasets.omnipose.get_omnipose_paths(path=path, split="test"),
218        "gonuclear": lambda: datasets.gonuclear.get_gonuclear_paths(path-path),
219        "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_paths(path=path, name="nuclei", split="val"),
220        "embedseg_data": lambda: datasets.embedseg_data.get_embedseg_paths(
221            path=path, name=list(datasets.embedseg_data.URLS.keys())[0], split="test"
222        ),
223        "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_paths(path=path),
224        "dic_hepg2": lambda: datasets.dic_hepg2.get_dic_hepg2_paths(path=path, split="test"),
225
226        # Electron Microscopy datasets
227        "mitoem_rat": lambda: datasets.mitoem.get_mitoem_paths(path=path, splits="test", samples="rat"),
228        "mitem_human": lambda: datasets.mitoem.get_mitoem_paths(path=path, splits="test", samples="human"),
229        "platynereis_nuclei": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="nuclei"),
230        "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="cilia"),
231        "lucchi": lambda: datasets.lucchi.get_lucchi_paths(path=path, split="test"),
232        "mitolab_3d": lambda: (
233            [rpath for i in range(1, 7) for rpath in datasets.cem.get_benchmark_paths(path=path, dataset_id=i)[0]],
234            [lpath for i in range(1, 7) for lpath in datasets.cem.get_benchmark_paths(path=path, dataset_id=i)[1]]
235        ),
236        "mitolab_tem": lambda: datasets.cem.get_benchmark_paths(path=path, dataset_id=7),
237        "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="mouse", split="val"),
238        "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="zebrafish", split="val"),
239        "uro_cell": lambda: datasets.uro_cell.get_uro_cell_paths(path=path, target="mito"),
240        "sponge_em": lambda: datasets.sponge_em.get_sponge_em_paths(path=path, sample_ids=None),
241        "vnc": lambda: datasets.vnc.get_vnc_mito_paths(path=path),
242        "asem_mito": lambda: datasets.asem.get_asem_paths(path=path, volume_ids=datasets.asem.ORGANELLES["mito"])
243    }
244
245    if ndim == 2:
246        image_paths, gt_paths = available_datasets[dataset_choice]()
247
248        if dataset_choice in DATASET_RETURNS_FOLDER:
249            image_paths = glob(os.path.join(image_paths, DATASET_RETURNS_FOLDER[dataset_choice]))
250            gt_paths = glob(os.path.join(gt_paths, DATASET_RETURNS_FOLDER[dataset_choice]))
251
252        image_paths, gt_paths = natsorted(image_paths), natsorted(gt_paths)
253        assert len(image_paths) == len(gt_paths)
254
255        paths_set = zip(image_paths, gt_paths)
256
257    else:
258        image_paths = available_datasets[dataset_choice]()
259        if isinstance(image_paths, str):
260            paths_set = [image_paths]
261        else:
262            paths_set = natsorted(image_paths)
263
264    # Directory where we store the extracted ROIs.
265    save_image_dir = [os.path.join(path, f"roi_{ndim}d", "inputs")]
266    save_gt_dir = [os.path.join(path, f"roi_{ndim}d", "labels")]
267    if extract_2d_crops_from_volumes:
268        save_image_dir.append(os.path.join(path, "roi_2d", "inputs"))
269        save_gt_dir.append(os.path.join(path, "roi_2d", "labels"))
270
271    _dir_exists = [
272         os.path.exists(idir) and os.path.exists(gdir) for idir, gdir in zip(save_image_dir, save_gt_dir)
273    ]
274    if all(_dir_exists):
275        return ndim
276
277    [os.makedirs(idir, exist_ok=True) for idir in save_image_dir]
278    [os.makedirs(gdir, exist_ok=True) for gdir in save_gt_dir]
279
280    # Logic to extract relevant patches for inference
281    image_counter = 1
282    for per_paths in tqdm(paths_set, desc=f"Extracting patches for {dataset_choice}"):
283        if ndim == 2:
284            image_path, gt_path = per_paths
285            image, gt = util.load_image_data(image_path), util.load_image_data(gt_path)
286        else:
287            image_path = per_paths
288            image = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][0])
289            gt = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][1])
290
291        skip_smaller_shape = (np.array(image.shape) >= np.array(tile_shape)).all()
292
293        # Ensure ground truth has instance labels.
294        gt = connected_components(gt)
295
296        if len(np.unique(gt)) == 1:  # There could be labels which does not have any annotated foreground.
297            continue
298
299        # Let's extract and save all the crops.
300        # NOTE: The first round of extraction is always to match the desired input dimensions.
301        image_crops, gt_crops = _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input)
302        image_counter = _save_image_label_crops(
303            image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir[0], save_gt_dir[0]
304        )
305
306        # NOTE: The next round of extraction is to get 2d crops from 3d inputs.
307        if extract_2d_crops_from_volumes:
308            curr_tile_shape = tile_shape[-2:]  # NOTE: We expect 2d tile shape for this stage.
309
310            curr_image_crops, curr_gt_crops = [], []
311            for per_z_im, per_z_gt in zip(image, gt):
312                curr_skip_smaller_shape = (np.array(per_z_im.shape) >= np.array(curr_tile_shape)).all()
313
314                image_crops, gt_crops = _get_crops_for_input(
315                    image=per_z_im, gt=per_z_gt, ndim=2,
316                    tile_shape=curr_tile_shape,
317                    skip_smaller_shape=curr_skip_smaller_shape,
318                    crops_per_input=crops_per_input,
319                )
320                curr_image_crops.extend(image_crops)
321                curr_gt_crops.extend(gt_crops)
322
323            image_counter = _save_image_label_crops(
324                curr_image_crops, curr_gt_crops, dataset_choice, 2, image_counter, save_image_dir[1], save_gt_dir[1]
325            )
326
327    return ndim
328
329
330def _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input):
331    tiling = blocking([0] * ndim, gt.shape, tile_shape)
332    n_tiles = tiling.numberOfBlocks
333    tiles = [tiling.getBlock(tile_id) for tile_id in range(n_tiles)]
334    crop_boxes = [
335        tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) for tile in tiles
336    ]
337    n_ids = [idx for idx in range(len(crop_boxes))]
338    n_instances = [len(np.unique(gt[crop])) for crop in crop_boxes]
339
340    # Extract the desired number of patches with higher number of instances.
341    image_crops, gt_crops = [], []
342    for i, (per_n_instance, per_id) in enumerate(sorted(zip(n_instances, n_ids), reverse=True), start=1):
343        crop_box = crop_boxes[per_id]
344        crop_image, crop_gt = image[crop_box], gt[crop_box]
345        # NOTE: We avoid using the crops which do not match the desired tile shape.
346        if skip_smaller_shape and crop_image.shape != tile_shape:
347            continue
348
349        # NOTE: There could be a case where some later patches are invalid.
350        if per_n_instance == 1:
351            break
352
353        image_crops.append(crop_image)
354        gt_crops.append(crop_gt)
355
356        # NOTE: If the number of patches extracted have been fulfiled, we stop sampling patches.
357        if len(image_crops) > 0 and i >= crops_per_input:
358            break
359
360    return image_crops, gt_crops
361
362
363def _save_image_label_crops(image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir, save_gt_dir):
364    for image_crop, gt_crop in tqdm(
365        zip(image_crops, gt_crops), total=len(image_crops), desc=f"Saving {ndim}d crops for {dataset_choice}"
366    ):
367        fname = f"{dataset_choice}_{image_counter:05}.tif"
368        assert image_crop.shape == gt_crop.shape
369        imageio.imwrite(os.path.join(save_image_dir, fname), image_crop, compression="zlib")
370        imageio.imwrite(os.path.join(save_gt_dir, fname), gt_crop, compression="zlib")
371        image_counter += 1
372
373    return image_counter
374
375
376def _get_image_label_paths(path, ndim):
377    image_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "inputs", "*")))
378    gt_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "labels", "*")))
379    return image_paths, gt_paths
380
381
382def _run_automatic_segmentation_per_dataset(
383    image_paths: List[Union[os.PathLike, str]],
384    gt_paths: List[Union[os.PathLike, str]],
385    model_type: str,
386    output_folder: Union[os.PathLike, str],
387    ndim: Optional[int] = None,
388    device: Optional[Union[torch.device, str]] = None,
389    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
390    run_amg: bool = False,
391    **auto_seg_kwargs
392):
393    """Functionality to run automatic segmentation for multiple input files at once.
394    It stores the evaluated automatic segmentation results (quantitative).
395
396    Args:
397        image_paths: List of filepaths for the input image data.
398        gt_paths: List of filepaths for the corresponding label data.
399        model_type: The choice of image encoder for the Segment Anything model.
400        output_folder: Filepath to the folder where we store all the results.
401        ndim: The number of input dimensions.
402        device: The torch device.
403        checkpoint_path: The filepath where the model checkpoints are stored.
404        run_amg: Whether to run automatic segmentation in AMG mode.
405        auto_seg_kwargs: Additional arguments for automatic segmentation parameters.
406    """
407    experiment_name = "AMG" if run_amg else "AIS"
408    fname = f"{experiment_name.lower()}_{ndim}d"
409
410    result_path = os.path.join(output_folder, "results", f"{fname}.csv")
411    prediction_dir = os.path.join(output_folder, fname, "inference")
412    if os.path.exists(prediction_dir):
413        return
414
415    os.makedirs(prediction_dir, exist_ok=True)
416
417    # Get the predictor (and the additional instance segmentation decoder, if available).
418    predictor, segmenter = get_predictor_and_segmenter(
419        model_type=model_type, checkpoint=checkpoint_path, device=device, amg=run_amg, is_tiled=False,
420    )
421
422    for image_path in tqdm(image_paths, desc=f"Run {experiment_name} in {ndim}d"):
423        output_path = os.path.join(prediction_dir, os.path.basename(image_path))
424        if os.path.exists(output_path):
425            continue
426
427        # Run Automatic Segmentation (AMG and AIS)
428        automatic_instance_segmentation(
429            predictor=predictor,
430            segmenter=segmenter,
431            input_path=image_path,
432            output_path=output_path,
433            ndim=ndim,
434            verbose=False,
435            **auto_seg_kwargs
436        )
437
438    prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*")))
439    run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths, save_path=result_path)
440
441
442def _run_interactive_segmentation_per_dataset(
443    image_paths: List[Union[os.PathLike, str]],
444    gt_paths: List[Union[os.PathLike, str]],
445    output_folder: Union[os.PathLike, str],
446    model_type: str,
447    prompt_choice: Literal["box", "points"],
448    device: Optional[Union[torch.device, str]] = None,
449    ndim: Optional[int] = None,
450    checkpoint_path: Optional[Union[os.PathLike, str]] = None,
451):
452    """Functionality to run interactive segmentation for multiple input files at once.
453    It stores the evaluated interactive segmentation results.
454
455    Args:
456        image_paths: List of filepaths for the input image data.
457        gt_paths: List of filepaths for the corresponding label data.
458        output_folder: Filepath to the folder where we store all the results.
459        model_type: The choice of model type for Segment Anything.
460        prompt_choice: The choice of initial prompts to begin the interactive segmentation.
461        device: The torch device.
462        ndim: The number of input dimensions.
463        checkpoint_path: The filepath for stored checkpoints.
464    """
465    if ndim == 2:
466        # Get the Segment Anything predictor.
467        predictor = util.get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path)
468
469        # Run interactive instance segmentation
470        # (starting with box and points followed by iterative prompt-based correction)
471        run_inference_with_iterative_prompting(
472            predictor=predictor,
473            image_paths=image_paths,
474            gt_paths=gt_paths,
475            embedding_dir=None,  # We set this to None to compute embeddings on-the-fly.
476            prediction_dir=os.path.join(output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}"),
477            start_with_box_prompt=(prompt_choice == "box"),
478            # TODO: add parameter for deform over box prompts (to simulate prompts in practice).
479        )
480
481        # Evaluate the interactive instance segmentation.
482        run_evaluation_for_iterative_prompting(
483            gt_paths=gt_paths,
484            prediction_root=os.path.join(output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}"),
485            experiment_folder=output_folder,
486            start_with_box_prompt=(prompt_choice == "box"),
487        )
488
489    else:
490        save_path = os.path.join(output_folder, "results", f"interactive_segmentation_3d_with_{prompt_choice}.csv")
491        if os.path.exists(save_path):
492            print(
493                f"Results for 3d interactive segmentation with '{prompt_choice}' are already stored at '{save_path}'."
494            )
495            return
496
497        results = []
498        for image_path, gt_path in tqdm(
499            zip(image_paths, gt_paths), total=len(image_paths),
500            desc=f"Run interactive segmentation in 3d with '{prompt_choice}'"
501        ):
502            prediction_dir = os.path.join(output_folder, "interactive_segmentation_3d", f"{prompt_choice}")
503            os.makedirs(prediction_dir, exist_ok=True)
504
505            prediction_path = os.path.join(prediction_dir, os.path.basename(image_path))
506            if os.path.exists(prediction_path):
507                continue
508
509            per_vol_result = segment_slices_from_ground_truth(
510                volume=imageio.imread(image_path),
511                ground_truth=imageio.imread(gt_path),
512                model_type=model_type,
513                checkpoint_path=checkpoint_path,
514                save_path=prediction_path,
515                device=device,
516                interactive_seg_mode=prompt_choice,
517                min_size=10,
518            )
519            results.append(per_vol_result)
520
521        results = pd.concat(results)
522        results = results.groupby(results.index).mean()
523        results.to_csv(save_path)
524
525
526def _run_benchmark_evaluation_series(
527    image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg,
528):
529    seg_kwargs = {
530        "image_paths": image_paths,
531        "gt_paths": gt_paths,
532        "output_folder": output_folder,
533        "ndim": ndim,
534        "model_type": model_type,
535        "device": device,
536        "checkpoint_path": checkpoint_path,
537    }
538
539    # Perform:
540    # a. automatic segmentation (supported in both 2d and 3d, wherever relevant)
541    #    The automatic segmentation steps below are configured in a way that AIS has priority (if decoder is found)
542    #    Else, it runs for AMG.
543    #    Next, we check if the user expects to run AMG as well (after the run for AIS).
544
545    # i. Run automatic segmentation method supported with the SAM model (AMG or AIS).
546    _run_automatic_segmentation_per_dataset(run_amg=False, **seg_kwargs)
547
548    # ii. Run automatic mask generation (AMG) (in case the first run is AIS).
549    _run_automatic_segmentation_per_dataset(run_amg=run_amg, **seg_kwargs)
550
551    # b. Run interactive segmentation (supported in both 2d and 3d, wherever relevant)
552    _run_interactive_segmentation_per_dataset(prompt_choice="box", **seg_kwargs)
553    _run_interactive_segmentation_per_dataset(prompt_choice="points", **seg_kwargs)
554
555
556def _clear_cached_items(retain, path, output_folder):
557    import shutil
558    from pathlib import Path
559
560    REMOVE_LIST = ["data", "crops", "auto", "int"]
561    if retain is None:
562        remove_list = REMOVE_LIST
563    else:
564        assert isinstance(retain, list)
565        remove_list = set(REMOVE_LIST) - set(retain)
566
567    paths = []
568    # Stage 1: Remove inputs.
569    if "data" in remove_list or "crops" in remove_list:
570        all_paths = glob(os.path.join(path, "*"))
571
572        # In case we want to remove both data and crops, we remove the data folder entirely.
573        if "data" in remove_list and "crops" in remove_list:
574            paths.extend(all_paths)
575            return
576
577        # Next, we verify whether the we only remove either of data or crops.
578        for curr_path in all_paths:
579            if os.path.basename(curr_path).startswith("roi") and "crops" in remove_list:
580                paths.append(curr_path)
581            elif "data" in remove_list:
582                paths.append(curr_path)
583
584    # Stage 2: Remove predictions
585    if "auto" in remove_list:
586        paths.extend(glob(os.path.join(output_folder, "amg_*")))
587        paths.extend(glob(os.path.join(output_folder, "ais_*")))
588
589    if "int" in remove_list:
590        paths.extend(glob(os.path.join(output_folder, "interactive_segmentation_*")))
591
592    [shutil.rmtree(_path) if Path(_path).is_dir() else os.remove(_path) for _path in paths]
593
594
595def run_benchmark_evaluations(
596    input_folder: Union[os.PathLike, str],
597    dataset_choice: str,
598    model_type: str = util._DEFAULT_MODEL,
599    output_folder: Optional[Union[str, os.PathLike]] = None,
600    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
601    run_amg: bool = False,
602    retain: Optional[List[str]] = None,
603    ignore_warnings: bool = False,
604):
605    """Run evaluation for benchmarking Segment Anything models on microscopy datasets.
606
607    Args:
608        input_folder: The path to directory where all inputs will be stored and preprocessed.
609        dataset_choice: The dataset choice.
610        model_type: The model choice for SAM.
611        output_folder: The path to directory where all outputs will be stored.
612        checkpoint_path: The checkpoint path
613        run_amg: Whether to run automatic segmentation in AMG mode.
614        retain: Whether to retain certain parts of the benchmark runs.
615            By default, removes everything besides quantitative results.
616            There is the choice to retain 'data', 'crops', 'auto', or 'int'.
617        ignore_warnings: Whether to ignore warnings.
618    """
619    start = time.time()
620
621    with _filter_warnings(ignore_warnings):
622        device = util._get_default_device()
623
624        # Ensure if all the datasets have been installed by default.
625        dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice)
626
627        for choice in dataset_choice:
628            output_folder = os.path.join(output_folder, choice)
629            result_dir = os.path.join(output_folder, "results")
630            if os.path.exists(result_dir):
631                continue
632
633            os.makedirs(result_dir, exist_ok=True)
634
635            data_path = os.path.join(input_folder, choice)
636
637            # Extrapolate desired set from the datasets:
638            # a. for 2d datasets - 2d patches with the most number of labels present
639            #    (in case of volumetric data, choose 2d patches per slice).
640            # b. for 3d datasets - 3d regions of interest with the most number of labels present.
641            ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10)
642
643            # Run inference and evaluation scripts on benchmark datasets.
644            image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim)
645            _run_benchmark_evaluation_series(
646                image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg
647            )
648
649            # Run inference and evaluation scripts on '2d' crops for volumetric datasets
650            if ndim == 3:
651                image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2)
652                _run_benchmark_evaluation_series(
653                    image_paths, gt_paths, model_type, output_folder, 2, device, checkpoint_path, run_amg
654                )
655
656            _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder)
657
658    diff = time.time() - start
659    hours, rest = divmod(diff, 3600)
660    minutes, seconds = divmod(rest, 60)
661    print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {seconds:.2f}s")
662
663
664def main():
665    """@private"""
666    import argparse
667
668    available_models = list(util.get_model_names())
669    available_models = ", ".join(available_models)
670
671    parser = argparse.ArgumentParser(
672        description="Run evaluation for benchmarking Segment Anything models on microscopy datasets."
673    )
674    parser.add_argument(
675        "-i", "--input_folder", type=str, required=True,
676        help="The path to a directory where the microscopy datasets are / will be stored."
677    )
678    parser.add_argument(
679        "-m", "--model_type", type=str, default=util._DEFAULT_MODEL,
680        help=f"The segment anything model that will be used, one of {available_models}."
681    )
682    parser.add_argument(
683        "-c", "--checkpoint_path", type=str, default=None,
684        help="Checkpoint from which the SAM model will be loaded loaded."
685    )
686    parser.add_argument(
687        "-d", "--dataset_choice", type=str, nargs='*', default=None,
688        help="The choice(s) of dataset for evaluating SAM models. Multiple datasets can be specified."
689    )
690    parser.add_argument(
691        "-o", "--output_folder", type=str, required=True,
692        help="The path where the results for automatic and interactive instance segmentation will be stored as 'csv'."
693    )
694    parser.add_argument(
695        "--amg", action="store_true",
696        help="Whether to run automatic segmentation in AMG mode (i.e. the default auto-seg approach for SAM)."
697    )
698    parser.add_argument(
699        "--retain", nargs="*", default=None,
700        help="By default, the functionality removes all besides quantitative results required for running benchmarks. "
701        "In case you would like to retain parts of the benchmark evaluation for visualization / reproducability, "
702        "you should choose one or multiple of 'data', 'crops', 'auto', 'int'. "
703        "where they are responsible for either retaining original inputs / extracted crops / "
704        "predictions of automatic segmentation / predictions of interactive segmentation, respectively."
705    )
706    args = parser.parse_args()
707
708    run_benchmark_evaluations(
709        input_folder=args.input_folder,
710        dataset_choice=args.dataset_choice,
711        model_type=args.model_type,
712        output_folder=args.output_folder,
713        checkpoint_path=args.checkpoint_path,
714        run_amg=args.amg,
715        retain=args.retain,
716        ignore_warnings=True,
717    )
718
719
720if __name__ == "__main__":
721    main()
LM_2D_DATASETS = ['livecell', 'deepbacs', 'tissuenet', 'neurips_cellseg', 'dynamicnuclearnet', 'hpa', 'covid_if', 'pannuke', 'lizard', 'orgasegment', 'omnipose', 'dic_hepg2']
LM_3D_DATASETS = ['plantseg_root', 'plantseg_ovules', 'gonuclear', 'mouse_embryo', 'embegseg', 'cellseg3d']
EM_2D_DATASETS = ['mitolab_tem']
EM_3D_DATASETS = ['mitoem_rat', 'mitoem_human', 'platynereis_nuclei', 'lucchi', 'mitolab', 'nuc_mm_mouse', 'num_mm_zebrafish', 'uro_cell', 'sponge_em', 'platynereis_cilia', 'vnc', 'asem_mito']
DATASET_RETURNS_FOLDER = {'deepbacs': '*.tif'}
DATASET_CONTAINER_KEYS = {'lucchi': ['raw', 'labels']}
def run_benchmark_evaluations( input_folder: Union[os.PathLike, str], dataset_choice: str, model_type: str = 'vit_l', output_folder: Union[os.PathLike, str, NoneType] = None, checkpoint_path: Union[os.PathLike, str, NoneType] = None, run_amg: bool = False, retain: Optional[List[str]] = None, ignore_warnings: bool = False):
596def run_benchmark_evaluations(
597    input_folder: Union[os.PathLike, str],
598    dataset_choice: str,
599    model_type: str = util._DEFAULT_MODEL,
600    output_folder: Optional[Union[str, os.PathLike]] = None,
601    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
602    run_amg: bool = False,
603    retain: Optional[List[str]] = None,
604    ignore_warnings: bool = False,
605):
606    """Run evaluation for benchmarking Segment Anything models on microscopy datasets.
607
608    Args:
609        input_folder: The path to directory where all inputs will be stored and preprocessed.
610        dataset_choice: The dataset choice.
611        model_type: The model choice for SAM.
612        output_folder: The path to directory where all outputs will be stored.
613        checkpoint_path: The checkpoint path
614        run_amg: Whether to run automatic segmentation in AMG mode.
615        retain: Whether to retain certain parts of the benchmark runs.
616            By default, removes everything besides quantitative results.
617            There is the choice to retain 'data', 'crops', 'auto', or 'int'.
618        ignore_warnings: Whether to ignore warnings.
619    """
620    start = time.time()
621
622    with _filter_warnings(ignore_warnings):
623        device = util._get_default_device()
624
625        # Ensure if all the datasets have been installed by default.
626        dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice)
627
628        for choice in dataset_choice:
629            output_folder = os.path.join(output_folder, choice)
630            result_dir = os.path.join(output_folder, "results")
631            if os.path.exists(result_dir):
632                continue
633
634            os.makedirs(result_dir, exist_ok=True)
635
636            data_path = os.path.join(input_folder, choice)
637
638            # Extrapolate desired set from the datasets:
639            # a. for 2d datasets - 2d patches with the most number of labels present
640            #    (in case of volumetric data, choose 2d patches per slice).
641            # b. for 3d datasets - 3d regions of interest with the most number of labels present.
642            ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10)
643
644            # Run inference and evaluation scripts on benchmark datasets.
645            image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim)
646            _run_benchmark_evaluation_series(
647                image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg
648            )
649
650            # Run inference and evaluation scripts on '2d' crops for volumetric datasets
651            if ndim == 3:
652                image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2)
653                _run_benchmark_evaluation_series(
654                    image_paths, gt_paths, model_type, output_folder, 2, device, checkpoint_path, run_amg
655                )
656
657            _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder)
658
659    diff = time.time() - start
660    hours, rest = divmod(diff, 3600)
661    minutes, seconds = divmod(rest, 60)
662    print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {seconds:.2f}s")

Run evaluation for benchmarking Segment Anything models on microscopy datasets.

Arguments:
  • input_folder: The path to directory where all inputs will be stored and preprocessed.
  • dataset_choice: The dataset choice.
  • model_type: The model choice for SAM.
  • output_folder: The path to directory where all outputs will be stored.
  • checkpoint_path: The checkpoint path
  • run_amg: Whether to run automatic segmentation in AMG mode.
  • retain: Whether to retain certain parts of the benchmark runs. By default, removes everything besides quantitative results. There is the choice to retain 'data', 'crops', 'auto', or 'int'.
  • ignore_warnings: Whether to ignore warnings.