
Functionality for qualitative comparison of Segment Anything models on microscopy data.

  1"""Functionality for qualitative comparison of Segment Anything models on microscopy data.
  4import os
  5from glob import glob
  6from tqdm import tqdm
  7from pathlib import Path
  8from functools import partial
  9from typing import Optional, Union, Dict, Any
 11import h5py
 12import numpy as np
 13import pandas as pd
 14import matplotlib.pyplot as plt
 15import skimage.draw as draw
 16from skimage import exposure
 17from scipy.ndimage import binary_dilation
 18from skimage.segmentation import relabel_sequential, find_boundaries
 20import torch
 22from .. import util
 23from ..prompt_generators import PointAndBoxPromptGenerator
 24from ..prompt_based_segmentation import segment_from_box, segment_from_points
 28# Compute all required data for the model comparison
 32def _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder):
 33    i = 0
 34    os.makedirs(output_folder, exist_ok=True)
 36    for x, y in tqdm(loader, total=n_samples):
 37        out_path = os.path.join(output_folder, f"sample_{i}.h5")
 39        im = x.numpy().squeeze()
 40        if im.ndim == 3 and im.shape[0] == 3:
 41            im = im.transpose((1, 2, 0))
 43        gt = y.numpy().squeeze().astype("uint32")
 44        gt = relabel_sequential(gt)[0]
 46        emb1 = util.precompute_image_embeddings(predictor1, im, ndim=2)
 47        util.set_precomputed(predictor1, emb1)
 49        emb2 = util.precompute_image_embeddings(predictor2, im, ndim=2)
 50        util.set_precomputed(predictor2, emb2)
 52        if predictor3 is not None:
 53            emb3 = util.precompute_image_embeddings(predictor3, im, ndim=2)
 54            util.set_precomputed(predictor3, emb3)
 56        with h5py.File(out_path, "a") as f:
 57            f.create_dataset("image", data=im, compression="gzip")
 59        gt_ids = np.unique(gt)[1:]
 60        centers, boxes = util.get_centers_and_bounding_boxes(gt)
 61        centers = [centers[gt_id] for gt_id in gt_ids]
 62        boxes = [boxes[gt_id] for gt_id in gt_ids]
 64        object_masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids)
 65        coords, labels, boxes, _ = prompt_generator(
 66            segmentation=object_masks,
 67            bbox_coordinates=boxes,
 68            center_coordinates=centers,
 69        )
 71        for idx, gt_id in tqdm(enumerate(gt_ids), total=len(gt_ids)):
 73            # Box prompts:
 74            # Reorder the coordinates so that they match the normal python convention.
 75            box = boxes[idx][[1, 0, 3, 2]]
 76            mask1_box = segment_from_box(predictor1, box)
 77            mask2_box = segment_from_box(predictor2, box)
 78            mask1_box, mask2_box = mask1_box.squeeze(), mask2_box.squeeze()
 80            if predictor3 is not None:
 81                mask3_box = segment_from_box(predictor3, box)
 82                mask3_box = mask3_box.squeeze()
 84            # Point prompts:
 85            # Reorder the coordinates so that they match the normal python convention.
 86            point_coords, point_labels = np.array(coords[idx])[:, ::-1], np.array(labels[idx])
 87            mask1_points = segment_from_points(predictor1, point_coords, point_labels)
 88            mask2_points = segment_from_points(predictor2, point_coords, point_labels)
 89            mask1_points, mask2_points = mask1_points.squeeze(), mask2_points.squeeze()
 91            if predictor3 is not None:
 92                mask3_points = segment_from_points(predictor3, point_coords, point_labels)
 93                mask3_points = mask3_points.squeeze()
 95            gt_mask = gt == gt_id
 96            with h5py.File(out_path, "a") as f:
 97                g = f.create_group(str(gt_id))
 98                g.attrs["point_coords"] = point_coords
 99                g.attrs["point_labels"] = point_labels
100                g.attrs["box"] = box
102                g.create_dataset("gt_mask", data=gt_mask, compression="gzip")
103                g.create_dataset("box/mask1", data=mask1_box.astype("uint8"), compression="gzip")
104                g.create_dataset("box/mask2", data=mask2_box.astype("uint8"), compression="gzip")
105                g.create_dataset("points/mask1", data=mask1_points.astype("uint8"), compression="gzip")
106                g.create_dataset("points/mask2", data=mask2_points.astype("uint8"), compression="gzip")
108                if predictor3 is not None:
109                    g.create_dataset("box/mask3", data=mask3_box.astype("uint8"), compression="gzip")
110                    g.create_dataset("points/mask3", data=mask3_points.astype("uint8"), compression="gzip")
112        i += 1
113        if i >= n_samples:
114            return
117def generate_data_for_model_comparison(
118    loader: torch.utils.data.DataLoader,
119    output_folder: Union[str, os.PathLike],
120    model_type1: str,
121    model_type2: str,
122    n_samples: int,
123    model_type3: Optional[str] = None,
124    checkpoint1: Optional[Union[str, os.PathLike]] = None,
125    checkpoint2: Optional[Union[str, os.PathLike]] = None,
126    checkpoint3: Optional[Union[str, os.PathLike]] = None,
127    peft_kwargs1: Optional[Dict[str, Any]] = None,
128    peft_kwargs2: Optional[Dict[str, Any]] = None,
129    peft_kwargs3: Optional[Dict[str, Any]] = None,
130) -> None:
131    """Generate samples for qualitative model comparison.
133    This precomputes the input for `model_comparison` and `model_comparison_with_napari`.
135    Args:
136        loader: The torch dataloader from which samples are drawn.
137        output_folder: The folder where the samples will be saved.
138        model_type1: The first model to use for comparison.
139            The value needs to be a valid model_type for `micro_sam.util.get_sam_model`.
140        model_type2: The second model to use for comparison.
141            The value needs to be a valid model_type for `micro_sam.util.get_sam_model`.
142        n_samples: The number of samples to draw from the dataloader.
143        checkpoint1: Optional checkpoint for the first model.
144        checkpoint2: Optional checkpoint for the second model.
145        checkpoint3: Optional checkpoint for the third model.
146    """
147    prompt_generator = PointAndBoxPromptGenerator(
148        n_positive_points=1,
149        n_negative_points=0,
150        dilation_strength=3,
151        get_point_prompts=True,
152        get_box_prompts=True,
153    )
155    predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1, peft_kwargs=peft_kwargs1)
156    predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2, peft_kwargs=peft_kwargs2)
158    if model_type3 is not None:
159        predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3, peft_kwargs=peft_kwargs3)
160    else:
161        predictor3 = None
163    _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder)
167# Visual evaluation according to metrics
171def _evaluate_samples(f, prefix, min_size):
172    eval_result = {
173        "gt_id": [],
174        "score1": [],
175        "score2": [],
176    }
177    for name, group in f.items():
178        if name == "image":
179            continue
181        gt_mask = group["gt_mask"][:]
183        size = gt_mask.sum()
184        if size < min_size:
185            continue
187        m1 = group[f"{prefix}/mask1"][:]
188        m2 = group[f"{prefix}/mask2"][:]
190        score1 = util.compute_iou(gt_mask, m1)
191        score2 = util.compute_iou(gt_mask, m2)
193        eval_result["gt_id"].append(name)
194        eval_result["score1"].append(score1)
195        eval_result["score2"].append(score2)
197    eval_result = pd.DataFrame.from_dict(eval_result)
198    eval_result["advantage1"] = eval_result["score1"] - eval_result["score2"]
199    eval_result["advantage2"] = eval_result["score2"] - eval_result["score1"]
200    return eval_result
203def _overlay_mask(image, mask, alpha=0.6):
204    assert image.ndim in (2, 3)
205    # overlay the mask
206    if image.ndim == 2:
207        overlay = np.stack([image, image, image]).transpose((1, 2, 0))
208    else:
209        overlay = image
210    assert overlay.shape[-1] == 3
211    mask_overlay = np.zeros_like(overlay)
212    mask_overlay[mask == 1] = [255, 0, 0]
213    alpha = alpha
214    overlay = alpha * overlay + (1.0 - alpha) * mask_overlay
215    return overlay.astype("uint8")
218def _enhance_image(im, do_norm=True):
219    # apply CLAHE to improve the image quality
220    if do_norm:
221        im -= im.min(axis=(0, 1), keepdims=True)
222        im /= (im.max(axis=(0, 1), keepdims=True) + 1e-6)
223    im = exposure.equalize_adapthist(im)
224    im *= 255
225    return im
228def _overlay_outline(im, mask, outline_dilation):
229    outline = find_boundaries(mask)
230    if outline_dilation > 0:
231        outline = binary_dilation(outline, iterations=outline_dilation)
232    overlay = im.copy()
233    overlay[outline] = [255, 255, 0]
234    return overlay
237def _overlay_box(im, prompt, outline_dilation):
238    start, end = prompt
239    rr, cc = draw.rectangle_perimeter(start, end=end, shape=im.shape[:2])
241    box_outline = np.zeros(im.shape[:2], dtype="bool")
242    box_outline[rr, cc] = 1
243    if outline_dilation > 0:
244        box_outline = binary_dilation(box_outline, iterations=outline_dilation)
246    overlay = im.copy()
247    overlay[box_outline] = [0, 255, 255]
249    return overlay
252# NOTE: we currently only support a single point
253def _overlay_points(im, prompt, radius):
254    coords, labels = prompt
255    # make sure we have a single positive prompt, other options are
256    # currently not supported
257    assert coords.shape[0] == labels.shape[0] == 1
258    assert labels[0] == 1
260    rr, cc = draw.disk(coords[0], radius, shape=im.shape[:2])
261    overlay = im.copy()
262    draw.set_color(overlay, (rr, cc), [0, 255, 255], alpha=1.0)
264    return overlay
267def _compare_eval(
268    f, eval_result, advantage_column, n_images_per_sample, prefix, sample_name,
269    plot_folder, point_radius, outline_dilation, have_model3, enhance_image,
271    result = eval_result.sort_values(advantage_column, ascending=False).iloc[:n_images_per_sample]
272    n_rows = result.shape[0]
274    image = f["image"][:]
275    is_box_prompt = prefix == "box"
276    overlay_prompts = partial(_overlay_box, outline_dilation=outline_dilation) if is_box_prompt else\
277        partial(_overlay_points, radius=point_radius)
279    def make_square(bb, shape):
280        box_shape = [b.stop - b.start for b in bb]
281        longest_side = max(box_shape)
282        padding = [(longest_side - sh) // 2 for sh in box_shape]
283        bb = tuple(
284            slice(max(b.start - pad, 0), min(b.stop + pad, sh)) for b, pad, sh in zip(bb, padding, shape)
285        )
286        return bb
288    def plot_ax(axis, i, row):
289        g = f[row.gt_id]
291        gt = g["gt_mask"][:]
292        mask1 = g[f"{prefix}/mask1"][:]
293        mask2 = g[f"{prefix}/mask2"][:]
295        # The mask3 is just for comparison purpose, we just plot the crops as it is.
296        if have_model3:
297            mask3 = g[f"{prefix}/mask3"][:]
299        fg_mask = (gt + mask1 + mask2) > 0
300        # if this is a box prompt we dilate the mask so that the bounding box
301        # can be seen
302        if is_box_prompt:
303            fg_mask = binary_dilation(fg_mask, iterations=5)
304        bb = np.where(fg_mask)
305        bb = tuple(
306            slice(int(b.min()), int(b.max() + 1)) for b in bb
307        )
308        bb = make_square(bb, fg_mask.shape)
310        offset = np.array([b.start for b in bb])
311        if is_box_prompt:
312            prompt = g.attrs["box"]
313            prompt = np.array(
314                [prompt[:2], prompt[2:]]
315            ) - offset
316        else:
317            prompt = (g.attrs["point_coords"] - offset, g.attrs["point_labels"])
319        if enhance_image:
320            im = _enhance_image(image[bb])
321        else:
322            im = image[bb]
324        gt, mask1, mask2 = gt[bb], mask1[bb], mask2[bb]
326        if have_model3:
327            mask3 = mask3[bb]
329        im1 = _overlay_mask(im, mask1)
330        im1 = _overlay_outline(im1, gt, outline_dilation)
331        im1 = overlay_prompts(im1, prompt)
332        ax = axis[0] if i is None else axis[i, 0]
333        ax.axis("off")
334        ax.imshow(im1)
336        # We put the third set of comparsion point in between
337        # so that the comparison looks -> default, generalist, specialist
338        if have_model3:
339            im3 = _overlay_mask(im, mask3)
340            im3 = _overlay_outline(im3, gt, outline_dilation)
341            im3 = overlay_prompts(im3, prompt)
342            ax = axis[1] if i is None else axis[i, 1]
343            ax.axis("off")
344            ax.imshow(im3)
346            nexax = 2
347        else:
348            nexax = 1
350        im2 = _overlay_mask(im, mask2)
351        im2 = _overlay_outline(im2, gt, outline_dilation)
352        im2 = overlay_prompts(im2, prompt)
353        ax = axis[nexax] if i is None else axis[i, nexax]
354        ax.axis("off")
355        ax.imshow(im2)
357    cols = 3 if have_model3 else 2
358    if plot_folder is None:
359        fig, axis = plt.subplots(n_rows, cols)
360        for i, (_, row) in enumerate(result.iterrows()):
361            plot_ax(axis, i, row)
362        plt.show()
363    else:
364        for i, (_, row) in enumerate(result.iterrows()):
365            fig, axis = plt.subplots(1, cols)
366            plot_ax(axis, None, row)
367            plt.subplots_adjust(wspace=0.05, hspace=0)
368            plt.savefig(os.path.join(plot_folder, f"{sample_name}_{i}.svg"), bbox_inches="tight")
369            plt.close()
372def _compare_prompts(
373    f, prefix, n_images_per_sample, min_size, sample_name, plot_folder,
374    point_radius, outline_dilation, have_model3, enhance_image,
376    box_eval = _evaluate_samples(f, prefix, min_size)
377    if plot_folder is None:
378        plot_folder1, plot_folder2 = None, None
379    else:
380        plot_folder1 = os.path.join(plot_folder, "advantage1")
381        plot_folder2 = os.path.join(plot_folder, "advantage2")
382        os.makedirs(plot_folder1, exist_ok=True)
383        os.makedirs(plot_folder2, exist_ok=True)
384    _compare_eval(
385        f, box_eval, "advantage1", n_images_per_sample, prefix, sample_name, plot_folder1,
386        point_radius, outline_dilation, have_model3, enhance_image,
387    )
388    _compare_eval(
389        f, box_eval, "advantage2", n_images_per_sample, prefix, sample_name, plot_folder2,
390        point_radius, outline_dilation, have_model3, enhance_image,
391    )
394def _compare_models(
395    path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3, enhance_image,
397    sample_name = Path(path).stem
398    with h5py.File(path, "r") as f:
399        if plot_folder is None:
400            plot_folder_points, plot_folder_box = None, None
401        else:
402            plot_folder_points = os.path.join(plot_folder, "points")
403            plot_folder_box = os.path.join(plot_folder, "box")
404        _compare_prompts(
405            f, "points", n_images_per_sample, min_size, sample_name, plot_folder_points,
406            point_radius, outline_dilation, have_model3, enhance_image,
407        )
408        _compare_prompts(
409            f, "box", n_images_per_sample, min_size, sample_name, plot_folder_box,
410            point_radius, outline_dilation, have_model3, enhance_image,
411        )
414def model_comparison(
415    output_folder: Union[str, os.PathLike],
416    n_images_per_sample: int,
417    min_size: int,
418    plot_folder: Optional[Union[str, os.PathLike]] = None,
419    point_radius: int = 4,
420    outline_dilation: int = 0,
421    have_model3=False,
422    enhance_image=True,
423) -> None:
424    """Create images for a qualitative model comparision.
426    Args:
427        output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`.
428        n_images_per_sample: The number of images to generate per precomputed sample.
429        min_size: The min size of ground-truth objects to take into account.
430        plot_folder: The folder where to save the plots. If not given the plots will be displayed.
431        point_radius: The radius of the point overlay.
432        outline_dilation: The dilation factor of the outline overlay.
433        enhance_image: Whether to enhance the input image.
434    """
435    files = glob(os.path.join(output_folder, "*.h5"))
436    for path in tqdm(files):
437        _compare_models(
438            path, n_images_per_sample, min_size, plot_folder, point_radius,
439            outline_dilation, have_model3, enhance_image,
440        )
444# Quick visual evaluation with napari
448def _check_group(g, show_points):
449    import napari
451    image = g["image"][:]
452    gt = g["gt_mask"][:]
453    if show_points:
454        m1 = g["points/mask1"][:]
455        m2 = g["points/mask2"][:]
456        points = g.attrs["point_coords"]
457    else:
458        m1 = g["box/mask1"][:]
459        m2 = g["box/mask2"][:]
460        box = g.attrs["box"]
461        box = np.array([
462            [box[0], box[1]], [box[2], box[3]]
463        ])
465    v = napari.Viewer()
466    v.add_image(image)
467    v.add_labels(gt)
468    v.add_labels(m1)
469    v.add_labels(m2)
470    if show_points:
471        # TODO use point labels for coloring
472        v.add_points(
473            points,
474            edge_color="#00FF00",
475            symbol="o",
476            face_color="transparent",
477            edge_width=0.5,
478            size=12,
479        )
480    else:
481        v.add_shapes(
482            box, face_color="transparent", edge_color="green", edge_width=4,
483        )
484    napari.run()
487def model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None:
488    """Use napari to display the qualtiative comparison results for two models.
490    Args:
491        output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`.
492        show_points: Whether to show the results for point or for box prompts.
493    """
494    files = glob(os.path.join(output_folder, "*.h5"))
495    for path in files:
496        print("Comparing models in", path)
497        with h5py.File(path, "r") as f:
498            for name, g in f.items():
499                if name == "image":
500                    continue
501                _check_group(g, show_points=show_points)
Use napari to display the qualtiative comparison results for two models.
