micro_sam.evaluation.model_comparison
Functionality for qualitative comparison of Segment Anything models on microscopy data.
1"""Functionality for qualitative comparison of Segment Anything models on microscopy data. 2""" 3 4import os 5from glob import glob 6from tqdm import tqdm 7from pathlib import Path 8from functools import partial 9from typing import Optional, Union, Dict, Any 10 11import h5py 12import numpy as np 13import pandas as pd 14import matplotlib.pyplot as plt 15import skimage.draw as draw 16from skimage import exposure 17from scipy.ndimage import binary_dilation 18from skimage.segmentation import find_boundaries 19 20from bioimage_cpp.segmentation import relabel_sequential 21 22import torch 23 24from .. import util 25from ..prompt_generators import PointAndBoxPromptGenerator 26from ..prompt_based_segmentation import segment_from_box, segment_from_points 27 28 29# 30# Compute all required data for the model comparison 31# 32 33 34def _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder): 35 i = 0 36 os.makedirs(output_folder, exist_ok=True) 37 38 for x, y in tqdm(loader, total=n_samples): 39 out_path = os.path.join(output_folder, f"sample_{i}.h5") 40 41 im = x.numpy().squeeze() 42 if im.ndim == 3 and im.shape[0] == 3: 43 im = im.transpose((1, 2, 0)) 44 45 gt = y.numpy().squeeze().astype("uint32") 46 gt = relabel_sequential(gt)[0] 47 48 emb1 = util.precompute_image_embeddings(predictor1, im, ndim=2) 49 util.set_precomputed(predictor1, emb1) 50 51 emb2 = util.precompute_image_embeddings(predictor2, im, ndim=2) 52 util.set_precomputed(predictor2, emb2) 53 54 if predictor3 is not None: 55 emb3 = util.precompute_image_embeddings(predictor3, im, ndim=2) 56 util.set_precomputed(predictor3, emb3) 57 58 with h5py.File(out_path, "a") as f: 59 f.create_dataset("image", data=im, compression="gzip") 60 61 gt_ids = np.unique(gt)[1:] 62 centers, boxes = util.get_centers_and_bounding_boxes(gt) 63 centers = [centers[gt_id] for gt_id in gt_ids] 64 boxes = [boxes[gt_id] for gt_id in gt_ids] 65 66 object_masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids) 67 coords, labels, boxes, _ = prompt_generator( 68 segmentation=object_masks, 69 bbox_coordinates=boxes, 70 center_coordinates=centers, 71 ) 72 73 for idx, gt_id in tqdm(enumerate(gt_ids), total=len(gt_ids)): 74 75 # Box prompts: 76 # Reorder the coordinates so that they match the normal python convention. 77 box = boxes[idx][[1, 0, 3, 2]] 78 mask1_box = segment_from_box(predictor1, box) 79 mask2_box = segment_from_box(predictor2, box) 80 mask1_box, mask2_box = mask1_box.squeeze(), mask2_box.squeeze() 81 82 if predictor3 is not None: 83 mask3_box = segment_from_box(predictor3, box) 84 mask3_box = mask3_box.squeeze() 85 86 # Point prompts: 87 # Reorder the coordinates so that they match the normal python convention. 88 point_coords, point_labels = np.array(coords[idx])[:, ::-1], np.array(labels[idx]) 89 mask1_points = segment_from_points(predictor1, point_coords, point_labels) 90 mask2_points = segment_from_points(predictor2, point_coords, point_labels) 91 mask1_points, mask2_points = mask1_points.squeeze(), mask2_points.squeeze() 92 93 if predictor3 is not None: 94 mask3_points = segment_from_points(predictor3, point_coords, point_labels) 95 mask3_points = mask3_points.squeeze() 96 97 gt_mask = gt == gt_id 98 with h5py.File(out_path, "a") as f: 99 g = f.create_group(str(gt_id)) 100 g.attrs["point_coords"] = point_coords 101 g.attrs["point_labels"] = point_labels 102 g.attrs["box"] = box 103 104 g.create_dataset("gt_mask", data=gt_mask, compression="gzip") 105 g.create_dataset("box/mask1", data=mask1_box.astype("uint8"), compression="gzip") 106 g.create_dataset("box/mask2", data=mask2_box.astype("uint8"), compression="gzip") 107 g.create_dataset("points/mask1", data=mask1_points.astype("uint8"), compression="gzip") 108 g.create_dataset("points/mask2", data=mask2_points.astype("uint8"), compression="gzip") 109 110 if predictor3 is not None: 111 g.create_dataset("box/mask3", data=mask3_box.astype("uint8"), compression="gzip") 112 g.create_dataset("points/mask3", data=mask3_points.astype("uint8"), compression="gzip") 113 114 i += 1 115 if i >= n_samples: 116 return 117 118 119def generate_data_for_model_comparison( 120 loader: torch.utils.data.DataLoader, 121 output_folder: Union[str, os.PathLike], 122 model_type1: str, 123 model_type2: str, 124 n_samples: int, 125 model_type3: Optional[str] = None, 126 checkpoint1: Optional[Union[str, os.PathLike]] = None, 127 checkpoint2: Optional[Union[str, os.PathLike]] = None, 128 checkpoint3: Optional[Union[str, os.PathLike]] = None, 129 peft_kwargs1: Optional[Dict[str, Any]] = None, 130 peft_kwargs2: Optional[Dict[str, Any]] = None, 131 peft_kwargs3: Optional[Dict[str, Any]] = None, 132) -> None: 133 """Generate samples for qualitative model comparison. 134 135 This precomputes the input for `model_comparison` and `model_comparison_with_napari`. 136 137 Args: 138 loader: The torch dataloader from which samples are drawn. 139 output_folder: The folder where the samples will be saved. 140 model_type1: The first model to use for comparison. 141 The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. 142 model_type2: The second model to use for comparison. 143 The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. 144 n_samples: The number of samples to draw from the dataloader. 145 checkpoint1: Optional checkpoint for the first model. 146 checkpoint2: Optional checkpoint for the second model. 147 checkpoint3: Optional checkpoint for the third model. 148 """ 149 prompt_generator = PointAndBoxPromptGenerator( 150 n_positive_points=1, 151 n_negative_points=0, 152 dilation_strength=3, 153 get_point_prompts=True, 154 get_box_prompts=True, 155 ) 156 157 predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1, peft_kwargs=peft_kwargs1) 158 predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2, peft_kwargs=peft_kwargs2) 159 160 if model_type3 is not None: 161 predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3, peft_kwargs=peft_kwargs3) 162 else: 163 predictor3 = None 164 165 _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder) 166 167 168# 169# Visual evaluation according to metrics 170# 171 172 173def _evaluate_samples(f, prefix, min_size): 174 eval_result = { 175 "gt_id": [], 176 "score1": [], 177 "score2": [], 178 } 179 for name, group in f.items(): 180 if name == "image": 181 continue 182 183 gt_mask = group["gt_mask"][:] 184 185 size = gt_mask.sum() 186 if size < min_size: 187 continue 188 189 m1 = group[f"{prefix}/mask1"][:] 190 m2 = group[f"{prefix}/mask2"][:] 191 192 score1 = util.compute_iou(gt_mask, m1) 193 score2 = util.compute_iou(gt_mask, m2) 194 195 eval_result["gt_id"].append(name) 196 eval_result["score1"].append(score1) 197 eval_result["score2"].append(score2) 198 199 eval_result = pd.DataFrame.from_dict(eval_result) 200 eval_result["advantage1"] = eval_result["score1"] - eval_result["score2"] 201 eval_result["advantage2"] = eval_result["score2"] - eval_result["score1"] 202 return eval_result 203 204 205def _overlay_mask(image, mask, alpha=0.6): 206 assert image.ndim in (2, 3) 207 # overlay the mask 208 if image.ndim == 2: 209 overlay = np.stack([image, image, image]).transpose((1, 2, 0)) 210 else: 211 overlay = image 212 assert overlay.shape[-1] == 3 213 mask_overlay = np.zeros_like(overlay) 214 mask_overlay[mask == 1] = [255, 0, 0] 215 alpha = alpha 216 overlay = alpha * overlay + (1.0 - alpha) * mask_overlay 217 return overlay.astype("uint8") 218 219 220def _enhance_image(im, do_norm=True): 221 # apply CLAHE to improve the image quality 222 if do_norm: 223 im -= im.min(axis=(0, 1), keepdims=True) 224 im /= (im.max(axis=(0, 1), keepdims=True) + 1e-6) 225 im = exposure.equalize_adapthist(im) 226 im *= 255 227 return im 228 229 230def _overlay_outline(im, mask, outline_dilation): 231 outline = find_boundaries(mask) 232 if outline_dilation > 0: 233 outline = binary_dilation(outline, iterations=outline_dilation) 234 overlay = im.copy() 235 overlay[outline] = [255, 255, 0] 236 return overlay 237 238 239def _overlay_box(im, prompt, outline_dilation): 240 start, end = prompt 241 rr, cc = draw.rectangle_perimeter(start, end=end, shape=im.shape[:2]) 242 243 box_outline = np.zeros(im.shape[:2], dtype="bool") 244 box_outline[rr, cc] = 1 245 if outline_dilation > 0: 246 box_outline = binary_dilation(box_outline, iterations=outline_dilation) 247 248 overlay = im.copy() 249 overlay[box_outline] = [0, 255, 255] 250 251 return overlay 252 253 254# NOTE: we currently only support a single point 255def _overlay_points(im, prompt, radius): 256 coords, labels = prompt 257 # make sure we have a single positive prompt, other options are 258 # currently not supported 259 assert coords.shape[0] == labels.shape[0] == 1 260 assert labels[0] == 1 261 262 rr, cc = draw.disk(coords[0], radius, shape=im.shape[:2]) 263 overlay = im.copy() 264 draw.set_color(overlay, (rr, cc), [0, 255, 255], alpha=1.0) 265 266 return overlay 267 268 269def _compare_eval( 270 f, eval_result, advantage_column, n_images_per_sample, prefix, sample_name, 271 plot_folder, point_radius, outline_dilation, have_model3, enhance_image, 272): 273 result = eval_result.sort_values(advantage_column, ascending=False).iloc[:n_images_per_sample] 274 n_rows = result.shape[0] 275 276 image = f["image"][:] 277 is_box_prompt = prefix == "box" 278 overlay_prompts = partial(_overlay_box, outline_dilation=outline_dilation) if is_box_prompt else\ 279 partial(_overlay_points, radius=point_radius) 280 281 def make_square(bb, shape): 282 box_shape = [b.stop - b.start for b in bb] 283 longest_side = max(box_shape) 284 padding = [(longest_side - sh) // 2 for sh in box_shape] 285 bb = tuple( 286 slice(max(b.start - pad, 0), min(b.stop + pad, sh)) for b, pad, sh in zip(bb, padding, shape) 287 ) 288 return bb 289 290 def plot_ax(axis, i, row): 291 g = f[row.gt_id] 292 293 gt = g["gt_mask"][:] 294 mask1 = g[f"{prefix}/mask1"][:] 295 mask2 = g[f"{prefix}/mask2"][:] 296 297 # The mask3 is just for comparison purpose, we just plot the crops as it is. 298 if have_model3: 299 mask3 = g[f"{prefix}/mask3"][:] 300 301 fg_mask = (gt + mask1 + mask2) > 0 302 # if this is a box prompt we dilate the mask so that the bounding box 303 # can be seen 304 if is_box_prompt: 305 fg_mask = binary_dilation(fg_mask, iterations=5) 306 bb = np.where(fg_mask) 307 bb = tuple( 308 slice(int(b.min()), int(b.max() + 1)) for b in bb 309 ) 310 bb = make_square(bb, fg_mask.shape) 311 312 offset = np.array([b.start for b in bb]) 313 if is_box_prompt: 314 prompt = g.attrs["box"] 315 prompt = np.array( 316 [prompt[:2], prompt[2:]] 317 ) - offset 318 else: 319 prompt = (g.attrs["point_coords"] - offset, g.attrs["point_labels"]) 320 321 if enhance_image: 322 im = _enhance_image(image[bb]) 323 else: 324 im = image[bb] 325 326 gt, mask1, mask2 = gt[bb], mask1[bb], mask2[bb] 327 328 if have_model3: 329 mask3 = mask3[bb] 330 331 im1 = _overlay_mask(im, mask1) 332 im1 = _overlay_outline(im1, gt, outline_dilation) 333 im1 = overlay_prompts(im1, prompt) 334 ax = axis[0] if i is None else axis[i, 0] 335 ax.axis("off") 336 ax.imshow(im1) 337 338 # We put the third set of comparsion point in between 339 # so that the comparison looks -> default, generalist, specialist 340 if have_model3: 341 im3 = _overlay_mask(im, mask3) 342 im3 = _overlay_outline(im3, gt, outline_dilation) 343 im3 = overlay_prompts(im3, prompt) 344 ax = axis[1] if i is None else axis[i, 1] 345 ax.axis("off") 346 ax.imshow(im3) 347 348 nexax = 2 349 else: 350 nexax = 1 351 352 im2 = _overlay_mask(im, mask2) 353 im2 = _overlay_outline(im2, gt, outline_dilation) 354 im2 = overlay_prompts(im2, prompt) 355 ax = axis[nexax] if i is None else axis[i, nexax] 356 ax.axis("off") 357 ax.imshow(im2) 358 359 cols = 3 if have_model3 else 2 360 if plot_folder is None: 361 fig, axis = plt.subplots(n_rows, cols) 362 for i, (_, row) in enumerate(result.iterrows()): 363 plot_ax(axis, i, row) 364 plt.show() 365 else: 366 for i, (_, row) in enumerate(result.iterrows()): 367 fig, axis = plt.subplots(1, cols) 368 plot_ax(axis, None, row) 369 plt.subplots_adjust(wspace=0.05, hspace=0) 370 plt.savefig(os.path.join(plot_folder, f"{sample_name}_{i}.svg"), bbox_inches="tight") 371 plt.close() 372 373 374def _compare_prompts( 375 f, prefix, n_images_per_sample, min_size, sample_name, plot_folder, 376 point_radius, outline_dilation, have_model3, enhance_image, 377): 378 box_eval = _evaluate_samples(f, prefix, min_size) 379 if plot_folder is None: 380 plot_folder1, plot_folder2 = None, None 381 else: 382 plot_folder1 = os.path.join(plot_folder, "advantage1") 383 plot_folder2 = os.path.join(plot_folder, "advantage2") 384 os.makedirs(plot_folder1, exist_ok=True) 385 os.makedirs(plot_folder2, exist_ok=True) 386 _compare_eval( 387 f, box_eval, "advantage1", n_images_per_sample, prefix, sample_name, plot_folder1, 388 point_radius, outline_dilation, have_model3, enhance_image, 389 ) 390 _compare_eval( 391 f, box_eval, "advantage2", n_images_per_sample, prefix, sample_name, plot_folder2, 392 point_radius, outline_dilation, have_model3, enhance_image, 393 ) 394 395 396def _compare_models( 397 path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3, enhance_image, 398): 399 sample_name = Path(path).stem 400 with h5py.File(path, "r") as f: 401 if plot_folder is None: 402 plot_folder_points, plot_folder_box = None, None 403 else: 404 plot_folder_points = os.path.join(plot_folder, "points") 405 plot_folder_box = os.path.join(plot_folder, "box") 406 _compare_prompts( 407 f, "points", n_images_per_sample, min_size, sample_name, plot_folder_points, 408 point_radius, outline_dilation, have_model3, enhance_image, 409 ) 410 _compare_prompts( 411 f, "box", n_images_per_sample, min_size, sample_name, plot_folder_box, 412 point_radius, outline_dilation, have_model3, enhance_image, 413 ) 414 415 416def model_comparison( 417 output_folder: Union[str, os.PathLike], 418 n_images_per_sample: int, 419 min_size: int, 420 plot_folder: Optional[Union[str, os.PathLike]] = None, 421 point_radius: int = 4, 422 outline_dilation: int = 0, 423 have_model3=False, 424 enhance_image=True, 425) -> None: 426 """Create images for a qualitative model comparision. 427 428 Args: 429 output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`. 430 n_images_per_sample: The number of images to generate per precomputed sample. 431 min_size: The min size of ground-truth objects to take into account. 432 plot_folder: The folder where to save the plots. If not given the plots will be displayed. 433 point_radius: The radius of the point overlay. 434 outline_dilation: The dilation factor of the outline overlay. 435 enhance_image: Whether to enhance the input image. 436 """ 437 files = glob(os.path.join(output_folder, "*.h5")) 438 for path in tqdm(files): 439 _compare_models( 440 path, n_images_per_sample, min_size, plot_folder, point_radius, 441 outline_dilation, have_model3, enhance_image, 442 ) 443 444 445# 446# Quick visual evaluation with napari 447# 448 449 450def _check_group(g, show_points): 451 import napari 452 453 image = g["image"][:] 454 gt = g["gt_mask"][:] 455 if show_points: 456 m1 = g["points/mask1"][:] 457 m2 = g["points/mask2"][:] 458 points = g.attrs["point_coords"] 459 else: 460 m1 = g["box/mask1"][:] 461 m2 = g["box/mask2"][:] 462 box = g.attrs["box"] 463 box = np.array([ 464 [box[0], box[1]], [box[2], box[3]] 465 ]) 466 467 v = napari.Viewer() 468 v.add_image(image) 469 v.add_labels(gt) 470 v.add_labels(m1) 471 v.add_labels(m2) 472 if show_points: 473 # TODO use point labels for coloring 474 v.add_points( 475 points, 476 edge_color="#00FF00", 477 symbol="o", 478 face_color="transparent", 479 edge_width=0.5, 480 size=12, 481 ) 482 else: 483 v.add_shapes( 484 box, face_color="transparent", edge_color="green", edge_width=4, 485 ) 486 napari.run() 487 488 489def model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None: 490 """Use napari to display the qualtiative comparison results for two models. 491 492 Args: 493 output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`. 494 show_points: Whether to show the results for point or for box prompts. 495 """ 496 files = glob(os.path.join(output_folder, "*.h5")) 497 for path in files: 498 print("Comparing models in", path) 499 with h5py.File(path, "r") as f: 500 for name, g in f.items(): 501 if name == "image": 502 continue 503 _check_group(g, show_points=show_points)
def
generate_data_for_model_comparison( loader: torch.utils.data.dataloader.DataLoader, output_folder: Union[str, os.PathLike], model_type1: str, model_type2: str, n_samples: int, model_type3: Optional[str] = None, checkpoint1: Union[str, os.PathLike, NoneType] = None, checkpoint2: Union[str, os.PathLike, NoneType] = None, checkpoint3: Union[str, os.PathLike, NoneType] = None, peft_kwargs1: Optional[Dict[str, Any]] = None, peft_kwargs2: Optional[Dict[str, Any]] = None, peft_kwargs3: Optional[Dict[str, Any]] = None) -> None:
120def generate_data_for_model_comparison( 121 loader: torch.utils.data.DataLoader, 122 output_folder: Union[str, os.PathLike], 123 model_type1: str, 124 model_type2: str, 125 n_samples: int, 126 model_type3: Optional[str] = None, 127 checkpoint1: Optional[Union[str, os.PathLike]] = None, 128 checkpoint2: Optional[Union[str, os.PathLike]] = None, 129 checkpoint3: Optional[Union[str, os.PathLike]] = None, 130 peft_kwargs1: Optional[Dict[str, Any]] = None, 131 peft_kwargs2: Optional[Dict[str, Any]] = None, 132 peft_kwargs3: Optional[Dict[str, Any]] = None, 133) -> None: 134 """Generate samples for qualitative model comparison. 135 136 This precomputes the input for `model_comparison` and `model_comparison_with_napari`. 137 138 Args: 139 loader: The torch dataloader from which samples are drawn. 140 output_folder: The folder where the samples will be saved. 141 model_type1: The first model to use for comparison. 142 The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. 143 model_type2: The second model to use for comparison. 144 The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. 145 n_samples: The number of samples to draw from the dataloader. 146 checkpoint1: Optional checkpoint for the first model. 147 checkpoint2: Optional checkpoint for the second model. 148 checkpoint3: Optional checkpoint for the third model. 149 """ 150 prompt_generator = PointAndBoxPromptGenerator( 151 n_positive_points=1, 152 n_negative_points=0, 153 dilation_strength=3, 154 get_point_prompts=True, 155 get_box_prompts=True, 156 ) 157 158 predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1, peft_kwargs=peft_kwargs1) 159 predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2, peft_kwargs=peft_kwargs2) 160 161 if model_type3 is not None: 162 predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3, peft_kwargs=peft_kwargs3) 163 else: 164 predictor3 = None 165 166 _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder)
Generate samples for qualitative model comparison.
This precomputes the input for model_comparison and model_comparison_with_napari.
Arguments:
- loader: The torch dataloader from which samples are drawn.
- output_folder: The folder where the samples will be saved.
- model_type1: The first model to use for comparison.
The value needs to be a valid model_type for
micro_sam.util.get_sam_model. - model_type2: The second model to use for comparison.
The value needs to be a valid model_type for
micro_sam.util.get_sam_model. - n_samples: The number of samples to draw from the dataloader.
- checkpoint1: Optional checkpoint for the first model.
- checkpoint2: Optional checkpoint for the second model.
- checkpoint3: Optional checkpoint for the third model.
def
model_comparison( output_folder: Union[str, os.PathLike], n_images_per_sample: int, min_size: int, plot_folder: Union[str, os.PathLike, NoneType] = None, point_radius: int = 4, outline_dilation: int = 0, have_model3=False, enhance_image=True) -> None:
417def model_comparison( 418 output_folder: Union[str, os.PathLike], 419 n_images_per_sample: int, 420 min_size: int, 421 plot_folder: Optional[Union[str, os.PathLike]] = None, 422 point_radius: int = 4, 423 outline_dilation: int = 0, 424 have_model3=False, 425 enhance_image=True, 426) -> None: 427 """Create images for a qualitative model comparision. 428 429 Args: 430 output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`. 431 n_images_per_sample: The number of images to generate per precomputed sample. 432 min_size: The min size of ground-truth objects to take into account. 433 plot_folder: The folder where to save the plots. If not given the plots will be displayed. 434 point_radius: The radius of the point overlay. 435 outline_dilation: The dilation factor of the outline overlay. 436 enhance_image: Whether to enhance the input image. 437 """ 438 files = glob(os.path.join(output_folder, "*.h5")) 439 for path in tqdm(files): 440 _compare_models( 441 path, n_images_per_sample, min_size, plot_folder, point_radius, 442 outline_dilation, have_model3, enhance_image, 443 )
Create images for a qualitative model comparision.
Arguments:
- output_folder: The folder with the data precomputed by
generate_data_for_model_comparison. - n_images_per_sample: The number of images to generate per precomputed sample.
- min_size: The min size of ground-truth objects to take into account.
- plot_folder: The folder where to save the plots. If not given the plots will be displayed.
- point_radius: The radius of the point overlay.
- outline_dilation: The dilation factor of the outline overlay.
- enhance_image: Whether to enhance the input image.
def
model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None:
490def model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None: 491 """Use napari to display the qualtiative comparison results for two models. 492 493 Args: 494 output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`. 495 show_points: Whether to show the results for point or for box prompts. 496 """ 497 files = glob(os.path.join(output_folder, "*.h5")) 498 for path in files: 499 print("Comparing models in", path) 500 with h5py.File(path, "r") as f: 501 for name, g in f.items(): 502 if name == "image": 503 continue 504 _check_group(g, show_points=show_points)
Use napari to display the qualtiative comparison results for two models.
Arguments:
- output_folder: The folder with the data precomputed by
generate_data_for_model_comparison. - show_points: Whether to show the results for point or for box prompts.