micro_sam.evaluation.model_comparison
Functionality for qualitative comparison of Segment Anything models on microscopy data.
1"""Functionality for qualitative comparison of Segment Anything models on microscopy data. 2""" 3 4import os 5from glob import glob 6from tqdm import tqdm 7from pathlib import Path 8from functools import partial 9from typing import Optional, Union, Dict, Any 10 11import h5py 12import numpy as np 13import pandas as pd 14import matplotlib.pyplot as plt 15import skimage.draw as draw 16from skimage import exposure 17from scipy.ndimage import binary_dilation 18from skimage.segmentation import relabel_sequential, find_boundaries 19 20import torch 21 22from .. import util 23from ..prompt_generators import PointAndBoxPromptGenerator 24from ..prompt_based_segmentation import segment_from_box, segment_from_points 25 26 27# 28# Compute all required data for the model comparison 29# 30 31 32def _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder): 33 i = 0 34 os.makedirs(output_folder, exist_ok=True) 35 36 for x, y in tqdm(loader, total=n_samples): 37 out_path = os.path.join(output_folder, f"sample_{i}.h5") 38 39 im = x.numpy().squeeze() 40 if im.ndim == 3 and im.shape[0] == 3: 41 im = im.transpose((1, 2, 0)) 42 43 gt = y.numpy().squeeze().astype("uint32") 44 gt = relabel_sequential(gt)[0] 45 46 emb1 = util.precompute_image_embeddings(predictor1, im, ndim=2) 47 util.set_precomputed(predictor1, emb1) 48 49 emb2 = util.precompute_image_embeddings(predictor2, im, ndim=2) 50 util.set_precomputed(predictor2, emb2) 51 52 if predictor3 is not None: 53 emb3 = util.precompute_image_embeddings(predictor3, im, ndim=2) 54 util.set_precomputed(predictor3, emb3) 55 56 with h5py.File(out_path, "a") as f: 57 f.create_dataset("image", data=im, compression="gzip") 58 59 gt_ids = np.unique(gt)[1:] 60 centers, boxes = util.get_centers_and_bounding_boxes(gt) 61 centers = [centers[gt_id] for gt_id in gt_ids] 62 boxes = [boxes[gt_id] for gt_id in gt_ids] 63 64 object_masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids) 65 coords, labels, boxes, _ = prompt_generator( 66 segmentation=object_masks, 67 bbox_coordinates=boxes, 68 center_coordinates=centers, 69 ) 70 71 for idx, gt_id in tqdm(enumerate(gt_ids), total=len(gt_ids)): 72 73 # Box prompts: 74 # Reorder the coordinates so that they match the normal python convention. 75 box = boxes[idx][[1, 0, 3, 2]] 76 mask1_box = segment_from_box(predictor1, box) 77 mask2_box = segment_from_box(predictor2, box) 78 mask1_box, mask2_box = mask1_box.squeeze(), mask2_box.squeeze() 79 80 if predictor3 is not None: 81 mask3_box = segment_from_box(predictor3, box) 82 mask3_box = mask3_box.squeeze() 83 84 # Point prompts: 85 # Reorder the coordinates so that they match the normal python convention. 86 point_coords, point_labels = np.array(coords[idx])[:, ::-1], np.array(labels[idx]) 87 mask1_points = segment_from_points(predictor1, point_coords, point_labels) 88 mask2_points = segment_from_points(predictor2, point_coords, point_labels) 89 mask1_points, mask2_points = mask1_points.squeeze(), mask2_points.squeeze() 90 91 if predictor3 is not None: 92 mask3_points = segment_from_points(predictor3, point_coords, point_labels) 93 mask3_points = mask3_points.squeeze() 94 95 gt_mask = gt == gt_id 96 with h5py.File(out_path, "a") as f: 97 g = f.create_group(str(gt_id)) 98 g.attrs["point_coords"] = point_coords 99 g.attrs["point_labels"] = point_labels 100 g.attrs["box"] = box 101 102 g.create_dataset("gt_mask", data=gt_mask, compression="gzip") 103 g.create_dataset("box/mask1", data=mask1_box.astype("uint8"), compression="gzip") 104 g.create_dataset("box/mask2", data=mask2_box.astype("uint8"), compression="gzip") 105 g.create_dataset("points/mask1", data=mask1_points.astype("uint8"), compression="gzip") 106 g.create_dataset("points/mask2", data=mask2_points.astype("uint8"), compression="gzip") 107 108 if predictor3 is not None: 109 g.create_dataset("box/mask3", data=mask3_box.astype("uint8"), compression="gzip") 110 g.create_dataset("points/mask3", data=mask3_points.astype("uint8"), compression="gzip") 111 112 i += 1 113 if i >= n_samples: 114 return 115 116 117def generate_data_for_model_comparison( 118 loader: torch.utils.data.DataLoader, 119 output_folder: Union[str, os.PathLike], 120 model_type1: str, 121 model_type2: str, 122 n_samples: int, 123 model_type3: Optional[str] = None, 124 checkpoint1: Optional[Union[str, os.PathLike]] = None, 125 checkpoint2: Optional[Union[str, os.PathLike]] = None, 126 checkpoint3: Optional[Union[str, os.PathLike]] = None, 127 peft_kwargs1: Optional[Dict[str, Any]] = None, 128 peft_kwargs2: Optional[Dict[str, Any]] = None, 129 peft_kwargs3: Optional[Dict[str, Any]] = None, 130) -> None: 131 """Generate samples for qualitative model comparison. 132 133 This precomputes the input for `model_comparison` and `model_comparison_with_napari`. 134 135 Args: 136 loader: The torch dataloader from which samples are drawn. 137 output_folder: The folder where the samples will be saved. 138 model_type1: The first model to use for comparison. 139 The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. 140 model_type2: The second model to use for comparison. 141 The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. 142 n_samples: The number of samples to draw from the dataloader. 143 checkpoint1: Optional checkpoint for the first model. 144 checkpoint2: Optional checkpoint for the second model. 145 checkpoint3: Optional checkpoint for the third model. 146 """ 147 prompt_generator = PointAndBoxPromptGenerator( 148 n_positive_points=1, 149 n_negative_points=0, 150 dilation_strength=3, 151 get_point_prompts=True, 152 get_box_prompts=True, 153 ) 154 155 predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1, peft_kwargs=peft_kwargs1) 156 predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2, peft_kwargs=peft_kwargs2) 157 158 if model_type3 is not None: 159 predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3, peft_kwargs=peft_kwargs3) 160 else: 161 predictor3 = None 162 163 _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder) 164 165 166# 167# Visual evaluation according to metrics 168# 169 170 171def _evaluate_samples(f, prefix, min_size): 172 eval_result = { 173 "gt_id": [], 174 "score1": [], 175 "score2": [], 176 } 177 for name, group in f.items(): 178 if name == "image": 179 continue 180 181 gt_mask = group["gt_mask"][:] 182 183 size = gt_mask.sum() 184 if size < min_size: 185 continue 186 187 m1 = group[f"{prefix}/mask1"][:] 188 m2 = group[f"{prefix}/mask2"][:] 189 190 score1 = util.compute_iou(gt_mask, m1) 191 score2 = util.compute_iou(gt_mask, m2) 192 193 eval_result["gt_id"].append(name) 194 eval_result["score1"].append(score1) 195 eval_result["score2"].append(score2) 196 197 eval_result = pd.DataFrame.from_dict(eval_result) 198 eval_result["advantage1"] = eval_result["score1"] - eval_result["score2"] 199 eval_result["advantage2"] = eval_result["score2"] - eval_result["score1"] 200 return eval_result 201 202 203def _overlay_mask(image, mask, alpha=0.6): 204 assert image.ndim in (2, 3) 205 # overlay the mask 206 if image.ndim == 2: 207 overlay = np.stack([image, image, image]).transpose((1, 2, 0)) 208 else: 209 overlay = image 210 assert overlay.shape[-1] == 3 211 mask_overlay = np.zeros_like(overlay) 212 mask_overlay[mask == 1] = [255, 0, 0] 213 alpha = alpha 214 overlay = alpha * overlay + (1.0 - alpha) * mask_overlay 215 return overlay.astype("uint8") 216 217 218def _enhance_image(im, do_norm=True): 219 # apply CLAHE to improve the image quality 220 if do_norm: 221 im -= im.min(axis=(0, 1), keepdims=True) 222 im /= (im.max(axis=(0, 1), keepdims=True) + 1e-6) 223 im = exposure.equalize_adapthist(im) 224 im *= 255 225 return im 226 227 228def _overlay_outline(im, mask, outline_dilation): 229 outline = find_boundaries(mask) 230 if outline_dilation > 0: 231 outline = binary_dilation(outline, iterations=outline_dilation) 232 overlay = im.copy() 233 overlay[outline] = [255, 255, 0] 234 return overlay 235 236 237def _overlay_box(im, prompt, outline_dilation): 238 start, end = prompt 239 rr, cc = draw.rectangle_perimeter(start, end=end, shape=im.shape[:2]) 240 241 box_outline = np.zeros(im.shape[:2], dtype="bool") 242 box_outline[rr, cc] = 1 243 if outline_dilation > 0: 244 box_outline = binary_dilation(box_outline, iterations=outline_dilation) 245 246 overlay = im.copy() 247 overlay[box_outline] = [0, 255, 255] 248 249 return overlay 250 251 252# NOTE: we currently only support a single point 253def _overlay_points(im, prompt, radius): 254 coords, labels = prompt 255 # make sure we have a single positive prompt, other options are 256 # currently not supported 257 assert coords.shape[0] == labels.shape[0] == 1 258 assert labels[0] == 1 259 260 rr, cc = draw.disk(coords[0], radius, shape=im.shape[:2]) 261 overlay = im.copy() 262 draw.set_color(overlay, (rr, cc), [0, 255, 255], alpha=1.0) 263 264 return overlay 265 266 267def _compare_eval( 268 f, eval_result, advantage_column, n_images_per_sample, prefix, sample_name, 269 plot_folder, point_radius, outline_dilation, have_model3, enhance_image, 270): 271 result = eval_result.sort_values(advantage_column, ascending=False).iloc[:n_images_per_sample] 272 n_rows = result.shape[0] 273 274 image = f["image"][:] 275 is_box_prompt = prefix == "box" 276 overlay_prompts = partial(_overlay_box, outline_dilation=outline_dilation) if is_box_prompt else\ 277 partial(_overlay_points, radius=point_radius) 278 279 def make_square(bb, shape): 280 box_shape = [b.stop - b.start for b in bb] 281 longest_side = max(box_shape) 282 padding = [(longest_side - sh) // 2 for sh in box_shape] 283 bb = tuple( 284 slice(max(b.start - pad, 0), min(b.stop + pad, sh)) for b, pad, sh in zip(bb, padding, shape) 285 ) 286 return bb 287 288 def plot_ax(axis, i, row): 289 g = f[row.gt_id] 290 291 gt = g["gt_mask"][:] 292 mask1 = g[f"{prefix}/mask1"][:] 293 mask2 = g[f"{prefix}/mask2"][:] 294 295 # The mask3 is just for comparison purpose, we just plot the crops as it is. 296 if have_model3: 297 mask3 = g[f"{prefix}/mask3"][:] 298 299 fg_mask = (gt + mask1 + mask2) > 0 300 # if this is a box prompt we dilate the mask so that the bounding box 301 # can be seen 302 if is_box_prompt: 303 fg_mask = binary_dilation(fg_mask, iterations=5) 304 bb = np.where(fg_mask) 305 bb = tuple( 306 slice(int(b.min()), int(b.max() + 1)) for b in bb 307 ) 308 bb = make_square(bb, fg_mask.shape) 309 310 offset = np.array([b.start for b in bb]) 311 if is_box_prompt: 312 prompt = g.attrs["box"] 313 prompt = np.array( 314 [prompt[:2], prompt[2:]] 315 ) - offset 316 else: 317 prompt = (g.attrs["point_coords"] - offset, g.attrs["point_labels"]) 318 319 if enhance_image: 320 im = _enhance_image(image[bb]) 321 else: 322 im = image[bb] 323 324 gt, mask1, mask2 = gt[bb], mask1[bb], mask2[bb] 325 326 if have_model3: 327 mask3 = mask3[bb] 328 329 im1 = _overlay_mask(im, mask1) 330 im1 = _overlay_outline(im1, gt, outline_dilation) 331 im1 = overlay_prompts(im1, prompt) 332 ax = axis[0] if i is None else axis[i, 0] 333 ax.axis("off") 334 ax.imshow(im1) 335 336 # We put the third set of comparsion point in between 337 # so that the comparison looks -> default, generalist, specialist 338 if have_model3: 339 im3 = _overlay_mask(im, mask3) 340 im3 = _overlay_outline(im3, gt, outline_dilation) 341 im3 = overlay_prompts(im3, prompt) 342 ax = axis[1] if i is None else axis[i, 1] 343 ax.axis("off") 344 ax.imshow(im3) 345 346 nexax = 2 347 else: 348 nexax = 1 349 350 im2 = _overlay_mask(im, mask2) 351 im2 = _overlay_outline(im2, gt, outline_dilation) 352 im2 = overlay_prompts(im2, prompt) 353 ax = axis[nexax] if i is None else axis[i, nexax] 354 ax.axis("off") 355 ax.imshow(im2) 356 357 cols = 3 if have_model3 else 2 358 if plot_folder is None: 359 fig, axis = plt.subplots(n_rows, cols) 360 for i, (_, row) in enumerate(result.iterrows()): 361 plot_ax(axis, i, row) 362 plt.show() 363 else: 364 for i, (_, row) in enumerate(result.iterrows()): 365 fig, axis = plt.subplots(1, cols) 366 plot_ax(axis, None, row) 367 plt.subplots_adjust(wspace=0.05, hspace=0) 368 plt.savefig(os.path.join(plot_folder, f"{sample_name}_{i}.svg"), bbox_inches="tight") 369 plt.close() 370 371 372def _compare_prompts( 373 f, prefix, n_images_per_sample, min_size, sample_name, plot_folder, 374 point_radius, outline_dilation, have_model3, enhance_image, 375): 376 box_eval = _evaluate_samples(f, prefix, min_size) 377 if plot_folder is None: 378 plot_folder1, plot_folder2 = None, None 379 else: 380 plot_folder1 = os.path.join(plot_folder, "advantage1") 381 plot_folder2 = os.path.join(plot_folder, "advantage2") 382 os.makedirs(plot_folder1, exist_ok=True) 383 os.makedirs(plot_folder2, exist_ok=True) 384 _compare_eval( 385 f, box_eval, "advantage1", n_images_per_sample, prefix, sample_name, plot_folder1, 386 point_radius, outline_dilation, have_model3, enhance_image, 387 ) 388 _compare_eval( 389 f, box_eval, "advantage2", n_images_per_sample, prefix, sample_name, plot_folder2, 390 point_radius, outline_dilation, have_model3, enhance_image, 391 ) 392 393 394def _compare_models( 395 path, n_images_per_sample, min_size, plot_folder, point_radius, outline_dilation, have_model3, enhance_image, 396): 397 sample_name = Path(path).stem 398 with h5py.File(path, "r") as f: 399 if plot_folder is None: 400 plot_folder_points, plot_folder_box = None, None 401 else: 402 plot_folder_points = os.path.join(plot_folder, "points") 403 plot_folder_box = os.path.join(plot_folder, "box") 404 _compare_prompts( 405 f, "points", n_images_per_sample, min_size, sample_name, plot_folder_points, 406 point_radius, outline_dilation, have_model3, enhance_image, 407 ) 408 _compare_prompts( 409 f, "box", n_images_per_sample, min_size, sample_name, plot_folder_box, 410 point_radius, outline_dilation, have_model3, enhance_image, 411 ) 412 413 414def model_comparison( 415 output_folder: Union[str, os.PathLike], 416 n_images_per_sample: int, 417 min_size: int, 418 plot_folder: Optional[Union[str, os.PathLike]] = None, 419 point_radius: int = 4, 420 outline_dilation: int = 0, 421 have_model3=False, 422 enhance_image=True, 423) -> None: 424 """Create images for a qualitative model comparision. 425 426 Args: 427 output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`. 428 n_images_per_sample: The number of images to generate per precomputed sample. 429 min_size: The min size of ground-truth objects to take into account. 430 plot_folder: The folder where to save the plots. If not given the plots will be displayed. 431 point_radius: The radius of the point overlay. 432 outline_dilation: The dilation factor of the outline overlay. 433 enhance_image: Whether to enhance the input image. 434 """ 435 files = glob(os.path.join(output_folder, "*.h5")) 436 for path in tqdm(files): 437 _compare_models( 438 path, n_images_per_sample, min_size, plot_folder, point_radius, 439 outline_dilation, have_model3, enhance_image, 440 ) 441 442 443# 444# Quick visual evaluation with napari 445# 446 447 448def _check_group(g, show_points): 449 import napari 450 451 image = g["image"][:] 452 gt = g["gt_mask"][:] 453 if show_points: 454 m1 = g["points/mask1"][:] 455 m2 = g["points/mask2"][:] 456 points = g.attrs["point_coords"] 457 else: 458 m1 = g["box/mask1"][:] 459 m2 = g["box/mask2"][:] 460 box = g.attrs["box"] 461 box = np.array([ 462 [box[0], box[1]], [box[2], box[3]] 463 ]) 464 465 v = napari.Viewer() 466 v.add_image(image) 467 v.add_labels(gt) 468 v.add_labels(m1) 469 v.add_labels(m2) 470 if show_points: 471 # TODO use point labels for coloring 472 v.add_points( 473 points, 474 edge_color="#00FF00", 475 symbol="o", 476 face_color="transparent", 477 edge_width=0.5, 478 size=12, 479 ) 480 else: 481 v.add_shapes( 482 box, face_color="transparent", edge_color="green", edge_width=4, 483 ) 484 napari.run() 485 486 487def model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None: 488 """Use napari to display the qualtiative comparison results for two models. 489 490 Args: 491 output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`. 492 show_points: Whether to show the results for point or for box prompts. 493 """ 494 files = glob(os.path.join(output_folder, "*.h5")) 495 for path in files: 496 print("Comparing models in", path) 497 with h5py.File(path, "r") as f: 498 for name, g in f.items(): 499 if name == "image": 500 continue 501 _check_group(g, show_points=show_points)
def
generate_data_for_model_comparison( loader: torch.utils.data.dataloader.DataLoader, output_folder: Union[str, os.PathLike], model_type1: str, model_type2: str, n_samples: int, model_type3: Optional[str] = None, checkpoint1: Union[str, os.PathLike, NoneType] = None, checkpoint2: Union[str, os.PathLike, NoneType] = None, checkpoint3: Union[str, os.PathLike, NoneType] = None, peft_kwargs1: Optional[Dict[str, Any]] = None, peft_kwargs2: Optional[Dict[str, Any]] = None, peft_kwargs3: Optional[Dict[str, Any]] = None) -> None:
118def generate_data_for_model_comparison( 119 loader: torch.utils.data.DataLoader, 120 output_folder: Union[str, os.PathLike], 121 model_type1: str, 122 model_type2: str, 123 n_samples: int, 124 model_type3: Optional[str] = None, 125 checkpoint1: Optional[Union[str, os.PathLike]] = None, 126 checkpoint2: Optional[Union[str, os.PathLike]] = None, 127 checkpoint3: Optional[Union[str, os.PathLike]] = None, 128 peft_kwargs1: Optional[Dict[str, Any]] = None, 129 peft_kwargs2: Optional[Dict[str, Any]] = None, 130 peft_kwargs3: Optional[Dict[str, Any]] = None, 131) -> None: 132 """Generate samples for qualitative model comparison. 133 134 This precomputes the input for `model_comparison` and `model_comparison_with_napari`. 135 136 Args: 137 loader: The torch dataloader from which samples are drawn. 138 output_folder: The folder where the samples will be saved. 139 model_type1: The first model to use for comparison. 140 The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. 141 model_type2: The second model to use for comparison. 142 The value needs to be a valid model_type for `micro_sam.util.get_sam_model`. 143 n_samples: The number of samples to draw from the dataloader. 144 checkpoint1: Optional checkpoint for the first model. 145 checkpoint2: Optional checkpoint for the second model. 146 checkpoint3: Optional checkpoint for the third model. 147 """ 148 prompt_generator = PointAndBoxPromptGenerator( 149 n_positive_points=1, 150 n_negative_points=0, 151 dilation_strength=3, 152 get_point_prompts=True, 153 get_box_prompts=True, 154 ) 155 156 predictor1 = util.get_sam_model(model_type=model_type1, checkpoint_path=checkpoint1, peft_kwargs=peft_kwargs1) 157 predictor2 = util.get_sam_model(model_type=model_type2, checkpoint_path=checkpoint2, peft_kwargs=peft_kwargs2) 158 159 if model_type3 is not None: 160 predictor3 = util.get_sam_model(model_type=model_type3, checkpoint_path=checkpoint3, peft_kwargs=peft_kwargs3) 161 else: 162 predictor3 = None 163 164 _predict_models_with_loader(loader, n_samples, prompt_generator, predictor1, predictor2, predictor3, output_folder)
Generate samples for qualitative model comparison.
This precomputes the input for model_comparison
and model_comparison_with_napari
.
Arguments:
- loader: The torch dataloader from which samples are drawn.
- output_folder: The folder where the samples will be saved.
- model_type1: The first model to use for comparison.
The value needs to be a valid model_type for
micro_sam.util.get_sam_model
. - model_type2: The second model to use for comparison.
The value needs to be a valid model_type for
micro_sam.util.get_sam_model
. - n_samples: The number of samples to draw from the dataloader.
- checkpoint1: Optional checkpoint for the first model.
- checkpoint2: Optional checkpoint for the second model.
- checkpoint3: Optional checkpoint for the third model.
def
model_comparison( output_folder: Union[str, os.PathLike], n_images_per_sample: int, min_size: int, plot_folder: Union[str, os.PathLike, NoneType] = None, point_radius: int = 4, outline_dilation: int = 0, have_model3=False, enhance_image=True) -> None:
415def model_comparison( 416 output_folder: Union[str, os.PathLike], 417 n_images_per_sample: int, 418 min_size: int, 419 plot_folder: Optional[Union[str, os.PathLike]] = None, 420 point_radius: int = 4, 421 outline_dilation: int = 0, 422 have_model3=False, 423 enhance_image=True, 424) -> None: 425 """Create images for a qualitative model comparision. 426 427 Args: 428 output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`. 429 n_images_per_sample: The number of images to generate per precomputed sample. 430 min_size: The min size of ground-truth objects to take into account. 431 plot_folder: The folder where to save the plots. If not given the plots will be displayed. 432 point_radius: The radius of the point overlay. 433 outline_dilation: The dilation factor of the outline overlay. 434 enhance_image: Whether to enhance the input image. 435 """ 436 files = glob(os.path.join(output_folder, "*.h5")) 437 for path in tqdm(files): 438 _compare_models( 439 path, n_images_per_sample, min_size, plot_folder, point_radius, 440 outline_dilation, have_model3, enhance_image, 441 )
Create images for a qualitative model comparision.
Arguments:
- output_folder: The folder with the data precomputed by
generate_data_for_model_comparison
. - n_images_per_sample: The number of images to generate per precomputed sample.
- min_size: The min size of ground-truth objects to take into account.
- plot_folder: The folder where to save the plots. If not given the plots will be displayed.
- point_radius: The radius of the point overlay.
- outline_dilation: The dilation factor of the outline overlay.
- enhance_image: Whether to enhance the input image.
def
model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None:
488def model_comparison_with_napari(output_folder: Union[str, os.PathLike], show_points: bool = True) -> None: 489 """Use napari to display the qualtiative comparison results for two models. 490 491 Args: 492 output_folder: The folder with the data precomputed by `generate_data_for_model_comparison`. 493 show_points: Whether to show the results for point or for box prompts. 494 """ 495 files = glob(os.path.join(output_folder, "*.h5")) 496 for path in files: 497 print("Comparing models in", path) 498 with h5py.File(path, "r") as f: 499 for name, g in f.items(): 500 if name == "image": 501 continue 502 _check_group(g, show_points=show_points)
Use napari to display the qualtiative comparison results for two models.
Arguments:
- output_folder: The folder with the data precomputed by
generate_data_for_model_comparison
. - show_points: Whether to show the results for point or for box prompts.