micro_sam.evaluation.model_comparison

Functionality for qualitative comparison of Segment Anything models on microscopy data.

  1"""Functionality for qualitative comparison of Segment Anything models on microscopy data.
  2"""
  3
  4import os
  5from glob import glob
  6from tqdm import tqdm
  7from pathlib import Path
  8from functools import partial
  9from typing import Optional, Union
 10
 11import h5py
 12import numpy as np
 13import pandas as pd
 14import matplotlib.pyplot as plt
 15import skimage.draw as draw
 16from skimage import exposure
 17from scipy.ndimage import binary_dilation
 18from skimage.segmentation import relabel_sequential, find_boundaries
 19
 20import torch
 21
 22from .. import util
 23from ..prompt_generators import PointAndBoxPromptGenerator
 24from ..prompt_based_segmentation import segment_from_box, segment_from_points
 25
 26
 27#
 28# Compute all required data for the model comparison
 29#
 30
 31
 32def _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder):
 33    i = 0
 34    os.makedirs(output_folder, exist_ok=True)
 35
 36    for x, y in tqdm(loader, total=n_samples):
 37        out_path = os.path.join(output_folder, f"sample_{i}.h5")
 38
 39        im = x.numpy().squeeze()
 40        if im.ndim == 3 and im.shape[0] == 3:
 41            im = im.transpose((1, 2, 0))
 42
 43        gt = y.numpy().squeeze().astype("uint32")
 44        gt = relabel_sequential(gt)[0]
 45
 46        emb1 = util.precompute_image_embeddings(predictor1, im, ndim=2)
 47        util.set_precomputed(predictor1, emb1)
 48
 49        emb2 = util.precompute_image_embeddings(predictor2, im, ndim=2)
 50        util.set_precomputed(predictor2, emb2)
 51
 52        if predictor3 is not None:
 53            emb3 = util.precompute_image_embeddings(predictor3, im, ndim=2)
 54            util.set_precomputed(predictor3, emb3)
 55
 56        with h5py.File(out_path, "a") as f:
 57            f.create_dataset("image", data=im, compression="gzip")
 58
 59        gt_ids = np.unique(gt)[1:]
 60        centers, boxes = util.get_centers_and_bounding_boxes(gt)
 61        centers = [centers[gt_id] for gt_id in gt_ids]
 62        boxes = [boxes[gt_id] for gt_id in gt_ids]
 63
 64        object_masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids)
 65        coords, labels, boxes, _ = prompt_generator(
 66            segmentation=object_masks,
 67            bbox_coordinates=boxes,
 68            center_coordinates=centers,
 69        )
 70
 71        for idx, gt_id in tqdm(enumerate(gt_ids), total=len(gt_ids)):
 72
 73            # Box prompts:
 74            # Reorder the coordinates so that they match the normal python convention.
 75            box = boxes[idx][[1, 0, 3, 2]]
 76            mask1_box = segment_from_box(predictor1, box)
 77            mask2_box = segment_from_box(predictor2, box)
 78            mask1_box, mask2_box = mask1_box.squeeze(), mask2_box.squeeze()
 79
 80            if predictor3 is not None:
 81                mask3_box = segment_from_box(predictor3, box)
 82                mask3_box = mask3_box.squeeze()
 83
 84            # Point prompts:
 85            # Reorder the coordinates so that they match the normal python convention.
 86            point_coords, point_labels = np.array(coords[idx])[:, ::-1], np.array(labels[idx])
 87            mask1_points = segment_from_points(predictor1, point_coords, point_labels)
 88            mask2_points = segment_from_points(predictor2, point_coords, point_labels)
 89            mask1_points, mask2_points = mask1_points.squeeze(), mask2_points.squeeze()
 90
 91            if predictor3 is not None:
 92                mask3_points = segment_from_points(predictor3, point_coords, point_labels)
 93                mask3_points = mask3_points.squeeze()
 94
 95            gt_mask = gt == gt_id
 96            with h5py.File(out_path, "a") as f:
 97                g = f.create_group(str(gt_id))
 98                g.attrs["point_coords"] = point_coords
 99                g.attrs["point_labels"] = point_labels
100                g.attrs["box"] = box
101
102                g.create_dataset("gt_mask", data=gt_mask, compression="gzip")
103                g.create_dataset("box/mask1", data=mask1_box.astype("uint8"), compression="gzip")
104                g.create_dataset("box/mask2", data=mask2_box.astype("uint8"), compression="gzip")
105                g.create_dataset("points/mask1", data=mask1_points.astype("uint8"), compression="gzip")
106                g.create_dataset("points/mask2", data=mask2_points.astype("uint8"), compression="gzip")
107
108                if predictor3 is not None:
109                    g.create_dataset("box/mask3", data=mask3_box.astype("uint8"), compression="gzip")
110                    g.create_dataset("points/mask3", data=mask3_points.astype("uint8"), compression="gzip")
111
112        i += 1
113        if i >= n_samples:
114            return
115
116
117def generate_data_for_model_comparison(
118    loader: torch.utils.data.DataLoader,
119    output_folder: Union[str, os.PathLike],
120    model_type1: str,
121    model_type2: str,
122    n_samples: int,
123    model_type3: Optional[str] = None,
124    checkpoint1: Optional[Union[str, os.PathLike]] = None,
125    checkpoint2: Optional[Union[str, os.PathLike]] = None,
126    checkpoint3: Optional[Union[str, os.PathLike]] = None,
127) -> None:
128    """Generate samples for qualitative model comparison.
129
130    This precomputes the input for `model_comparison` and `model_comparison_with_napari`.
131
132    Args:
133        loader: The torch dataloader from which samples are drawn.
134        output_folder: The folder where the samples will be saved.
135        model_type1: The first model to use for comparison.
136            The value needs to be a valid model_type for `micro_sam.util.get_sam_model`.
137        model_type2: The second model to use for comparison.
138            The value needs to be a valid model_type for `micro_sam.util.get_sam_model`.
139        n_samples: The number of samples to draw from the dataloader.
140        checkpoint1: Optional checkpoint for the first model.
141        checkpoint2: Optional checkpoint for the second model.
142    """
143    prompt_generator = PointAndBoxPromptGenerator(
144        n_positive_points=1,
145        n_negative_points=0,
146        dilation_strength=3,
147        get_point_prompts=True,
148        get_box_prompts=True,
149    )
150    predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1)
151    predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2)
152
153    if model_type3 is not None:
154        predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3)
155    else:
156        predictor3 = None
157
158    _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder)
159
160
161#
162# Visual evaluation according to metrics
163#
164
165
166def _evaluate_samples(f, prefix, min_size):
167    eval_result = {
168        "gt_id": [],
169        "score1": [],
170        "score2": [],
171    }
172    for name, group in f.items():
173        if name == "image":
174            continue
175
176        gt_mask = group["gt_mask"][:]
177
178        size = gt_mask.sum()
179        if size < min_size:
180            continue
181
182        m1 = group[f"{prefix}/mask1"][:]
183        m2 = group[f"{prefix}/mask2"][:]
184
185        score1 = util.compute_iou(gt_mask, m1)
186        score2 = util.compute_iou(gt_mask, m2)
187
188        eval_result["gt_id"].append(name)
189        eval_result["score1"].append(score1)
190        eval_result["score2"].append(score2)
191
192    eval_result = pd.DataFrame.from_dict(eval_result)
193    eval_result["advantage1"] = eval_result["score1"] - eval_result["score2"]
194    eval_result["advantage2"] = eval_result["score2"] - eval_result["score1"]
195    return eval_result
196
197
198def _overlay_mask(image, mask, alpha=0.6):
199    assert image.ndim in (2, 3)
200    # overlay the mask
201    if image.ndim == 2:
202        overlay = np.stack([image, image, image]).transpose((1, 2, 0))
203    else:
204        overlay = image
205    assert overlay.shape[-1] == 3
206    mask_overlay = np.zeros_like(overlay)
207    mask_overlay[mask == 1] = [255, 0, 0]
208    alpha = alpha
209    overlay = alpha * overlay + (1.0 - alpha) * mask_overlay
210    return overlay.astype("uint8")
211
212
213def _enhance_image(im, do_norm=True):
214    # apply CLAHE to improve the image quality
215    if do_norm:
216        im -= im.min(axis=(0, 1), keepdims=True)
217        im /= (im.max(axis=(0, 1), keepdims=True) + 1e-6)
218    im = exposure.equalize_adapthist(im)
219    im *= 255
220    return im
221
222
223def _overlay_outline(im, mask, outline_dilation):
224    outline = find_boundaries(mask)
225    if outline_dilation > 0:
226        outline = binary_dilation(outline, iterations=outline_dilation)
227    overlay = im.copy()
228    overlay[outline] = [255, 255, 0]
229    return overlay
230
231
232def _overlay_box(im, prompt, outline_dilation):
233    start, end = prompt
234    rr, cc = draw.rectangle_perimeter(start, end=end, shape=im.shape[:2])
235
236    box_outline = np.zeros(im.shape[:2], dtype="bool")
237    box_outline[rr, cc] = 1
238    if outline_dilation > 0:
239        box_outline = binary_dilation(box_outline, iterations=outline_dilation)
240
241    overlay = im.copy()
242    overlay[box_outline] = [0, 255, 255]
243
244    return overlay
245
246
247# NOTE: we currently only support a single point
248def _overlay_points(im, prompt, radius):
249    coords, labels = prompt
250    # make sure we have a single positive prompt, other options are
251    # currently not supported
252    assert coords.shape[0] == labels.shape[0] == 1
253    assert labels[0] == 1
254
255    rr, cc = draw.disk(coords[0], radius, shape=im.shape[:2])
256    overlay = im.copy()
257    draw.set_color(overlay, (rr, cc), [0, 255, 255], alpha=1.0)
258
259    return overlay
260
261
262def _compare_eval(
263    f, eval_result, advantage_column, n_images_per_sample, prefix,
264    sample_name, plot_folder, point_radius, outline_dilation, have_model3,
265):
266    result = eval_result.sort_values(advantage_column, ascending=False).iloc[:n_images_per_sample]
267    n_rows = result.shape[0]
268
269    image = f["image"][:]
270    is_box_prompt = prefix == "box"
271    overlay_prompts = partial(_overlay_box, outline_dilation=outline_dilation) if is_box_prompt else\
272        partial(_overlay_points, radius=point_radius)
273
274    def make_square(bb, shape):
275        box_shape = [b.stop - b.start for b in bb]
276        longest_side = max(box_shape)
277        padding = [(longest_side - sh) // 2 for sh in box_shape]
278        bb = tuple(
279            slice(max(b.start - pad, 0), min(b.stop + pad, sh)) for b, pad, sh in zip(bb, padding, shape)
280        )
281        return bb
282
283    def plot_ax(axis, i, row):
284        g = f[row.gt_id]
285
286        gt = g["gt_mask"][:]
287        mask1 = g[f"{prefix}/mask1"][:]
288        mask2 = g[f"{prefix}/mask2"][:]
289
290        # The mask3 is just for comparison purpose, we just plot the crops as it is.
291        if have_model3:
292            mask3 = g[f"{prefix}/mask3"][:]
293
294        fg_mask = (gt + mask1 + mask2) > 0
295        # if this is a box prompt we dilate the mask so that the bounding box
296        # can be seen
297        if is_box_prompt:
298            fg_mask = binary_dilation(fg_mask, iterations=5)
299        bb = np.where(fg_mask)
300        bb = tuple(
301            slice(int(b.min()), int(b.max() + 1)) for b in bb
302        )
303        bb = make_square(bb, fg_mask.shape)
304
305        offset = np.array([b.start for b in bb])
306        if is_box_prompt:
307            prompt = g.attrs["box"]
308            prompt = np.array(
309                [prompt[:2], prompt[2:]]
310            ) - offset
311        else:
312            prompt = (g.attrs["point_coords"] - offset, g.attrs["point_labels"])
313
314        im = _enhance_image(image[bb])
315        gt, mask1, mask2 = gt[bb], mask1[bb], mask2[bb]
316
317        if have_model3:
318            mask3 = mask3[bb]
319
320        im1 = _overlay_mask(im, mask1)
321        im1 = _overlay_outline(im1, gt, outline_dilation)
322        im1 = overlay_prompts(im1, prompt)
323        ax = axis[0] if i is None else axis[i, 0]
324        ax.axis("off")
325        ax.imshow(im1)
326
327        # We put the third set of comparsion point in between
328        # so that the comparison looks -> default, generalist, specialist
329        if have_model3:
330            im3 = _overlay_mask(im, mask3)
331            im3 = _overlay_outline(im3, gt, outline_dilation)
332            im3 = overlay_prompts(im3, prompt)
333            ax = axis[1] if i is None else axis[i, 1]
334            ax.axis("off")
335            ax.imshow(im3)
336
337            nexax = 2
338        else:
339            nexax = 1
340
341        im2 = _overlay_mask(im, mask2)
342        im2 = _overlay_outline(im2, gt, outline_dilation)
343        im2 = overlay_prompts(im2, prompt)
344        ax = axis[nexax] if i is None else axis[i, nexax]
345        ax.axis("off")
346        ax.imshow(im2)
347
348    cols = 3 if have_model3 else 2
349    if plot_folder is None:
350        fig, axis = plt.subplots(n_rows, cols)
351        for i, (_, row) in enumerate(result.iterrows()):
352            plot_ax(axis, i, row)
353        plt.show()
354    else:
355        for i, (_, row) in enumerate(result.iterrows()):
356            fig, axis = plt.subplots(1, cols)
357            plot_ax(axis, None, row)
358            plt.subplots_adjust(wspace=0.05, hspace=0)
359            plt.savefig(os.path.join(plot_folder, f"{sample_name}_{i}.svg"), bbox_inches="tight")
360            plt.close()
361
362
363def _compare_prompts(
364    f, prefix, n_images_per_sample, min_size, sample_name, plot_folder,
365    point_radius, outline_dilation, have_model3,
366):
367    box_eval = _evaluate_samples(f, prefix, min_size)
368    if plot_folder is None:
369        plot_folder1, plot_folder2 = None, None
370    else:
371        plot_folder1 = os.path.join(plot_folder, "advantage1")
372        plot_folder2 = os.path.join(plot_folder, "advantage2")
373        os.makedirs(plot_folder1, exist_ok=True)
374        os.makedirs(plot_folder2, exist_ok=True)
375    _compare_eval(
376        f, box_eval, "advantage1", n_images_per_sample, prefix, sample_name, plot_folder1,
377        point_radius, outline_dilation, have_model3,
378    )
379    _compare_eval(
380        f, box_eval, "advantage2", n_images_per_sample, prefix, sample_name, plot_folder2,
381        point_radius, outline_dilation, have_model3,
382    )
383
384
385def _compare_models(
386    path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3,
387):
388    sample_name = Path(path).stem
389    with h5py.File(path, "r") as f:
390        if plot_folder is None:
391            plot_folder_points, plot_folder_box = None, None
392        else:
393            plot_folder_points = os.path.join(plot_folder, "points")
394            plot_folder_box = os.path.join(plot_folder, "box")
395        _compare_prompts(
396            f, "points", n_images_per_sample, min_size, sample_name, plot_folder_points,
397            point_radius, outline_dilation, have_model3,
398        )
399        _compare_prompts(
400            f, "box", n_images_per_sample, min_size, sample_name, plot_folder_box,
401            point_radius, outline_dilation, have_model3,
402        )
403
404
405def model_comparison(
406    output_folder: Union[str, os.PathLike],
407    n_images_per_sample: int,
408    min_size: int,
409    plot_folder: Optional[Union[str, os.PathLike]] = None,
410    point_radius: int = 4,
411    outline_dilation: int = 0,
412    have_model3=False,
413) -> None:
414    """Create images for a qualitative model comparision.
415
416    Args:
417        output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`.
418        n_images_per_sample: The number of images to generate per precomputed sample.
419        min_size: The min size of ground-truth objects to take into account.
420        plot_folder: The folder where to save the plots. If not given the plots will be displayed.
421        point_radius: The radius of the point overlay.
422        outline_dilation: The dilation factor of the outline overlay.
423    """
424    files = glob(os.path.join(output_folder, "*.h5"))
425    for path in tqdm(files):
426        _compare_models(
427            path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3,
428        )
429
430
431#
432# Quick visual evaluation with napari
433#
434
435
436def _check_group(g, show_points):
437    import napari
438
439    image = g["image"][:]
440    gt = g["gt_mask"][:]
441    if show_points:
442        m1 = g["points/mask1"][:]
443        m2 = g["points/mask2"][:]
444        points = g.attrs["point_coords"]
445    else:
446        m1 = g["box/mask1"][:]
447        m2 = g["box/mask2"][:]
448        box = g.attrs["box"]
449        box = np.array([
450            [box[0], box[1]], [box[2], box[3]]
451        ])
452
453    v = napari.Viewer()
454    v.add_image(image)
455    v.add_labels(gt)
456    v.add_labels(m1)
457    v.add_labels(m2)
458    if show_points:
459        # TODO use point labels for coloring
460        v.add_points(
461            points,
462            edge_color="#00FF00",
463            symbol="o",
464            face_color="transparent",
465            edge_width=0.5,
466            size=12,
467        )
468    else:
469        v.add_shapes(
470            box, face_color="transparent", edge_color="green", edge_width=4,
471        )
472    napari.run()
473
474
475def model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None:
476    """Use napari to display the qualtiative comparison results for two models.
477
478    Args:
479        output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`.
480        show_points: Whether to show the results for point or for box prompts.
481    """
482    files = glob(os.path.join(output_folder, "*.h5"))
483    for path in files:
484        print("Comparing models in", path)
485        with h5py.File(path, "r") as f:
486            for name, g in f.items():
487                if name == "image":
488                    continue
489                _check_group(g, show_points=show_points)
def generate_data_for_model_comparison( loader: torch.utils.data.dataloader.DataLoader, output_folder: Union[str, os.PathLike], model_type1: str, model_type2: str, n_samples: int, model_type3: Optional[str] = None, checkpoint1: Union[str, os.PathLike, NoneType] = None, checkpoint2: Union[str, os.PathLike, NoneType] = None, checkpoint3: Union[str, os.PathLike, NoneType] = None) -> None:
118def generate_data_for_model_comparison(
119    loader: torch.utils.data.DataLoader,
120    output_folder: Union[str, os.PathLike],
121    model_type1: str,
122    model_type2: str,
123    n_samples: int,
124    model_type3: Optional[str] = None,
125    checkpoint1: Optional[Union[str, os.PathLike]] = None,
126    checkpoint2: Optional[Union[str, os.PathLike]] = None,
127    checkpoint3: Optional[Union[str, os.PathLike]] = None,
128) -> None:
129    """Generate samples for qualitative model comparison.
130
131    This precomputes the input for `model_comparison` and `model_comparison_with_napari`.
132
133    Args:
134        loader: The torch dataloader from which samples are drawn.
135        output_folder: The folder where the samples will be saved.
136        model_type1: The first model to use for comparison.
137            The value needs to be a valid model_type for `micro_sam.util.get_sam_model`.
138        model_type2: The second model to use for comparison.
139            The value needs to be a valid model_type for `micro_sam.util.get_sam_model`.
140        n_samples: The number of samples to draw from the dataloader.
141        checkpoint1: Optional checkpoint for the first model.
142        checkpoint2: Optional checkpoint for the second model.
143    """
144    prompt_generator = PointAndBoxPromptGenerator(
145        n_positive_points=1,
146        n_negative_points=0,
147        dilation_strength=3,
148        get_point_prompts=True,
149        get_box_prompts=True,
150    )
151    predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1)
152    predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2)
153
154    if model_type3 is not None:
155        predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3)
156    else:
157        predictor3 = None
158
159    _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder)

Generate samples for qualitative model comparison.

This precomputes the input for model_comparison and model_comparison_with_napari.

Arguments:
  • loader: The torch dataloader from which samples are drawn.
  • output_folder: The folder where the samples will be saved.
  • model_type1: The first model to use for comparison. The value needs to be a valid model_type for micro_sam.util.get_sam_model.
  • model_type2: The second model to use for comparison. The value needs to be a valid model_type for micro_sam.util.get_sam_model.
  • n_samples: The number of samples to draw from the dataloader.
  • checkpoint1: Optional checkpoint for the first model.
  • checkpoint2: Optional checkpoint for the second model.
def model_comparison( output_folder: Union[str, os.PathLike], n_images_per_sample: int, min_size: int, plot_folder: Union[str, os.PathLike, NoneType] = None, point_radius: int = 4, outline_dilation: int = 0, have_model3=False) -> None:
406def model_comparison(
407    output_folder: Union[str, os.PathLike],
408    n_images_per_sample: int,
409    min_size: int,
410    plot_folder: Optional[Union[str, os.PathLike]] = None,
411    point_radius: int = 4,
412    outline_dilation: int = 0,
413    have_model3=False,
414) -> None:
415    """Create images for a qualitative model comparision.
416
417    Args:
418        output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`.
419        n_images_per_sample: The number of images to generate per precomputed sample.
420        min_size: The min size of ground-truth objects to take into account.
421        plot_folder: The folder where to save the plots. If not given the plots will be displayed.
422        point_radius: The radius of the point overlay.
423        outline_dilation: The dilation factor of the outline overlay.
424    """
425    files = glob(os.path.join(output_folder, "*.h5"))
426    for path in tqdm(files):
427        _compare_models(
428            path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3,
429        )

Create images for a qualitative model comparision.

Arguments:
  • output_folder: The folder with the data precomputed by generate_data_for_model_comparison.
  • n_images_per_sample: The number of images to generate per precomputed sample.
  • min_size: The min size of ground-truth objects to take into account.
  • plot_folder: The folder where to save the plots. If not given the plots will be displayed.
  • point_radius: The radius of the point overlay.
  • outline_dilation: The dilation factor of the outline overlay.
def model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None:
476def model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None:
477    """Use napari to display the qualtiative comparison results for two models.
478
479    Args:
480        output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`.
481        show_points: Whether to show the results for point or for box prompts.
482    """
483    files = glob(os.path.join(output_folder, "*.h5"))
484    for path in files:
485        print("Comparing models in", path)
486        with h5py.File(path, "r") as f:
487            for name, g in f.items():
488                if name == "image":
489                    continue
490                _check_group(g, show_points=show_points)

Use napari to display the qualtiative comparison results for two models.

Arguments: