micro_sam.evaluation.inference
Inference with Segment Anything models and different prompt strategies.
1"""Inference with Segment Anything models and different prompt strategies. 2""" 3 4import os 5import pickle 6import numpy as np 7from tqdm import tqdm 8from copy import deepcopy 9from typing import Any, Dict, List, Optional, Union, Tuple 10 11import imageio.v3 as imageio 12from skimage.segmentation import relabel_sequential 13 14import torch 15 16from segment_anything import SamPredictor 17 18from .. import util as util 19from ..inference import batched_inference 20from ..instance_segmentation import ( 21 mask_data_to_segmentation, get_predictor_and_decoder, 22 AutomaticMaskGenerator, InstanceSegmentationWithDecoder, 23 TiledAutomaticMaskGenerator, TiledInstanceSegmentationWithDecoder, 24) 25from . import instance_segmentation 26from ..prompt_generators import PointAndBoxPromptGenerator, IterativePromptGenerator 27 28 29def _load_prompts( 30 cached_point_prompts, save_point_prompts, cached_box_prompts, save_box_prompts, image_name 31): 32 33 def load_prompt_type(cached_prompts, save_prompts): 34 # Check if we have saved prompts. 35 if cached_prompts is None or save_prompts: # we don't have cached prompts 36 return cached_prompts, None 37 38 # we have cached prompts, but they have not been loaded yet 39 if isinstance(cached_prompts, str): 40 with open(cached_prompts, "rb") as f: 41 cached_prompts = pickle.load(f) 42 43 prompts = cached_prompts[image_name] 44 return cached_prompts, prompts 45 46 cached_point_prompts, point_prompts = load_prompt_type(cached_point_prompts, save_point_prompts) 47 cached_box_prompts, box_prompts = load_prompt_type(cached_box_prompts, save_box_prompts) 48 49 # we don't have anything cached 50 if point_prompts is None and box_prompts is None: 51 return None, cached_point_prompts, cached_box_prompts 52 53 if point_prompts is None: 54 input_point, input_label = [], [] 55 else: 56 input_point, input_label = point_prompts 57 58 if box_prompts is None: 59 input_box = [] 60 else: 61 input_box = box_prompts 62 63 prompts = (input_point, input_label, input_box) 64 return prompts, cached_point_prompts, cached_box_prompts 65 66 67def _get_batched_prompts(gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation): 68 69 # Initialize the prompt generator. 70 prompt_generator = PointAndBoxPromptGenerator( 71 n_positive_points=n_positives, n_negative_points=n_negatives, 72 dilation_strength=dilation, get_point_prompts=use_points, 73 get_box_prompts=use_boxes 74 ) 75 76 # Generate the prompts. 77 center_coordinates, bbox_coordinates = util.get_centers_and_bounding_boxes(gt) 78 center_coordinates = [center_coordinates[gt_id] for gt_id in gt_ids] 79 bbox_coordinates = [bbox_coordinates[gt_id] for gt_id in gt_ids] 80 masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids) 81 82 points, point_labels, boxes, _ = prompt_generator( 83 masks, bbox_coordinates, center_coordinates 84 ) 85 86 def to_numpy(x): 87 if x is None: 88 return x 89 return x.numpy() 90 91 return to_numpy(points), to_numpy(point_labels), to_numpy(boxes) 92 93 94def _run_inference_with_prompts_for_image( 95 predictor, 96 image, 97 gt, 98 use_points, 99 use_boxes, 100 n_positives, 101 n_negatives, 102 dilation, 103 batch_size, 104 cached_prompts, 105 embedding_path, 106): 107 gt_ids = np.unique(gt)[1:] 108 if cached_prompts is None: 109 points, point_labels, boxes = _get_batched_prompts( 110 gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation 111 ) 112 else: 113 points, point_labels, boxes = cached_prompts 114 115 # Make a copy of the point prompts to return them at the end. 116 prompts = deepcopy((points, point_labels, boxes)) 117 118 # Use multi-masking only if we have a single positive point without box 119 multimasking = False 120 if not use_boxes and (n_positives == 1 and n_negatives == 0): 121 multimasking = True 122 123 instance_labels = batched_inference( 124 predictor, image, batch_size, 125 boxes=boxes, points=points, point_labels=point_labels, 126 multimasking=multimasking, embedding_path=embedding_path, 127 return_instance_segmentation=True, 128 ) 129 130 return instance_labels, prompts 131 132 133def precompute_all_embeddings( 134 predictor: SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], 135) -> None: 136 """Precompute all image embeddings. 137 138 To enable running different inference tasks in parallel afterwards. 139 140 Args: 141 predictor: The SegmentAnything predictor. 142 image_paths: The image file paths. 143 embedding_dir: The directory where the embeddings will be saved. 144 """ 145 for image_path in tqdm(image_paths, desc="Precompute embeddings"): 146 image_name = os.path.basename(image_path) 147 im = imageio.imread(image_path) 148 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 149 util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2) 150 151 152def _precompute_prompts(gt_path, use_points, use_boxes, n_positives, n_negatives, dilation): 153 name = os.path.basename(gt_path) 154 155 gt = imageio.imread(gt_path).astype("uint32") 156 gt = relabel_sequential(gt)[0] 157 gt_ids = np.unique(gt)[1:] 158 159 input_point, input_label, input_box = _get_batched_prompts( 160 gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation 161 ) 162 163 if use_boxes and not use_points: 164 return name, input_box 165 return name, (input_point, input_label) 166 167 168def precompute_all_prompts( 169 gt_paths: List[Union[str, os.PathLike]], 170 prompt_save_dir: Union[str, os.PathLike], 171 prompt_settings: List[Dict[str, Any]], 172) -> None: 173 """Precompute all point prompts. 174 175 To enable running different inference tasks in parallel afterwards. 176 177 Args: 178 gt_paths: The file paths to the ground-truth segmentations. 179 prompt_save_dir: The directory where the prompt files will be saved. 180 prompt_settings: The settings for which the prompts will be computed. 181 """ 182 os.makedirs(prompt_save_dir, exist_ok=True) 183 184 for settings in tqdm(prompt_settings, desc="Precompute prompts"): 185 186 use_points, use_boxes = settings["use_points"], settings["use_boxes"] 187 n_positives, n_negatives = settings["n_positives"], settings["n_negatives"] 188 dilation = settings.get("dilation", 5) 189 190 # check if the prompts were already computed 191 if use_boxes and not use_points: 192 prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl") 193 else: 194 prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl") 195 if os.path.exists(prompt_save_path): 196 continue 197 198 results = [] 199 for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"): 200 prompts = _precompute_prompts( 201 gt_path, 202 use_points=use_points, 203 use_boxes=use_boxes, 204 n_positives=n_positives, 205 n_negatives=n_negatives, 206 dilation=dilation, 207 ) 208 results.append(prompts) 209 210 saved_prompts = {res[0]: res[1] for res in results} 211 with open(prompt_save_path, "wb") as f: 212 pickle.dump(saved_prompts, f) 213 214 215def _get_prompt_caching(prompt_save_dir, use_points, use_boxes, n_positives, n_negatives): 216 217 def get_prompt_type_caching(use_type, save_name): 218 if not use_type: 219 return None, False, None 220 221 prompt_save_path = os.path.join(prompt_save_dir, save_name) 222 if os.path.exists(prompt_save_path): 223 print("Using precomputed prompts from", prompt_save_path) 224 # We delay loading the prompts, so we only have to load them once they're needed the first time. 225 # This avoids loading the prompts (which are in a big pickle file) if all predictions are done already. 226 cached_prompts = prompt_save_path 227 save_prompts = False 228 else: 229 print("Saving prompts in", prompt_save_path) 230 cached_prompts = {} 231 save_prompts = True 232 return cached_prompts, save_prompts, prompt_save_path 233 234 # Check if prompt serialization is enabled. 235 # If it is then load the prompts if they are already cached and otherwise store them. 236 if prompt_save_dir is None: 237 print("Prompts are not cached.") 238 cached_point_prompts, cached_box_prompts = None, None 239 save_point_prompts, save_box_prompts = False, False 240 point_prompt_save_path, box_prompt_save_path = None, None 241 else: 242 cached_point_prompts, save_point_prompts, point_prompt_save_path = get_prompt_type_caching( 243 use_points, f"points-p{n_positives}-n{n_negatives}.pkl" 244 ) 245 cached_box_prompts, save_box_prompts, box_prompt_save_path = get_prompt_type_caching( 246 use_boxes, "boxes.pkl" 247 ) 248 249 return (cached_point_prompts, save_point_prompts, point_prompt_save_path, 250 cached_box_prompts, save_box_prompts, box_prompt_save_path) 251 252 253def run_inference_with_prompts( 254 predictor: SamPredictor, 255 image_paths: List[Union[str, os.PathLike]], 256 gt_paths: List[Union[str, os.PathLike]], 257 embedding_dir: Union[str, os.PathLike], 258 prediction_dir: Union[str, os.PathLike], 259 use_points: bool, 260 use_boxes: bool, 261 n_positives: int, 262 n_negatives: int, 263 dilation: int = 5, 264 prompt_save_dir: Optional[Union[str, os.PathLike]] = None, 265 batch_size: int = 512, 266) -> None: 267 """Run segment anything inference for multiple images using prompts derived from groundtruth. 268 269 Args: 270 predictor: The SegmentAnything predictor. 271 image_paths: The image file paths. 272 gt_paths: The ground-truth segmentation file paths. 273 embedding_dir: The directory where the image embddings will be saved or are already saved. 274 use_points: Whether to use point prompts. 275 use_boxes: Whether to use box prompts 276 n_positives: The number of positive point prompts that will be sampled. 277 n_negativess: The number of negative point prompts that will be sampled. 278 dilation: The dilation factor for the radius around the ground-truth object 279 around which points will not be sampled. 280 prompt_save_dir: The directory where point prompts will be saved or are already saved. 281 This enables running multiple experiments in a reproducible manner. 282 batch_size: The batch size used for batched prediction. 283 """ 284 if not (use_points or use_boxes): 285 raise ValueError("You need to use at least one of point or box prompts.") 286 287 if len(image_paths) != len(gt_paths): 288 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 289 290 (cached_point_prompts, save_point_prompts, point_prompt_save_path, 291 cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching( 292 prompt_save_dir, use_points, use_boxes, n_positives, n_negatives 293 ) 294 295 os.makedirs(prediction_dir, exist_ok=True) 296 for image_path, gt_path in tqdm( 297 zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts" 298 ): 299 image_name = os.path.basename(image_path) 300 label_name = os.path.basename(gt_path) 301 302 # We skip the images that already have been segmented. 303 prediction_path = os.path.join(prediction_dir, image_name) 304 if os.path.exists(prediction_path): 305 continue 306 307 assert os.path.exists(image_path), image_path 308 assert os.path.exists(gt_path), gt_path 309 310 im = imageio.imread(image_path) 311 gt = imageio.imread(gt_path).astype("uint32") 312 gt = relabel_sequential(gt)[0] 313 314 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 315 this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts( 316 cached_point_prompts, save_point_prompts, 317 cached_box_prompts, save_box_prompts, 318 label_name 319 ) 320 instances, this_prompts = _run_inference_with_prompts_for_image( 321 predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives, 322 dilation=dilation, use_points=use_points, use_boxes=use_boxes, 323 batch_size=batch_size, cached_prompts=this_prompts, 324 embedding_path=embedding_path, 325 ) 326 327 if save_point_prompts: 328 cached_point_prompts[label_name] = this_prompts[:2] 329 if save_box_prompts: 330 cached_box_prompts[label_name] = this_prompts[-1] 331 332 # It's important to compress here, otherwise the predictions would take up a lot of space. 333 imageio.imwrite(prediction_path, instances, compression=5) 334 335 # Save the prompts if we run experiments with prompt caching and have computed them 336 # for the first time. 337 if save_point_prompts: 338 with open(point_prompt_save_path, "wb") as f: 339 pickle.dump(cached_point_prompts, f) 340 if save_box_prompts: 341 with open(box_prompt_save_path, "wb") as f: 342 pickle.dump(cached_box_prompts, f) 343 344 345def _save_segmentation(masks, prediction_path): 346 # masks to segmentation 347 masks = masks.cpu().numpy().squeeze(1).astype("bool") 348 masks = [{"segmentation": mask, "area": mask.sum()} for mask in masks] 349 segmentation = mask_data_to_segmentation(masks, with_background=True) 350 imageio.imwrite(prediction_path, segmentation, compression=5) 351 352 353def _get_batched_iterative_prompts(sampled_binary_gt, masks, batch_size, prompt_generator): 354 n_samples = sampled_binary_gt.shape[0] 355 n_batches = int(np.ceil(float(n_samples) / batch_size)) 356 next_coords, next_labels = [], [] 357 for batch_idx in range(n_batches): 358 batch_start = batch_idx * batch_size 359 batch_stop = min((batch_idx + 1) * batch_size, n_samples) 360 361 batch_coords, batch_labels, _, _ = prompt_generator( 362 sampled_binary_gt[batch_start: batch_stop], masks[batch_start: batch_stop] 363 ) 364 next_coords.append(batch_coords) 365 next_labels.append(batch_labels) 366 367 next_coords = torch.concatenate(next_coords) 368 next_labels = torch.concatenate(next_labels) 369 370 return next_coords, next_labels 371 372 373@torch.no_grad() 374def _run_inference_with_iterative_prompting_for_image( 375 predictor, 376 image, 377 gt, 378 start_with_box_prompt, 379 dilation, 380 batch_size, 381 embedding_path, 382 n_iterations, 383 prediction_paths, 384 use_masks=False 385) -> None: 386 verbose_embeddings = False 387 388 prompt_generator = IterativePromptGenerator() 389 390 gt_ids = np.unique(gt)[1:] 391 392 # Use multi-masking only if we have a single positive point without box 393 if start_with_box_prompt: 394 use_boxes, use_points = True, False 395 n_positives = 0 396 multimasking = False 397 else: 398 use_boxes, use_points = False, True 399 n_positives = 1 400 multimasking = True 401 402 points, point_labels, boxes = _get_batched_prompts( 403 gt, gt_ids, 404 use_points=use_points, 405 use_boxes=use_boxes, 406 n_positives=n_positives, 407 n_negatives=0, 408 dilation=dilation 409 ) 410 411 sampled_binary_gt = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids) 412 413 for iteration in range(n_iterations): 414 if iteration == 0: # logits mask can not be used for the first iteration. 415 logits_masks = None 416 else: 417 if not use_masks: # logits mask should not be used when not desired. 418 logits_masks = None 419 420 batched_outputs = batched_inference( 421 predictor=predictor, 422 image=image, 423 batch_size=batch_size, 424 boxes=boxes, 425 points=points, 426 point_labels=point_labels, 427 multimasking=multimasking, 428 embedding_path=embedding_path, 429 return_instance_segmentation=False, 430 logits_masks=logits_masks, 431 verbose_embeddings=verbose_embeddings, 432 ) 433 434 # switching off multimasking after first iter, as next iters (with multiple prompts) don't expect multimasking 435 multimasking = False 436 437 masks = torch.stack([m["segmentation"][None] for m in batched_outputs]).to(torch.float32) 438 439 next_coords, next_labels = _get_batched_iterative_prompts( 440 sampled_binary_gt, masks, batch_size, prompt_generator 441 ) 442 next_coords, next_labels = next_coords.detach().cpu().numpy(), next_labels.detach().cpu().numpy() 443 444 if points is not None: 445 points = np.concatenate([points, next_coords], axis=1) 446 else: 447 points = next_coords 448 449 if point_labels is not None: 450 point_labels = np.concatenate([point_labels, next_labels], axis=1) 451 else: 452 point_labels = next_labels 453 454 if use_masks: 455 logits_masks = torch.stack([m["logits"] for m in batched_outputs]) 456 457 _save_segmentation(masks, prediction_paths[iteration]) 458 459 460def run_inference_with_iterative_prompting( 461 predictor: SamPredictor, 462 image_paths: List[Union[str, os.PathLike]], 463 gt_paths: List[Union[str, os.PathLike]], 464 embedding_dir: Union[str, os.PathLike], 465 prediction_dir: Union[str, os.PathLike], 466 start_with_box_prompt: bool = True, 467 dilation: int = 5, 468 batch_size: int = 32, 469 n_iterations: int = 8, 470 use_masks: bool = False 471) -> None: 472 """Run Segment Anything inference for multiple images using prompts iteratively 473 derived from model outputs and ground-truth. 474 475 Args: 476 predictor: The Segment Anything predictor. 477 image_paths: The image file paths. 478 gt_paths: The ground-truth segmentation file paths. 479 embedding_dir: The directory where the image embeddings will be saved or are already saved. 480 prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration. 481 start_with_box_prompt: Whether to use the first prompt as bounding box or a single point 482 dilation: The dilation factor for the radius around the ground-truth object 483 around which points will not be sampled. 484 batch_size: The batch size used for batched predictions. 485 n_iterations: The number of iterations for iterative prompting. 486 use_masks: Whether to make use of logits from previous prompt-based segmentation. 487 """ 488 if len(image_paths) != len(gt_paths): 489 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 490 491 # create all prediction folders for all intermediate iterations 492 for i in range(n_iterations): 493 os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True) 494 495 if use_masks: 496 print("The iterative prompting will make use of logits masks from previous iterations.") 497 498 for image_path, gt_path in tqdm( 499 zip(image_paths, gt_paths), total=len(image_paths), 500 desc="Run inference with iterative prompting for all images", 501 ): 502 image_name = os.path.basename(image_path) 503 504 # We skip the images that already have been segmented 505 prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)] 506 if all(os.path.exists(prediction_path) for prediction_path in prediction_paths): 507 continue 508 509 assert os.path.exists(image_path), image_path 510 assert os.path.exists(gt_path), gt_path 511 512 image = imageio.imread(image_path) 513 gt = imageio.imread(gt_path).astype("uint32") 514 gt = relabel_sequential(gt)[0] 515 516 if embedding_dir is None: 517 embedding_path = None 518 else: 519 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 520 521 _run_inference_with_iterative_prompting_for_image( 522 predictor, image, gt, start_with_box_prompt=start_with_box_prompt, 523 dilation=dilation, batch_size=batch_size, embedding_path=embedding_path, 524 n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks 525 ) 526 527 528# 529# AMG FUNCTION 530# 531 532 533def run_amg( 534 checkpoint: Union[str, os.PathLike], 535 model_type: str, 536 experiment_folder: Union[str, os.PathLike], 537 val_image_paths: List[Union[str, os.PathLike]], 538 val_gt_paths: List[Union[str, os.PathLike]], 539 test_image_paths: List[Union[str, os.PathLike]], 540 iou_thresh_values: Optional[List[float]] = None, 541 stability_score_values: Optional[List[float]] = None, 542 peft_kwargs: Optional[Dict] = None, 543 cache_embeddings: bool = False, 544 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 545) -> str: 546 """Run Segment Anything inference for multiple images using automatic mask generation (AMG). 547 548 Args: 549 checkpoint: The filepath to model checkpoints. 550 model_type: The segment anything model choice. 551 experimet_folder: The directory where the relevant files are saved. 552 val_image_paths: The list of filepaths of input images for grid-search. 553 val_gt_paths: The list of filepaths of corresponding labels for grid-search. 554 test_image_paths: The list of filepaths of input images for automatic instance segmentation. 555 iou_thresh_values: Optional choice of values for grid search of `iou_thresh` parameter. 556 stability_score_values: Optional choice of values for grid search of `stability_score` parameter. 557 peft_kwargs: Keyword arguments for th PEFT wrapper class. 558 cache_embeddings: Whether to cache embeddings in experiment folder. 559 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 560 561 Returns: 562 Filepath where the predictions have been saved. 563 """ 564 565 if cache_embeddings: 566 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 567 os.makedirs(embedding_folder, exist_ok=True) 568 else: 569 embedding_folder = None 570 571 predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs) 572 573 # Get the AMG class. 574 if tiling_window_params: 575 if not isinstance(tiling_window_params, dict): 576 raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.") 577 578 if "tile_shape" not in tiling_window_params: 579 raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.") 580 581 if "halo" not in tiling_window_params: 582 raise RuntimeError("'halo' parameter is missing from the provided parameters.") 583 584 amg_class = TiledAutomaticMaskGenerator 585 else: 586 amg_class = AutomaticMaskGenerator 587 588 amg = amg_class(predictor) 589 amg_prefix = "amg" 590 591 # where the predictions are saved 592 prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference") 593 os.makedirs(prediction_folder, exist_ok=True) 594 595 # where the grid-search results are saved 596 gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search") 597 os.makedirs(gs_result_folder, exist_ok=True) 598 599 grid_search_values = instance_segmentation.default_grid_search_values_amg( 600 iou_thresh_values=iou_thresh_values, 601 stability_score_values=stability_score_values, 602 ) 603 604 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 605 segmenter=amg, 606 grid_search_values=grid_search_values, 607 val_image_paths=val_image_paths, 608 val_gt_paths=val_gt_paths, 609 test_image_paths=test_image_paths, 610 embedding_dir=embedding_folder, 611 prediction_dir=prediction_folder, 612 result_dir=gs_result_folder, 613 experiment_folder=experiment_folder, 614 tiling_window_params=tiling_window_params, 615 ) 616 return prediction_folder 617 618 619# 620# INSTANCE SEGMENTATION FUNCTION 621# 622 623 624def run_instance_segmentation_with_decoder( 625 checkpoint: Union[str, os.PathLike], 626 model_type: str, 627 experiment_folder: Union[str, os.PathLike], 628 val_image_paths: List[Union[str, os.PathLike]], 629 val_gt_paths: List[Union[str, os.PathLike]], 630 test_image_paths: List[Union[str, os.PathLike]], 631 peft_kwargs: Optional[Dict] = None, 632 cache_embeddings: bool = False, 633 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 634) -> str: 635 """Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS). 636 637 Args: 638 checkpoint: The filepath to model checkpoints. 639 model_type: The segment anything model choice. 640 experimet_folder: The directory where the relevant files are saved. 641 val_image_paths: The list of filepaths of input images for grid-search. 642 val_gt_paths: The list of filepaths of corresponding labels for grid-search. 643 test_image_paths: The list of filepaths of input images for automatic instance segmentation. 644 peft_kwargs: Keyword arguments for th PEFT wrapper class. 645 cache_embeddings: Whether to cache embeddings in experiment folder. 646 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 647 648 Returns: 649 Filepath where the predictions have been saved. 650 """ 651 652 if cache_embeddings: 653 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 654 os.makedirs(embedding_folder, exist_ok=True) 655 else: 656 embedding_folder = None 657 658 predictor, decoder = get_predictor_and_decoder( 659 model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, 660 ) 661 662 # Get the AIS class. 663 if tiling_window_params: 664 if not isinstance(tiling_window_params, dict): 665 raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.") 666 667 if "tile_shape" not in tiling_window_params: 668 raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.") 669 670 if "halo" not in tiling_window_params: 671 raise RuntimeError("'halo' parameter is missing from the provided parameters.") 672 673 ais_class = TiledInstanceSegmentationWithDecoder 674 else: 675 ais_class = InstanceSegmentationWithDecoder 676 677 segmenter = ais_class(predictor, decoder) 678 seg_prefix = "instance_segmentation_with_decoder" 679 680 # where the predictions are saved 681 prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") 682 os.makedirs(prediction_folder, exist_ok=True) 683 684 # where the grid-search results are saved 685 gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") 686 os.makedirs(gs_result_folder, exist_ok=True) 687 688 grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder() 689 690 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 691 segmenter=segmenter, 692 grid_search_values=grid_search_values, 693 val_image_paths=val_image_paths, 694 val_gt_paths=val_gt_paths, 695 test_image_paths=test_image_paths, 696 embedding_dir=embedding_folder, 697 prediction_dir=prediction_folder, 698 result_dir=gs_result_folder, 699 experiment_folder=experiment_folder, 700 tiling_window_params=tiling_window_params, 701 ) 702 return prediction_folder
134def precompute_all_embeddings( 135 predictor: SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], 136) -> None: 137 """Precompute all image embeddings. 138 139 To enable running different inference tasks in parallel afterwards. 140 141 Args: 142 predictor: The SegmentAnything predictor. 143 image_paths: The image file paths. 144 embedding_dir: The directory where the embeddings will be saved. 145 """ 146 for image_path in tqdm(image_paths, desc="Precompute embeddings"): 147 image_name = os.path.basename(image_path) 148 im = imageio.imread(image_path) 149 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 150 util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)
Precompute all image embeddings.
To enable running different inference tasks in parallel afterwards.
Arguments:
- predictor: The SegmentAnything predictor.
- image_paths: The image file paths.
- embedding_dir: The directory where the embeddings will be saved.
169def precompute_all_prompts( 170 gt_paths: List[Union[str, os.PathLike]], 171 prompt_save_dir: Union[str, os.PathLike], 172 prompt_settings: List[Dict[str, Any]], 173) -> None: 174 """Precompute all point prompts. 175 176 To enable running different inference tasks in parallel afterwards. 177 178 Args: 179 gt_paths: The file paths to the ground-truth segmentations. 180 prompt_save_dir: The directory where the prompt files will be saved. 181 prompt_settings: The settings for which the prompts will be computed. 182 """ 183 os.makedirs(prompt_save_dir, exist_ok=True) 184 185 for settings in tqdm(prompt_settings, desc="Precompute prompts"): 186 187 use_points, use_boxes = settings["use_points"], settings["use_boxes"] 188 n_positives, n_negatives = settings["n_positives"], settings["n_negatives"] 189 dilation = settings.get("dilation", 5) 190 191 # check if the prompts were already computed 192 if use_boxes and not use_points: 193 prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl") 194 else: 195 prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl") 196 if os.path.exists(prompt_save_path): 197 continue 198 199 results = [] 200 for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"): 201 prompts = _precompute_prompts( 202 gt_path, 203 use_points=use_points, 204 use_boxes=use_boxes, 205 n_positives=n_positives, 206 n_negatives=n_negatives, 207 dilation=dilation, 208 ) 209 results.append(prompts) 210 211 saved_prompts = {res[0]: res[1] for res in results} 212 with open(prompt_save_path, "wb") as f: 213 pickle.dump(saved_prompts, f)
Precompute all point prompts.
To enable running different inference tasks in parallel afterwards.
Arguments:
- gt_paths: The file paths to the ground-truth segmentations.
- prompt_save_dir: The directory where the prompt files will be saved.
- prompt_settings: The settings for which the prompts will be computed.
254def run_inference_with_prompts( 255 predictor: SamPredictor, 256 image_paths: List[Union[str, os.PathLike]], 257 gt_paths: List[Union[str, os.PathLike]], 258 embedding_dir: Union[str, os.PathLike], 259 prediction_dir: Union[str, os.PathLike], 260 use_points: bool, 261 use_boxes: bool, 262 n_positives: int, 263 n_negatives: int, 264 dilation: int = 5, 265 prompt_save_dir: Optional[Union[str, os.PathLike]] = None, 266 batch_size: int = 512, 267) -> None: 268 """Run segment anything inference for multiple images using prompts derived from groundtruth. 269 270 Args: 271 predictor: The SegmentAnything predictor. 272 image_paths: The image file paths. 273 gt_paths: The ground-truth segmentation file paths. 274 embedding_dir: The directory where the image embddings will be saved or are already saved. 275 use_points: Whether to use point prompts. 276 use_boxes: Whether to use box prompts 277 n_positives: The number of positive point prompts that will be sampled. 278 n_negativess: The number of negative point prompts that will be sampled. 279 dilation: The dilation factor for the radius around the ground-truth object 280 around which points will not be sampled. 281 prompt_save_dir: The directory where point prompts will be saved or are already saved. 282 This enables running multiple experiments in a reproducible manner. 283 batch_size: The batch size used for batched prediction. 284 """ 285 if not (use_points or use_boxes): 286 raise ValueError("You need to use at least one of point or box prompts.") 287 288 if len(image_paths) != len(gt_paths): 289 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 290 291 (cached_point_prompts, save_point_prompts, point_prompt_save_path, 292 cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching( 293 prompt_save_dir, use_points, use_boxes, n_positives, n_negatives 294 ) 295 296 os.makedirs(prediction_dir, exist_ok=True) 297 for image_path, gt_path in tqdm( 298 zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts" 299 ): 300 image_name = os.path.basename(image_path) 301 label_name = os.path.basename(gt_path) 302 303 # We skip the images that already have been segmented. 304 prediction_path = os.path.join(prediction_dir, image_name) 305 if os.path.exists(prediction_path): 306 continue 307 308 assert os.path.exists(image_path), image_path 309 assert os.path.exists(gt_path), gt_path 310 311 im = imageio.imread(image_path) 312 gt = imageio.imread(gt_path).astype("uint32") 313 gt = relabel_sequential(gt)[0] 314 315 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 316 this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts( 317 cached_point_prompts, save_point_prompts, 318 cached_box_prompts, save_box_prompts, 319 label_name 320 ) 321 instances, this_prompts = _run_inference_with_prompts_for_image( 322 predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives, 323 dilation=dilation, use_points=use_points, use_boxes=use_boxes, 324 batch_size=batch_size, cached_prompts=this_prompts, 325 embedding_path=embedding_path, 326 ) 327 328 if save_point_prompts: 329 cached_point_prompts[label_name] = this_prompts[:2] 330 if save_box_prompts: 331 cached_box_prompts[label_name] = this_prompts[-1] 332 333 # It's important to compress here, otherwise the predictions would take up a lot of space. 334 imageio.imwrite(prediction_path, instances, compression=5) 335 336 # Save the prompts if we run experiments with prompt caching and have computed them 337 # for the first time. 338 if save_point_prompts: 339 with open(point_prompt_save_path, "wb") as f: 340 pickle.dump(cached_point_prompts, f) 341 if save_box_prompts: 342 with open(box_prompt_save_path, "wb") as f: 343 pickle.dump(cached_box_prompts, f)
Run segment anything inference for multiple images using prompts derived from groundtruth.
Arguments:
- predictor: The SegmentAnything predictor.
- image_paths: The image file paths.
- gt_paths: The ground-truth segmentation file paths.
- embedding_dir: The directory where the image embddings will be saved or are already saved.
- use_points: Whether to use point prompts.
- use_boxes: Whether to use box prompts
- n_positives: The number of positive point prompts that will be sampled.
- n_negativess: The number of negative point prompts that will be sampled.
- dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
- prompt_save_dir: The directory where point prompts will be saved or are already saved. This enables running multiple experiments in a reproducible manner.
- batch_size: The batch size used for batched prediction.
461def run_inference_with_iterative_prompting( 462 predictor: SamPredictor, 463 image_paths: List[Union[str, os.PathLike]], 464 gt_paths: List[Union[str, os.PathLike]], 465 embedding_dir: Union[str, os.PathLike], 466 prediction_dir: Union[str, os.PathLike], 467 start_with_box_prompt: bool = True, 468 dilation: int = 5, 469 batch_size: int = 32, 470 n_iterations: int = 8, 471 use_masks: bool = False 472) -> None: 473 """Run Segment Anything inference for multiple images using prompts iteratively 474 derived from model outputs and ground-truth. 475 476 Args: 477 predictor: The Segment Anything predictor. 478 image_paths: The image file paths. 479 gt_paths: The ground-truth segmentation file paths. 480 embedding_dir: The directory where the image embeddings will be saved or are already saved. 481 prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration. 482 start_with_box_prompt: Whether to use the first prompt as bounding box or a single point 483 dilation: The dilation factor for the radius around the ground-truth object 484 around which points will not be sampled. 485 batch_size: The batch size used for batched predictions. 486 n_iterations: The number of iterations for iterative prompting. 487 use_masks: Whether to make use of logits from previous prompt-based segmentation. 488 """ 489 if len(image_paths) != len(gt_paths): 490 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 491 492 # create all prediction folders for all intermediate iterations 493 for i in range(n_iterations): 494 os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True) 495 496 if use_masks: 497 print("The iterative prompting will make use of logits masks from previous iterations.") 498 499 for image_path, gt_path in tqdm( 500 zip(image_paths, gt_paths), total=len(image_paths), 501 desc="Run inference with iterative prompting for all images", 502 ): 503 image_name = os.path.basename(image_path) 504 505 # We skip the images that already have been segmented 506 prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)] 507 if all(os.path.exists(prediction_path) for prediction_path in prediction_paths): 508 continue 509 510 assert os.path.exists(image_path), image_path 511 assert os.path.exists(gt_path), gt_path 512 513 image = imageio.imread(image_path) 514 gt = imageio.imread(gt_path).astype("uint32") 515 gt = relabel_sequential(gt)[0] 516 517 if embedding_dir is None: 518 embedding_path = None 519 else: 520 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 521 522 _run_inference_with_iterative_prompting_for_image( 523 predictor, image, gt, start_with_box_prompt=start_with_box_prompt, 524 dilation=dilation, batch_size=batch_size, embedding_path=embedding_path, 525 n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks 526 )
Run Segment Anything inference for multiple images using prompts iteratively derived from model outputs and ground-truth.
Arguments:
- predictor: The Segment Anything predictor.
- image_paths: The image file paths.
- gt_paths: The ground-truth segmentation file paths.
- embedding_dir: The directory where the image embeddings will be saved or are already saved.
- prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
- start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
- dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
- batch_size: The batch size used for batched predictions.
- n_iterations: The number of iterations for iterative prompting.
- use_masks: Whether to make use of logits from previous prompt-based segmentation.
534def run_amg( 535 checkpoint: Union[str, os.PathLike], 536 model_type: str, 537 experiment_folder: Union[str, os.PathLike], 538 val_image_paths: List[Union[str, os.PathLike]], 539 val_gt_paths: List[Union[str, os.PathLike]], 540 test_image_paths: List[Union[str, os.PathLike]], 541 iou_thresh_values: Optional[List[float]] = None, 542 stability_score_values: Optional[List[float]] = None, 543 peft_kwargs: Optional[Dict] = None, 544 cache_embeddings: bool = False, 545 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 546) -> str: 547 """Run Segment Anything inference for multiple images using automatic mask generation (AMG). 548 549 Args: 550 checkpoint: The filepath to model checkpoints. 551 model_type: The segment anything model choice. 552 experimet_folder: The directory where the relevant files are saved. 553 val_image_paths: The list of filepaths of input images for grid-search. 554 val_gt_paths: The list of filepaths of corresponding labels for grid-search. 555 test_image_paths: The list of filepaths of input images for automatic instance segmentation. 556 iou_thresh_values: Optional choice of values for grid search of `iou_thresh` parameter. 557 stability_score_values: Optional choice of values for grid search of `stability_score` parameter. 558 peft_kwargs: Keyword arguments for th PEFT wrapper class. 559 cache_embeddings: Whether to cache embeddings in experiment folder. 560 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 561 562 Returns: 563 Filepath where the predictions have been saved. 564 """ 565 566 if cache_embeddings: 567 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 568 os.makedirs(embedding_folder, exist_ok=True) 569 else: 570 embedding_folder = None 571 572 predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs) 573 574 # Get the AMG class. 575 if tiling_window_params: 576 if not isinstance(tiling_window_params, dict): 577 raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.") 578 579 if "tile_shape" not in tiling_window_params: 580 raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.") 581 582 if "halo" not in tiling_window_params: 583 raise RuntimeError("'halo' parameter is missing from the provided parameters.") 584 585 amg_class = TiledAutomaticMaskGenerator 586 else: 587 amg_class = AutomaticMaskGenerator 588 589 amg = amg_class(predictor) 590 amg_prefix = "amg" 591 592 # where the predictions are saved 593 prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference") 594 os.makedirs(prediction_folder, exist_ok=True) 595 596 # where the grid-search results are saved 597 gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search") 598 os.makedirs(gs_result_folder, exist_ok=True) 599 600 grid_search_values = instance_segmentation.default_grid_search_values_amg( 601 iou_thresh_values=iou_thresh_values, 602 stability_score_values=stability_score_values, 603 ) 604 605 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 606 segmenter=amg, 607 grid_search_values=grid_search_values, 608 val_image_paths=val_image_paths, 609 val_gt_paths=val_gt_paths, 610 test_image_paths=test_image_paths, 611 embedding_dir=embedding_folder, 612 prediction_dir=prediction_folder, 613 result_dir=gs_result_folder, 614 experiment_folder=experiment_folder, 615 tiling_window_params=tiling_window_params, 616 ) 617 return prediction_folder
Run Segment Anything inference for multiple images using automatic mask generation (AMG).
Arguments:
- checkpoint: The filepath to model checkpoints.
- model_type: The segment anything model choice.
- experimet_folder: The directory where the relevant files are saved.
- val_image_paths: The list of filepaths of input images for grid-search.
- val_gt_paths: The list of filepaths of corresponding labels for grid-search.
- test_image_paths: The list of filepaths of input images for automatic instance segmentation.
- iou_thresh_values: Optional choice of values for grid search of
iou_thresh
parameter. - stability_score_values: Optional choice of values for grid search of
stability_score
parameter. - peft_kwargs: Keyword arguments for th PEFT wrapper class.
- cache_embeddings: Whether to cache embeddings in experiment folder.
- tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
Returns:
Filepath where the predictions have been saved.
625def run_instance_segmentation_with_decoder( 626 checkpoint: Union[str, os.PathLike], 627 model_type: str, 628 experiment_folder: Union[str, os.PathLike], 629 val_image_paths: List[Union[str, os.PathLike]], 630 val_gt_paths: List[Union[str, os.PathLike]], 631 test_image_paths: List[Union[str, os.PathLike]], 632 peft_kwargs: Optional[Dict] = None, 633 cache_embeddings: bool = False, 634 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 635) -> str: 636 """Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS). 637 638 Args: 639 checkpoint: The filepath to model checkpoints. 640 model_type: The segment anything model choice. 641 experimet_folder: The directory where the relevant files are saved. 642 val_image_paths: The list of filepaths of input images for grid-search. 643 val_gt_paths: The list of filepaths of corresponding labels for grid-search. 644 test_image_paths: The list of filepaths of input images for automatic instance segmentation. 645 peft_kwargs: Keyword arguments for th PEFT wrapper class. 646 cache_embeddings: Whether to cache embeddings in experiment folder. 647 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 648 649 Returns: 650 Filepath where the predictions have been saved. 651 """ 652 653 if cache_embeddings: 654 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 655 os.makedirs(embedding_folder, exist_ok=True) 656 else: 657 embedding_folder = None 658 659 predictor, decoder = get_predictor_and_decoder( 660 model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, 661 ) 662 663 # Get the AIS class. 664 if tiling_window_params: 665 if not isinstance(tiling_window_params, dict): 666 raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.") 667 668 if "tile_shape" not in tiling_window_params: 669 raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.") 670 671 if "halo" not in tiling_window_params: 672 raise RuntimeError("'halo' parameter is missing from the provided parameters.") 673 674 ais_class = TiledInstanceSegmentationWithDecoder 675 else: 676 ais_class = InstanceSegmentationWithDecoder 677 678 segmenter = ais_class(predictor, decoder) 679 seg_prefix = "instance_segmentation_with_decoder" 680 681 # where the predictions are saved 682 prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") 683 os.makedirs(prediction_folder, exist_ok=True) 684 685 # where the grid-search results are saved 686 gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") 687 os.makedirs(gs_result_folder, exist_ok=True) 688 689 grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder() 690 691 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 692 segmenter=segmenter, 693 grid_search_values=grid_search_values, 694 val_image_paths=val_image_paths, 695 val_gt_paths=val_gt_paths, 696 test_image_paths=test_image_paths, 697 embedding_dir=embedding_folder, 698 prediction_dir=prediction_folder, 699 result_dir=gs_result_folder, 700 experiment_folder=experiment_folder, 701 tiling_window_params=tiling_window_params, 702 ) 703 return prediction_folder
Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS).
Arguments:
- checkpoint: The filepath to model checkpoints.
- model_type: The segment anything model choice.
- experimet_folder: The directory where the relevant files are saved.
- val_image_paths: The list of filepaths of input images for grid-search.
- val_gt_paths: The list of filepaths of corresponding labels for grid-search.
- test_image_paths: The list of filepaths of input images for automatic instance segmentation.
- peft_kwargs: Keyword arguments for th PEFT wrapper class.
- cache_embeddings: Whether to cache embeddings in experiment folder.
- tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
Returns:
Filepath where the predictions have been saved.