micro_sam.evaluation.model_comparison
Functionality for qualitative comparison of Segment Anything models on microscopy data.
1"""Functionality for qualitative comparison of Segment Anything models on microscopy data. 2""" 3 4import os 5from glob import glob 6from tqdm import tqdm 7from pathlib import Path 8from functools import partial 9from typing import Optional, Union 10 11import h5py 12import numpy as np 13import pandas as pd 14import matplotlib.pyplot as plt 15import skimage.draw as draw 16from skimage import exposure 17from scipy.ndimage import binary_dilation 18from skimage.segmentation import relabel_sequential, find_boundaries 19 20import torch 21 22from .. import util 23from ..prompt_generators import PointAndBoxPromptGenerator 24from ..prompt_based_segmentation import segment_from_box, segment_from_points 25 26 27# 28# Compute all required data for the model comparison 29# 30 31 32def _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder): 33 i = 0 34 os.makedirs(output_folder, exist_ok=True) 35 36 for x, y in tqdm(loader, total=n_samples): 37 out_path = os.path.join(output_folder, f"sample_{i}.h5") 38 39 im = x.numpy().squeeze() 40 if im.ndim == 3 and im.shape[0] == 3: 41 im = im.transpose((1, 2, 0)) 42 43 gt = y.numpy().squeeze().astype("uint32") 44 gt = relabel_sequential(gt)[0] 45 46 emb1 = util.precompute_image_embeddings(predictor1, im, ndim=2) 47 util.set_precomputed(predictor1, emb1) 48 49 emb2 = util.precompute_image_embeddings(predictor2, im, ndim=2) 50 util.set_precomputed(predictor2, emb2) 51 52 if predictor3 is not None: 53 emb3 = util.precompute_image_embeddings(predictor3, im, ndim=2) 54 util.set_precomputed(predictor3, emb3) 55 56 with h5py.File(out_path, "a") as f: 57 f.create_dataset("image", data=im, compression="gzip") 58 59 gt_ids = np.unique(gt)[1:] 60 centers, boxes = util.get_centers_and_bounding_boxes(gt) 61 centers = [centers[gt_id] for gt_id in gt_ids] 62 boxes = [boxes[gt_id] for gt_id in gt_ids] 63 64 object_masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids) 65 coords, labels, boxes, _ = prompt_generator( 66 segmentation=object_masks, 67 bbox_coordinates=boxes, 68 center_coordinates=centers, 69 ) 70 71 for idx, gt_id in tqdm(enumerate(gt_ids), total=len(gt_ids)): 72 73 # Box prompts: 74 # Reorder the coordinates so that they match the normal python convention. 75 box = boxes[idx][[1, 0, 3, 2]] 76 mask1_box = segment_from_box(predictor1, box) 77 mask2_box = segment_from_box(predictor2, box) 78 mask1_box, mask2_box = mask1_box.squeeze(), mask2_box.squeeze() 79 80 if predictor3 is not None: 81 mask3_box = segment_from_box(predictor3, box) 82 mask3_box = mask3_box.squeeze() 83 84 # Point prompts: 85 # Reorder the coordinates so that they match the normal python convention. 86 point_coords, point_labels = np.array(coords[idx])[:, ::-1], np.array(labels[idx]) 87 mask1_points = segment_from_points(predictor1, point_coords, point_labels) 88 mask2_points = segment_from_points(predictor2, point_coords, point_labels) 89 mask1_points, mask2_points = mask1_points.squeeze(), mask2_points.squeeze() 90 91 if predictor3 is not None: 92 mask3_points = segment_from_points(predictor3, point_coords, point_labels) 93 mask3_points = mask3_points.squeeze() 94 95 gt_mask = gt == gt_id 96 with h5py.File(out_path, "a") as f: 97 g = f.create_group(str(gt_id)) 98 g.attrs["point_coords"] = point_coords 99 g.attrs["point_labels"] = point_labels 100 g.attrs["box"] = box 101 102 g.create_dataset("gt_mask", data=gt_mask, compression="gzip") 103 g.create_dataset("box/mask1", data=mask1_box.astype("uint8"), compression="gzip") 104 g.create_dataset("box/mask2", data=mask2_box.astype("uint8"), compression="gzip") 105 g.create_dataset("points/mask1", data=mask1_points.astype("uint8"), compression="gzip") 106 g.create_dataset("points/mask2", data=mask2_points.astype("uint8"), compression="gzip") 107 108 if predictor3 is not None: 109 g.create_dataset("box/mask3", data=mask3_box.astype("uint8"), compression="gzip") 110 g.create_dataset("points/mask3", data=mask3_points.astype("uint8"), compression="gzip") 111 112 i += 1 113 if i >= n_samples: 114 return 115 116 117def generate_data_for_model_comparison( 118 loader: torch.utils.data.DataLoader, 119 output_folder: Union[str, os.PathLike], 120 model_type1: str, 121 model_type2: str, 122 n_samples: int, 123 model_type3: Optional[str] = None, 124 checkpoint1: Optional[Union[str, os.PathLike]] = None, 125 checkpoint2: Optional[Union[str, os.PathLike]] = None, 126 checkpoint3: Optional[Union[str, os.PathLike]] = None, 127) -> None: 128 """Generate samples for qualitative model comparison. 129 130 This precomputes the input for `model_comparison` and `model_comparison_with_napari`. 131 132 Args: 133 loader: The torch dataloader from which samples are drawn. 134 output_folder: The folder where the samples will be saved. 135 model_type1: The first model to use for comparison. 136 The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. 137 model_type2: The second model to use for comparison. 138 The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. 139 n_samples: The number of samples to draw from the dataloader. 140 checkpoint1: Optional checkpoint for the first model. 141 checkpoint2: Optional checkpoint for the second model. 142 """ 143 prompt_generator = PointAndBoxPromptGenerator( 144 n_positive_points=1, 145 n_negative_points=0, 146 dilation_strength=3, 147 get_point_prompts=True, 148 get_box_prompts=True, 149 ) 150 predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1) 151 predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2) 152 153 if model_type3 is not None: 154 predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3) 155 else: 156 predictor3 = None 157 158 _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder) 159 160 161# 162# Visual evaluation according to metrics 163# 164 165 166def _evaluate_samples(f, prefix, min_size): 167 eval_result = { 168 "gt_id": [], 169 "score1": [], 170 "score2": [], 171 } 172 for name, group in f.items(): 173 if name == "image": 174 continue 175 176 gt_mask = group["gt_mask"][:] 177 178 size = gt_mask.sum() 179 if size < min_size: 180 continue 181 182 m1 = group[f"{prefix}/mask1"][:] 183 m2 = group[f"{prefix}/mask2"][:] 184 185 score1 = util.compute_iou(gt_mask, m1) 186 score2 = util.compute_iou(gt_mask, m2) 187 188 eval_result["gt_id"].append(name) 189 eval_result["score1"].append(score1) 190 eval_result["score2"].append(score2) 191 192 eval_result = pd.DataFrame.from_dict(eval_result) 193 eval_result["advantage1"] = eval_result["score1"] - eval_result["score2"] 194 eval_result["advantage2"] = eval_result["score2"] - eval_result["score1"] 195 return eval_result 196 197 198def _overlay_mask(image, mask, alpha=0.6): 199 assert image.ndim in (2, 3) 200 # overlay the mask 201 if image.ndim == 2: 202 overlay = np.stack([image, image, image]).transpose((1, 2, 0)) 203 else: 204 overlay = image 205 assert overlay.shape[-1] == 3 206 mask_overlay = np.zeros_like(overlay) 207 mask_overlay[mask == 1] = [255, 0, 0] 208 alpha = alpha 209 overlay = alpha * overlay + (1.0 - alpha) * mask_overlay 210 return overlay.astype("uint8") 211 212 213def _enhance_image(im, do_norm=True): 214 # apply CLAHE to improve the image quality 215 if do_norm: 216 im -= im.min(axis=(0, 1), keepdims=True) 217 im /= (im.max(axis=(0, 1), keepdims=True) + 1e-6) 218 im = exposure.equalize_adapthist(im) 219 im *= 255 220 return im 221 222 223def _overlay_outline(im, mask, outline_dilation): 224 outline = find_boundaries(mask) 225 if outline_dilation > 0: 226 outline = binary_dilation(outline, iterations=outline_dilation) 227 overlay = im.copy() 228 overlay[outline] = [255, 255, 0] 229 return overlay 230 231 232def _overlay_box(im, prompt, outline_dilation): 233 start, end = prompt 234 rr, cc = draw.rectangle_perimeter(start, end=end, shape=im.shape[:2]) 235 236 box_outline = np.zeros(im.shape[:2], dtype="bool") 237 box_outline[rr, cc] = 1 238 if outline_dilation > 0: 239 box_outline = binary_dilation(box_outline, iterations=outline_dilation) 240 241 overlay = im.copy() 242 overlay[box_outline] = [0, 255, 255] 243 244 return overlay 245 246 247# NOTE: we currently only support a single point 248def _overlay_points(im, prompt, radius): 249 coords, labels = prompt 250 # make sure we have a single positive prompt, other options are 251 # currently not supported 252 assert coords.shape[0] == labels.shape[0] == 1 253 assert labels[0] == 1 254 255 rr, cc = draw.disk(coords[0], radius, shape=im.shape[:2]) 256 overlay = im.copy() 257 draw.set_color(overlay, (rr, cc), [0, 255, 255], alpha=1.0) 258 259 return overlay 260 261 262def _compare_eval( 263 f, eval_result, advantage_column, n_images_per_sample, prefix, 264 sample_name, plot_folder, point_radius, outline_dilation, have_model3, 265): 266 result = eval_result.sort_values(advantage_column, ascending=False).iloc[:n_images_per_sample] 267 n_rows = result.shape[0] 268 269 image = f["image"][:] 270 is_box_prompt = prefix == "box" 271 overlay_prompts = partial(_overlay_box, outline_dilation=outline_dilation) if is_box_prompt else\ 272 partial(_overlay_points, radius=point_radius) 273 274 def make_square(bb, shape): 275 box_shape = [b.stop - b.start for b in bb] 276 longest_side = max(box_shape) 277 padding = [(longest_side - sh) // 2 for sh in box_shape] 278 bb = tuple( 279 slice(max(b.start - pad, 0), min(b.stop + pad, sh)) for b, pad, sh in zip(bb, padding, shape) 280 ) 281 return bb 282 283 def plot_ax(axis, i, row): 284 g = f[row.gt_id] 285 286 gt = g["gt_mask"][:] 287 mask1 = g[f"{prefix}/mask1"][:] 288 mask2 = g[f"{prefix}/mask2"][:] 289 290 # The mask3 is just for comparison purpose, we just plot the crops as it is. 291 if have_model3: 292 mask3 = g[f"{prefix}/mask3"][:] 293 294 fg_mask = (gt + mask1 + mask2) > 0 295 # if this is a box prompt we dilate the mask so that the bounding box 296 # can be seen 297 if is_box_prompt: 298 fg_mask = binary_dilation(fg_mask, iterations=5) 299 bb = np.where(fg_mask) 300 bb = tuple( 301 slice(int(b.min()), int(b.max() + 1)) for b in bb 302 ) 303 bb = make_square(bb, fg_mask.shape) 304 305 offset = np.array([b.start for b in bb]) 306 if is_box_prompt: 307 prompt = g.attrs["box"] 308 prompt = np.array( 309 [prompt[:2], prompt[2:]] 310 ) - offset 311 else: 312 prompt = (g.attrs["point_coords"] - offset, g.attrs["point_labels"]) 313 314 im = _enhance_image(image[bb]) 315 gt, mask1, mask2 = gt[bb], mask1[bb], mask2[bb] 316 317 if have_model3: 318 mask3 = mask3[bb] 319 320 im1 = _overlay_mask(im, mask1) 321 im1 = _overlay_outline(im1, gt, outline_dilation) 322 im1 = overlay_prompts(im1, prompt) 323 ax = axis[0] if i is None else axis[i, 0] 324 ax.axis("off") 325 ax.imshow(im1) 326 327 # We put the third set of comparsion point in between 328 # so that the comparison looks -> default, generalist, specialist 329 if have_model3: 330 im3 = _overlay_mask(im, mask3) 331 im3 = _overlay_outline(im3, gt, outline_dilation) 332 im3 = overlay_prompts(im3, prompt) 333 ax = axis[1] if i is None else axis[i, 1] 334 ax.axis("off") 335 ax.imshow(im3) 336 337 nexax = 2 338 else: 339 nexax = 1 340 341 im2 = _overlay_mask(im, mask2) 342 im2 = _overlay_outline(im2, gt, outline_dilation) 343 im2 = overlay_prompts(im2, prompt) 344 ax = axis[nexax] if i is None else axis[i, nexax] 345 ax.axis("off") 346 ax.imshow(im2) 347 348 cols = 3 if have_model3 else 2 349 if plot_folder is None: 350 fig, axis = plt.subplots(n_rows, cols) 351 for i, (_, row) in enumerate(result.iterrows()): 352 plot_ax(axis, i, row) 353 plt.show() 354 else: 355 for i, (_, row) in enumerate(result.iterrows()): 356 fig, axis = plt.subplots(1, cols) 357 plot_ax(axis, None, row) 358 plt.subplots_adjust(wspace=0.05, hspace=0) 359 plt.savefig(os.path.join(plot_folder, f"{sample_name}_{i}.svg"), bbox_inches="tight") 360 plt.close() 361 362 363def _compare_prompts( 364 f, prefix, n_images_per_sample, min_size, sample_name, plot_folder, 365 point_radius, outline_dilation, have_model3, 366): 367 box_eval = _evaluate_samples(f, prefix, min_size) 368 if plot_folder is None: 369 plot_folder1, plot_folder2 = None, None 370 else: 371 plot_folder1 = os.path.join(plot_folder, "advantage1") 372 plot_folder2 = os.path.join(plot_folder, "advantage2") 373 os.makedirs(plot_folder1, exist_ok=True) 374 os.makedirs(plot_folder2, exist_ok=True) 375 _compare_eval( 376 f, box_eval, "advantage1", n_images_per_sample, prefix, sample_name, plot_folder1, 377 point_radius, outline_dilation, have_model3, 378 ) 379 _compare_eval( 380 f, box_eval, "advantage2", n_images_per_sample, prefix, sample_name, plot_folder2, 381 point_radius, outline_dilation, have_model3, 382 ) 383 384 385def _compare_models( 386 path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3, 387): 388 sample_name = Path(path).stem 389 with h5py.File(path, "r") as f: 390 if plot_folder is None: 391 plot_folder_points, plot_folder_box = None, None 392 else: 393 plot_folder_points = os.path.join(plot_folder, "points") 394 plot_folder_box = os.path.join(plot_folder, "box") 395 _compare_prompts( 396 f, "points", n_images_per_sample, min_size, sample_name, plot_folder_points, 397 point_radius, outline_dilation, have_model3, 398 ) 399 _compare_prompts( 400 f, "box", n_images_per_sample, min_size, sample_name, plot_folder_box, 401 point_radius, outline_dilation, have_model3, 402 ) 403 404 405def model_comparison( 406 output_folder: Union[str, os.PathLike], 407 n_images_per_sample: int, 408 min_size: int, 409 plot_folder: Optional[Union[str, os.PathLike]] = None, 410 point_radius: int = 4, 411 outline_dilation: int = 0, 412 have_model3=False, 413) -> None: 414 """Create images for a qualitative model comparision. 415 416 Args: 417 output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`. 418 n_images_per_sample: The number of images to generate per precomputed sample. 419 min_size: The min size of ground-truth objects to take into account. 420 plot_folder: The folder where to save the plots. If not given the plots will be displayed. 421 point_radius: The radius of the point overlay. 422 outline_dilation: The dilation factor of the outline overlay. 423 """ 424 files = glob(os.path.join(output_folder, "*.h5")) 425 for path in tqdm(files): 426 _compare_models( 427 path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3, 428 ) 429 430 431# 432# Quick visual evaluation with napari 433# 434 435 436def _check_group(g, show_points): 437 import napari 438 439 image = g["image"][:] 440 gt = g["gt_mask"][:] 441 if show_points: 442 m1 = g["points/mask1"][:] 443 m2 = g["points/mask2"][:] 444 points = g.attrs["point_coords"] 445 else: 446 m1 = g["box/mask1"][:] 447 m2 = g["box/mask2"][:] 448 box = g.attrs["box"] 449 box = np.array([ 450 [box[0], box[1]], [box[2], box[3]] 451 ]) 452 453 v = napari.Viewer() 454 v.add_image(image) 455 v.add_labels(gt) 456 v.add_labels(m1) 457 v.add_labels(m2) 458 if show_points: 459 # TODO use point labels for coloring 460 v.add_points( 461 points, 462 edge_color="#00FF00", 463 symbol="o", 464 face_color="transparent", 465 edge_width=0.5, 466 size=12, 467 ) 468 else: 469 v.add_shapes( 470 box, face_color="transparent", edge_color="green", edge_width=4, 471 ) 472 napari.run() 473 474 475def model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None: 476 """Use napari to display the qualtiative comparison results for two models. 477 478 Args: 479 output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`. 480 show_points: Whether to show the results for point or for box prompts. 481 """ 482 files = glob(os.path.join(output_folder, "*.h5")) 483 for path in files: 484 print("Comparing models in", path) 485 with h5py.File(path, "r") as f: 486 for name, g in f.items(): 487 if name == "image": 488 continue 489 _check_group(g, show_points=show_points)
def
generate_data_for_model_comparison( loader: torch.utils.data.dataloader.DataLoader, output_folder: Union[str, os.PathLike], model_type1: str, model_type2: str, n_samples: int, model_type3: Optional[str] = None, checkpoint1: Union[str, os.PathLike, NoneType] = None, checkpoint2: Union[str, os.PathLike, NoneType] = None, checkpoint3: Union[str, os.PathLike, NoneType] = None) -> None:
118def generate_data_for_model_comparison( 119 loader: torch.utils.data.DataLoader, 120 output_folder: Union[str, os.PathLike], 121 model_type1: str, 122 model_type2: str, 123 n_samples: int, 124 model_type3: Optional[str] = None, 125 checkpoint1: Optional[Union[str, os.PathLike]] = None, 126 checkpoint2: Optional[Union[str, os.PathLike]] = None, 127 checkpoint3: Optional[Union[str, os.PathLike]] = None, 128) -> None: 129 """Generate samples for qualitative model comparison. 130 131 This precomputes the input for `model_comparison` and `model_comparison_with_napari`. 132 133 Args: 134 loader: The torch dataloader from which samples are drawn. 135 output_folder: The folder where the samples will be saved. 136 model_type1: The first model to use for comparison. 137 The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. 138 model_type2: The second model to use for comparison. 139 The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. 140 n_samples: The number of samples to draw from the dataloader. 141 checkpoint1: Optional checkpoint for the first model. 142 checkpoint2: Optional checkpoint for the second model. 143 """ 144 prompt_generator = PointAndBoxPromptGenerator( 145 n_positive_points=1, 146 n_negative_points=0, 147 dilation_strength=3, 148 get_point_prompts=True, 149 get_box_prompts=True, 150 ) 151 predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1) 152 predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2) 153 154 if model_type3 is not None: 155 predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3) 156 else: 157 predictor3 = None 158 159 _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder)
Generate samples for qualitative model comparison.
This precomputes the input for model_comparison
and model_comparison_with_napari
.
Arguments:
- loader: The torch dataloader from which samples are drawn.
- output_folder: The folder where the samples will be saved.
- model_type1: The first model to use for comparison.
The value needs to be a valid model_type for
micro_sam.util.get_sam_model
. - model_type2: The second model to use for comparison.
The value needs to be a valid model_type for
micro_sam.util.get_sam_model
. - n_samples: The number of samples to draw from the dataloader.
- checkpoint1: Optional checkpoint for the first model.
- checkpoint2: Optional checkpoint for the second model.
def
model_comparison( output_folder: Union[str, os.PathLike], n_images_per_sample: int, min_size: int, plot_folder: Union[str, os.PathLike, NoneType] = None, point_radius: int = 4, outline_dilation: int = 0, have_model3=False) -> None:
406def model_comparison( 407 output_folder: Union[str, os.PathLike], 408 n_images_per_sample: int, 409 min_size: int, 410 plot_folder: Optional[Union[str, os.PathLike]] = None, 411 point_radius: int = 4, 412 outline_dilation: int = 0, 413 have_model3=False, 414) -> None: 415 """Create images for a qualitative model comparision. 416 417 Args: 418 output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`. 419 n_images_per_sample: The number of images to generate per precomputed sample. 420 min_size: The min size of ground-truth objects to take into account. 421 plot_folder: The folder where to save the plots. If not given the plots will be displayed. 422 point_radius: The radius of the point overlay. 423 outline_dilation: The dilation factor of the outline overlay. 424 """ 425 files = glob(os.path.join(output_folder, "*.h5")) 426 for path in tqdm(files): 427 _compare_models( 428 path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3, 429 )
Create images for a qualitative model comparision.
Arguments:
- output_folder: The folder with the data precomputed by
generate_data_for_model_comparison
. - n_images_per_sample: The number of images to generate per precomputed sample.
- min_size: The min size of ground-truth objects to take into account.
- plot_folder: The folder where to save the plots. If not given the plots will be displayed.
- point_radius: The radius of the point overlay.
- outline_dilation: The dilation factor of the outline overlay.
def
model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None:
476def model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None: 477 """Use napari to display the qualtiative comparison results for two models. 478 479 Args: 480 output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`. 481 show_points: Whether to show the results for point or for box prompts. 482 """ 483 files = glob(os.path.join(output_folder, "*.h5")) 484 for path in files: 485 print("Comparing models in", path) 486 with h5py.File(path, "r") as f: 487 for name, g in f.items(): 488 if name == "image": 489 continue 490 _check_group(g, show_points=show_points)
Use napari to display the qualtiative comparison results for two models.
Arguments:
- output_folder: The folder with the data precomputed by
generate_data_for_model_comparison
. - show_points: Whether to show the results for point or for box prompts.