micro_sam.evaluation.model_comparison

Functionality for qualitative comparison of Segment Anything models on microscopy data.

  1"""Functionality for qualitative comparison of Segment Anything models on microscopy data.
  2"""
  3
  4import os
  5from glob import glob
  6from tqdm import tqdm
  7from pathlib import Path
  8from functools import partial
  9from typing import Optional, Union, Dict, Any
 10
 11import h5py
 12import numpy as np
 13import pandas as pd
 14import matplotlib.pyplot as plt
 15import skimage.draw as draw
 16from skimage import exposure
 17from scipy.ndimage import binary_dilation
 18from skimage.segmentation import relabel_sequential, find_boundaries
 19
 20import torch
 21
 22from .. import util
 23from ..prompt_generators import PointAndBoxPromptGenerator
 24from ..prompt_based_segmentation import segment_from_box, segment_from_points
 25
 26
 27#
 28# Compute all required data for the model comparison
 29#
 30
 31
 32def _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder):
 33    i = 0
 34    os.makedirs(output_folder, exist_ok=True)
 35
 36    for x, y in tqdm(loader, total=n_samples):
 37        out_path = os.path.join(output_folder, f"sample_{i}.h5")
 38
 39        im = x.numpy().squeeze()
 40        if im.ndim == 3 and im.shape[0] == 3:
 41            im = im.transpose((1, 2, 0))
 42
 43        gt = y.numpy().squeeze().astype("uint32")
 44        gt = relabel_sequential(gt)[0]
 45
 46        emb1 = util.precompute_image_embeddings(predictor1, im, ndim=2)
 47        util.set_precomputed(predictor1, emb1)
 48
 49        emb2 = util.precompute_image_embeddings(predictor2, im, ndim=2)
 50        util.set_precomputed(predictor2, emb2)
 51
 52        if predictor3 is not None:
 53            emb3 = util.precompute_image_embeddings(predictor3, im, ndim=2)
 54            util.set_precomputed(predictor3, emb3)
 55
 56        with h5py.File(out_path, "a") as f:
 57            f.create_dataset("image", data=im, compression="gzip")
 58
 59        gt_ids = np.unique(gt)[1:]
 60        centers, boxes = util.get_centers_and_bounding_boxes(gt)
 61        centers = [centers[gt_id] for gt_id in gt_ids]
 62        boxes = [boxes[gt_id] for gt_id in gt_ids]
 63
 64        object_masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids)
 65        coords, labels, boxes, _ = prompt_generator(
 66            segmentation=object_masks,
 67            bbox_coordinates=boxes,
 68            center_coordinates=centers,
 69        )
 70
 71        for idx, gt_id in tqdm(enumerate(gt_ids), total=len(gt_ids)):
 72
 73            # Box prompts:
 74            # Reorder the coordinates so that they match the normal python convention.
 75            box = boxes[idx][[1, 0, 3, 2]]
 76            mask1_box = segment_from_box(predictor1, box)
 77            mask2_box = segment_from_box(predictor2, box)
 78            mask1_box, mask2_box = mask1_box.squeeze(), mask2_box.squeeze()
 79
 80            if predictor3 is not None:
 81                mask3_box = segment_from_box(predictor3, box)
 82                mask3_box = mask3_box.squeeze()
 83
 84            # Point prompts:
 85            # Reorder the coordinates so that they match the normal python convention.
 86            point_coords, point_labels = np.array(coords[idx])[:, ::-1], np.array(labels[idx])
 87            mask1_points = segment_from_points(predictor1, point_coords, point_labels)
 88            mask2_points = segment_from_points(predictor2, point_coords, point_labels)
 89            mask1_points, mask2_points = mask1_points.squeeze(), mask2_points.squeeze()
 90
 91            if predictor3 is not None:
 92                mask3_points = segment_from_points(predictor3, point_coords, point_labels)
 93                mask3_points = mask3_points.squeeze()
 94
 95            gt_mask = gt == gt_id
 96            with h5py.File(out_path, "a") as f:
 97                g = f.create_group(str(gt_id))
 98                g.attrs["point_coords"] = point_coords
 99                g.attrs["point_labels"] = point_labels
100                g.attrs["box"] = box
101
102                g.create_dataset("gt_mask", data=gt_mask, compression="gzip")
103                g.create_dataset("box/mask1", data=mask1_box.astype("uint8"), compression="gzip")
104                g.create_dataset("box/mask2", data=mask2_box.astype("uint8"), compression="gzip")
105                g.create_dataset("points/mask1", data=mask1_points.astype("uint8"), compression="gzip")
106                g.create_dataset("points/mask2", data=mask2_points.astype("uint8"), compression="gzip")
107
108                if predictor3 is not None:
109                    g.create_dataset("box/mask3", data=mask3_box.astype("uint8"), compression="gzip")
110                    g.create_dataset("points/mask3", data=mask3_points.astype("uint8"), compression="gzip")
111
112        i += 1
113        if i >= n_samples:
114            return
115
116
117def generate_data_for_model_comparison(
118    loader: torch.utils.data.DataLoader,
119    output_folder: Union[str, os.PathLike],
120    model_type1: str,
121    model_type2: str,
122    n_samples: int,
123    model_type3: Optional[str] = None,
124    checkpoint1: Optional[Union[str, os.PathLike]] = None,
125    checkpoint2: Optional[Union[str, os.PathLike]] = None,
126    checkpoint3: Optional[Union[str, os.PathLike]] = None,
127    peft_kwargs1: Optional[Dict[str, Any]] = None,
128    peft_kwargs2: Optional[Dict[str, Any]] = None,
129    peft_kwargs3: Optional[Dict[str, Any]] = None,
130) -> None:
131    """Generate samples for qualitative model comparison.
132
133    This precomputes the input for `model_comparison` and `model_comparison_with_napari`.
134
135    Args:
136        loader: The torch dataloader from which samples are drawn.
137        output_folder: The folder where the samples will be saved.
138        model_type1: The first model to use for comparison.
139            The value needs to be a valid model_type for `micro_sam.util.get_sam_model`.
140        model_type2: The second model to use for comparison.
141            The value needs to be a valid model_type for `micro_sam.util.get_sam_model`.
142        n_samples: The number of samples to draw from the dataloader.
143        checkpoint1: Optional checkpoint for the first model.
144        checkpoint2: Optional checkpoint for the second model.
145        checkpoint3: Optional checkpoint for the third model.
146    """
147    prompt_generator = PointAndBoxPromptGenerator(
148        n_positive_points=1,
149        n_negative_points=0,
150        dilation_strength=3,
151        get_point_prompts=True,
152        get_box_prompts=True,
153    )
154
155    predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1, peft_kwargs=peft_kwargs1)
156    predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2, peft_kwargs=peft_kwargs2)
157
158    if model_type3 is not None:
159        predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3, peft_kwargs=peft_kwargs3)
160    else:
161        predictor3 = None
162
163    _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder)
164
165
166#
167# Visual evaluation according to metrics
168#
169
170
171def _evaluate_samples(f, prefix, min_size):
172    eval_result = {
173        "gt_id": [],
174        "score1": [],
175        "score2": [],
176    }
177    for name, group in f.items():
178        if name == "image":
179            continue
180
181        gt_mask = group["gt_mask"][:]
182
183        size = gt_mask.sum()
184        if size < min_size:
185            continue
186
187        m1 = group[f"{prefix}/mask1"][:]
188        m2 = group[f"{prefix}/mask2"][:]
189
190        score1 = util.compute_iou(gt_mask, m1)
191        score2 = util.compute_iou(gt_mask, m2)
192
193        eval_result["gt_id"].append(name)
194        eval_result["score1"].append(score1)
195        eval_result["score2"].append(score2)
196
197    eval_result = pd.DataFrame.from_dict(eval_result)
198    eval_result["advantage1"] = eval_result["score1"] - eval_result["score2"]
199    eval_result["advantage2"] = eval_result["score2"] - eval_result["score1"]
200    return eval_result
201
202
203def _overlay_mask(image, mask, alpha=0.6):
204    assert image.ndim in (2, 3)
205    # overlay the mask
206    if image.ndim == 2:
207        overlay = np.stack([image, image, image]).transpose((1, 2, 0))
208    else:
209        overlay = image
210    assert overlay.shape[-1] == 3
211    mask_overlay = np.zeros_like(overlay)
212    mask_overlay[mask == 1] = [255, 0, 0]
213    alpha = alpha
214    overlay = alpha * overlay + (1.0 - alpha) * mask_overlay
215    return overlay.astype("uint8")
216
217
218def _enhance_image(im, do_norm=True):
219    # apply CLAHE to improve the image quality
220    if do_norm:
221        im -= im.min(axis=(0, 1), keepdims=True)
222        im /= (im.max(axis=(0, 1), keepdims=True) + 1e-6)
223    im = exposure.equalize_adapthist(im)
224    im *= 255
225    return im
226
227
228def _overlay_outline(im, mask, outline_dilation):
229    outline = find_boundaries(mask)
230    if outline_dilation > 0:
231        outline = binary_dilation(outline, iterations=outline_dilation)
232    overlay = im.copy()
233    overlay[outline] = [255, 255, 0]
234    return overlay
235
236
237def _overlay_box(im, prompt, outline_dilation):
238    start, end = prompt
239    rr, cc = draw.rectangle_perimeter(start, end=end, shape=im.shape[:2])
240
241    box_outline = np.zeros(im.shape[:2], dtype="bool")
242    box_outline[rr, cc] = 1
243    if outline_dilation > 0:
244        box_outline = binary_dilation(box_outline, iterations=outline_dilation)
245
246    overlay = im.copy()
247    overlay[box_outline] = [0, 255, 255]
248
249    return overlay
250
251
252# NOTE: we currently only support a single point
253def _overlay_points(im, prompt, radius):
254    coords, labels = prompt
255    # make sure we have a single positive prompt, other options are
256    # currently not supported
257    assert coords.shape[0] == labels.shape[0] == 1
258    assert labels[0] == 1
259
260    rr, cc = draw.disk(coords[0], radius, shape=im.shape[:2])
261    overlay = im.copy()
262    draw.set_color(overlay, (rr, cc), [0, 255, 255], alpha=1.0)
263
264    return overlay
265
266
267def _compare_eval(
268    f, eval_result, advantage_column, n_images_per_sample, prefix, sample_name,
269    plot_folder, point_radius, outline_dilation, have_model3, enhance_image,
270):
271    result = eval_result.sort_values(advantage_column, ascending=False).iloc[:n_images_per_sample]
272    n_rows = result.shape[0]
273
274    image = f["image"][:]
275    is_box_prompt = prefix == "box"
276    overlay_prompts = partial(_overlay_box, outline_dilation=outline_dilation) if is_box_prompt else\
277        partial(_overlay_points, radius=point_radius)
278
279    def make_square(bb, shape):
280        box_shape = [b.stop - b.start for b in bb]
281        longest_side = max(box_shape)
282        padding = [(longest_side - sh) // 2 for sh in box_shape]
283        bb = tuple(
284            slice(max(b.start - pad, 0), min(b.stop + pad, sh)) for b, pad, sh in zip(bb, padding, shape)
285        )
286        return bb
287
288    def plot_ax(axis, i, row):
289        g = f[row.gt_id]
290
291        gt = g["gt_mask"][:]
292        mask1 = g[f"{prefix}/mask1"][:]
293        mask2 = g[f"{prefix}/mask2"][:]
294
295        # The mask3 is just for comparison purpose, we just plot the crops as it is.
296        if have_model3:
297            mask3 = g[f"{prefix}/mask3"][:]
298
299        fg_mask = (gt + mask1 + mask2) > 0
300        # if this is a box prompt we dilate the mask so that the bounding box
301        # can be seen
302        if is_box_prompt:
303            fg_mask = binary_dilation(fg_mask, iterations=5)
304        bb = np.where(fg_mask)
305        bb = tuple(
306            slice(int(b.min()), int(b.max() + 1)) for b in bb
307        )
308        bb = make_square(bb, fg_mask.shape)
309
310        offset = np.array([b.start for b in bb])
311        if is_box_prompt:
312            prompt = g.attrs["box"]
313            prompt = np.array(
314                [prompt[:2], prompt[2:]]
315            ) - offset
316        else:
317            prompt = (g.attrs["point_coords"] - offset, g.attrs["point_labels"])
318
319        if enhance_image:
320            im = _enhance_image(image[bb])
321        else:
322            im = image[bb]
323
324        gt, mask1, mask2 = gt[bb], mask1[bb], mask2[bb]
325
326        if have_model3:
327            mask3 = mask3[bb]
328
329        im1 = _overlay_mask(im, mask1)
330        im1 = _overlay_outline(im1, gt, outline_dilation)
331        im1 = overlay_prompts(im1, prompt)
332        ax = axis[0] if i is None else axis[i, 0]
333        ax.axis("off")
334        ax.imshow(im1)
335
336        # We put the third set of comparsion point in between
337        # so that the comparison looks -> default, generalist, specialist
338        if have_model3:
339            im3 = _overlay_mask(im, mask3)
340            im3 = _overlay_outline(im3, gt, outline_dilation)
341            im3 = overlay_prompts(im3, prompt)
342            ax = axis[1] if i is None else axis[i, 1]
343            ax.axis("off")
344            ax.imshow(im3)
345
346            nexax = 2
347        else:
348            nexax = 1
349
350        im2 = _overlay_mask(im, mask2)
351        im2 = _overlay_outline(im2, gt, outline_dilation)
352        im2 = overlay_prompts(im2, prompt)
353        ax = axis[nexax] if i is None else axis[i, nexax]
354        ax.axis("off")
355        ax.imshow(im2)
356
357    cols = 3 if have_model3 else 2
358    if plot_folder is None:
359        fig, axis = plt.subplots(n_rows, cols)
360        for i, (_, row) in enumerate(result.iterrows()):
361            plot_ax(axis, i, row)
362        plt.show()
363    else:
364        for i, (_, row) in enumerate(result.iterrows()):
365            fig, axis = plt.subplots(1, cols)
366            plot_ax(axis, None, row)
367            plt.subplots_adjust(wspace=0.05, hspace=0)
368            plt.savefig(os.path.join(plot_folder, f"{sample_name}_{i}.svg"), bbox_inches="tight")
369            plt.close()
370
371
372def _compare_prompts(
373    f, prefix, n_images_per_sample, min_size, sample_name, plot_folder,
374    point_radius, outline_dilation, have_model3, enhance_image,
375):
376    box_eval = _evaluate_samples(f, prefix, min_size)
377    if plot_folder is None:
378        plot_folder1, plot_folder2 = None, None
379    else:
380        plot_folder1 = os.path.join(plot_folder, "advantage1")
381        plot_folder2 = os.path.join(plot_folder, "advantage2")
382        os.makedirs(plot_folder1, exist_ok=True)
383        os.makedirs(plot_folder2, exist_ok=True)
384    _compare_eval(
385        f, box_eval, "advantage1", n_images_per_sample, prefix, sample_name, plot_folder1,
386        point_radius, outline_dilation, have_model3, enhance_image,
387    )
388    _compare_eval(
389        f, box_eval, "advantage2", n_images_per_sample, prefix, sample_name, plot_folder2,
390        point_radius, outline_dilation, have_model3, enhance_image,
391    )
392
393
394def _compare_models(
395    path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3, enhance_image,
396):
397    sample_name = Path(path).stem
398    with h5py.File(path, "r") as f:
399        if plot_folder is None:
400            plot_folder_points, plot_folder_box = None, None
401        else:
402            plot_folder_points = os.path.join(plot_folder, "points")
403            plot_folder_box = os.path.join(plot_folder, "box")
404        _compare_prompts(
405            f, "points", n_images_per_sample, min_size, sample_name, plot_folder_points,
406            point_radius, outline_dilation, have_model3, enhance_image,
407        )
408        _compare_prompts(
409            f, "box", n_images_per_sample, min_size, sample_name, plot_folder_box,
410            point_radius, outline_dilation, have_model3, enhance_image,
411        )
412
413
414def model_comparison(
415    output_folder: Union[str, os.PathLike],
416    n_images_per_sample: int,
417    min_size: int,
418    plot_folder: Optional[Union[str, os.PathLike]] = None,
419    point_radius: int = 4,
420    outline_dilation: int = 0,
421    have_model3=False,
422    enhance_image=True,
423) -> None:
424    """Create images for a qualitative model comparision.
425
426    Args:
427        output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`.
428        n_images_per_sample: The number of images to generate per precomputed sample.
429        min_size: The min size of ground-truth objects to take into account.
430        plot_folder: The folder where to save the plots. If not given the plots will be displayed.
431        point_radius: The radius of the point overlay.
432        outline_dilation: The dilation factor of the outline overlay.
433        enhance_image: Whether to enhance the input image.
434    """
435    files = glob(os.path.join(output_folder, "*.h5"))
436    for path in tqdm(files):
437        _compare_models(
438            path, n_images_per_sample, min_size, plot_folder, point_radius,
439            outline_dilation, have_model3, enhance_image,
440        )
441
442
443#
444# Quick visual evaluation with napari
445#
446
447
448def _check_group(g, show_points):
449    import napari
450
451    image = g["image"][:]
452    gt = g["gt_mask"][:]
453    if show_points:
454        m1 = g["points/mask1"][:]
455        m2 = g["points/mask2"][:]
456        points = g.attrs["point_coords"]
457    else:
458        m1 = g["box/mask1"][:]
459        m2 = g["box/mask2"][:]
460        box = g.attrs["box"]
461        box = np.array([
462            [box[0], box[1]], [box[2], box[3]]
463        ])
464
465    v = napari.Viewer()
466    v.add_image(image)
467    v.add_labels(gt)
468    v.add_labels(m1)
469    v.add_labels(m2)
470    if show_points:
471        # TODO use point labels for coloring
472        v.add_points(
473            points,
474            edge_color="#00FF00",
475            symbol="o",
476            face_color="transparent",
477            edge_width=0.5,
478            size=12,
479        )
480    else:
481        v.add_shapes(
482            box, face_color="transparent", edge_color="green", edge_width=4,
483        )
484    napari.run()
485
486
487def model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None:
488    """Use napari to display the qualtiative comparison results for two models.
489
490    Args:
491        output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`.
492        show_points: Whether to show the results for point or for box prompts.
493    """
494    files = glob(os.path.join(output_folder, "*.h5"))
495    for path in files:
496        print("Comparing models in", path)
497        with h5py.File(path, "r") as f:
498            for name, g in f.items():
499                if name == "image":
500                    continue
501                _check_group(g, show_points=show_points)
def generate_data_for_model_comparison( loader: torch.utils.data.dataloader.DataLoader, output_folder: Union[str, os.PathLike], model_type1: str, model_type2: str, n_samples: int, model_type3: Optional[str] = None, checkpoint1: Union[str, os.PathLike, NoneType] = None, checkpoint2: Union[str, os.PathLike, NoneType] = None, checkpoint3: Union[str, os.PathLike, NoneType] = None, peft_kwargs1: Optional[Dict[str, Any]] = None, peft_kwargs2: Optional[Dict[str, Any]] = None, peft_kwargs3: Optional[Dict[str, Any]] = None) -> None:
118def generate_data_for_model_comparison(
119    loader: torch.utils.data.DataLoader,
120    output_folder: Union[str, os.PathLike],
121    model_type1: str,
122    model_type2: str,
123    n_samples: int,
124    model_type3: Optional[str] = None,
125    checkpoint1: Optional[Union[str, os.PathLike]] = None,
126    checkpoint2: Optional[Union[str, os.PathLike]] = None,
127    checkpoint3: Optional[Union[str, os.PathLike]] = None,
128    peft_kwargs1: Optional[Dict[str, Any]] = None,
129    peft_kwargs2: Optional[Dict[str, Any]] = None,
130    peft_kwargs3: Optional[Dict[str, Any]] = None,
131) -> None:
132    """Generate samples for qualitative model comparison.
133
134    This precomputes the input for `model_comparison` and `model_comparison_with_napari`.
135
136    Args:
137        loader: The torch dataloader from which samples are drawn.
138        output_folder: The folder where the samples will be saved.
139        model_type1: The first model to use for comparison.
140            The value needs to be a valid model_type for `micro_sam.util.get_sam_model`.
141        model_type2: The second model to use for comparison.
142            The value needs to be a valid model_type for `micro_sam.util.get_sam_model`.
143        n_samples: The number of samples to draw from the dataloader.
144        checkpoint1: Optional checkpoint for the first model.
145        checkpoint2: Optional checkpoint for the second model.
146        checkpoint3: Optional checkpoint for the third model.
147    """
148    prompt_generator = PointAndBoxPromptGenerator(
149        n_positive_points=1,
150        n_negative_points=0,
151        dilation_strength=3,
152        get_point_prompts=True,
153        get_box_prompts=True,
154    )
155
156    predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1, peft_kwargs=peft_kwargs1)
157    predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2, peft_kwargs=peft_kwargs2)
158
159    if model_type3 is not None:
160        predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3, peft_kwargs=peft_kwargs3)
161    else:
162        predictor3 = None
163
164    _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder)

Generate samples for qualitative model comparison.

This precomputes the input for model_comparison and model_comparison_with_napari.

Arguments:
  • loader: The torch dataloader from which samples are drawn.
  • output_folder: The folder where the samples will be saved.
  • model_type1: The first model to use for comparison. The value needs to be a valid model_type for micro_sam.util.get_sam_model.
  • model_type2: The second model to use for comparison. The value needs to be a valid model_type for micro_sam.util.get_sam_model.
  • n_samples: The number of samples to draw from the dataloader.
  • checkpoint1: Optional checkpoint for the first model.
  • checkpoint2: Optional checkpoint for the second model.
  • checkpoint3: Optional checkpoint for the third model.
def model_comparison( output_folder: Union[str, os.PathLike], n_images_per_sample: int, min_size: int, plot_folder: Union[str, os.PathLike, NoneType] = None, point_radius: int = 4, outline_dilation: int = 0, have_model3=False, enhance_image=True) -> None:
415def model_comparison(
416    output_folder: Union[str, os.PathLike],
417    n_images_per_sample: int,
418    min_size: int,
419    plot_folder: Optional[Union[str, os.PathLike]] = None,
420    point_radius: int = 4,
421    outline_dilation: int = 0,
422    have_model3=False,
423    enhance_image=True,
424) -> None:
425    """Create images for a qualitative model comparision.
426
427    Args:
428        output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`.
429        n_images_per_sample: The number of images to generate per precomputed sample.
430        min_size: The min size of ground-truth objects to take into account.
431        plot_folder: The folder where to save the plots. If not given the plots will be displayed.
432        point_radius: The radius of the point overlay.
433        outline_dilation: The dilation factor of the outline overlay.
434        enhance_image: Whether to enhance the input image.
435    """
436    files = glob(os.path.join(output_folder, "*.h5"))
437    for path in tqdm(files):
438        _compare_models(
439            path, n_images_per_sample, min_size, plot_folder, point_radius,
440            outline_dilation, have_model3, enhance_image,
441        )

Create images for a qualitative model comparision.

Arguments:
  • output_folder: The folder with the data precomputed by generate_data_for_model_comparison.
  • n_images_per_sample: The number of images to generate per precomputed sample.
  • min_size: The min size of ground-truth objects to take into account.
  • plot_folder: The folder where to save the plots. If not given the plots will be displayed.
  • point_radius: The radius of the point overlay.
  • outline_dilation: The dilation factor of the outline overlay.
  • enhance_image: Whether to enhance the input image.
def model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None:
488def model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None:
489    """Use napari to display the qualtiative comparison results for two models.
490
491    Args:
492        output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`.
493        show_points: Whether to show the results for point or for box prompts.
494    """
495    files = glob(os.path.join(output_folder, "*.h5"))
496    for path in files:
497        print("Comparing models in", path)
498        with h5py.File(path, "r") as f:
499            for name, g in f.items():
500                if name == "image":
501                    continue
502                _check_group(g, show_points=show_points)

Use napari to display the qualtiative comparison results for two models.

Arguments: