micro_sam.evaluation.inference
Inference with Segment Anything models and different prompt strategies.
1"""Inference with Segment Anything models and different prompt strategies. 2""" 3 4import os 5import pickle 6import numpy as np 7from tqdm import tqdm 8from copy import deepcopy 9from typing import Any, Dict, List, Optional, Union 10 11import imageio.v3 as imageio 12from skimage.segmentation import relabel_sequential 13 14import torch 15 16from segment_anything import SamPredictor 17 18from .. import util as util 19from ..inference import batched_inference 20from ..instance_segmentation import ( 21 mask_data_to_segmentation, get_predictor_and_decoder, 22 AutomaticMaskGenerator, InstanceSegmentationWithDecoder, 23) 24from . import instance_segmentation 25from ..prompt_generators import PointAndBoxPromptGenerator, IterativePromptGenerator 26 27 28def _load_prompts( 29 cached_point_prompts, save_point_prompts, 30 cached_box_prompts, save_box_prompts, 31 image_name 32): 33 34 def load_prompt_type(cached_prompts, save_prompts): 35 # Check if we have saved prompts. 36 if cached_prompts is None or save_prompts: # we don't have cached prompts 37 return cached_prompts, None 38 39 # we have cached prompts, but they have not been loaded yet 40 if isinstance(cached_prompts, str): 41 with open(cached_prompts, "rb") as f: 42 cached_prompts = pickle.load(f) 43 44 prompts = cached_prompts[image_name] 45 return cached_prompts, prompts 46 47 cached_point_prompts, point_prompts = load_prompt_type(cached_point_prompts, save_point_prompts) 48 cached_box_prompts, box_prompts = load_prompt_type(cached_box_prompts, save_box_prompts) 49 50 # we don't have anything cached 51 if point_prompts is None and box_prompts is None: 52 return None, cached_point_prompts, cached_box_prompts 53 54 if point_prompts is None: 55 input_point, input_label = [], [] 56 else: 57 input_point, input_label = point_prompts 58 59 if box_prompts is None: 60 input_box = [] 61 else: 62 input_box = box_prompts 63 64 prompts = (input_point, input_label, input_box) 65 return prompts, cached_point_prompts, cached_box_prompts 66 67 68def _get_batched_prompts( 69 gt, 70 gt_ids, 71 use_points, 72 use_boxes, 73 n_positives, 74 n_negatives, 75 dilation, 76): 77 # Initialize the prompt generator. 78 prompt_generator = PointAndBoxPromptGenerator( 79 n_positive_points=n_positives, n_negative_points=n_negatives, 80 dilation_strength=dilation, get_point_prompts=use_points, 81 get_box_prompts=use_boxes 82 ) 83 84 # Generate the prompts. 85 center_coordinates, bbox_coordinates = util.get_centers_and_bounding_boxes(gt) 86 center_coordinates = [center_coordinates[gt_id] for gt_id in gt_ids] 87 bbox_coordinates = [bbox_coordinates[gt_id] for gt_id in gt_ids] 88 masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids) 89 90 points, point_labels, boxes, _ = prompt_generator( 91 masks, bbox_coordinates, center_coordinates 92 ) 93 94 def to_numpy(x): 95 if x is None: 96 return x 97 return x.numpy() 98 99 return to_numpy(points), to_numpy(point_labels), to_numpy(boxes) 100 101 102def _run_inference_with_prompts_for_image( 103 predictor, 104 image, 105 gt, 106 use_points, 107 use_boxes, 108 n_positives, 109 n_negatives, 110 dilation, 111 batch_size, 112 cached_prompts, 113 embedding_path, 114): 115 gt_ids = np.unique(gt)[1:] 116 if cached_prompts is None: 117 points, point_labels, boxes = _get_batched_prompts( 118 gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation 119 ) 120 else: 121 points, point_labels, boxes = cached_prompts 122 123 # Make a copy of the point prompts to return them at the end. 124 prompts = deepcopy((points, point_labels, boxes)) 125 126 # Use multi-masking only if we have a single positive point without box 127 multimasking = False 128 if not use_boxes and (n_positives == 1 and n_negatives == 0): 129 multimasking = True 130 131 instance_labels = batched_inference( 132 predictor, image, batch_size, 133 boxes=boxes, points=points, point_labels=point_labels, 134 multimasking=multimasking, embedding_path=embedding_path, 135 return_instance_segmentation=True, 136 ) 137 138 return instance_labels, prompts 139 140 141def precompute_all_embeddings( 142 predictor: SamPredictor, 143 image_paths: List[Union[str, os.PathLike]], 144 embedding_dir: Union[str, os.PathLike], 145) -> None: 146 """Precompute all image embeddings. 147 148 To enable running different inference tasks in parallel afterwards. 149 150 Args: 151 predictor: The SegmentAnything predictor. 152 image_paths: The image file paths. 153 embedding_dir: The directory where the embeddings will be saved. 154 """ 155 for image_path in tqdm(image_paths, desc="Precompute embeddings"): 156 image_name = os.path.basename(image_path) 157 im = imageio.imread(image_path) 158 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 159 util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2) 160 161 162def _precompute_prompts(gt_path, use_points, use_boxes, n_positives, n_negatives, dilation): 163 name = os.path.basename(gt_path) 164 165 gt = imageio.imread(gt_path).astype("uint32") 166 gt = relabel_sequential(gt)[0] 167 gt_ids = np.unique(gt)[1:] 168 169 input_point, input_label, input_box = _get_batched_prompts( 170 gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation 171 ) 172 173 if use_boxes and not use_points: 174 return name, input_box 175 return name, (input_point, input_label) 176 177 178def precompute_all_prompts( 179 gt_paths: List[Union[str, os.PathLike]], 180 prompt_save_dir: Union[str, os.PathLike], 181 prompt_settings: List[Dict[str, Any]], 182) -> None: 183 """Precompute all point prompts. 184 185 To enable running different inference tasks in parallel afterwards. 186 187 Args: 188 gt_paths: The file paths to the ground-truth segmentations. 189 prompt_save_dir: The directory where the prompt files will be saved. 190 prompt_settings: The settings for which the prompts will be computed. 191 """ 192 os.makedirs(prompt_save_dir, exist_ok=True) 193 194 for settings in tqdm(prompt_settings, desc="Precompute prompts"): 195 196 use_points, use_boxes = settings["use_points"], settings["use_boxes"] 197 n_positives, n_negatives = settings["n_positives"], settings["n_negatives"] 198 dilation = settings.get("dilation", 5) 199 200 # check if the prompts were already computed 201 if use_boxes and not use_points: 202 prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl") 203 else: 204 prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl") 205 if os.path.exists(prompt_save_path): 206 continue 207 208 results = [] 209 for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"): 210 prompts = _precompute_prompts( 211 gt_path, 212 use_points=use_points, 213 use_boxes=use_boxes, 214 n_positives=n_positives, 215 n_negatives=n_negatives, 216 dilation=dilation, 217 ) 218 results.append(prompts) 219 220 saved_prompts = {res[0]: res[1] for res in results} 221 with open(prompt_save_path, "wb") as f: 222 pickle.dump(saved_prompts, f) 223 224 225def _get_prompt_caching(prompt_save_dir, use_points, use_boxes, n_positives, n_negatives): 226 227 def get_prompt_type_caching(use_type, save_name): 228 if not use_type: 229 return None, False, None 230 231 prompt_save_path = os.path.join(prompt_save_dir, save_name) 232 if os.path.exists(prompt_save_path): 233 print("Using precomputed prompts from", prompt_save_path) 234 # We delay loading the prompts, so we only have to load them once they're needed the first time. 235 # This avoids loading the prompts (which are in a big pickle file) if all predictions are done already. 236 cached_prompts = prompt_save_path 237 save_prompts = False 238 else: 239 print("Saving prompts in", prompt_save_path) 240 cached_prompts = {} 241 save_prompts = True 242 return cached_prompts, save_prompts, prompt_save_path 243 244 # Check if prompt serialization is enabled. 245 # If it is then load the prompts if they are already cached and otherwise store them. 246 if prompt_save_dir is None: 247 print("Prompts are not cached.") 248 cached_point_prompts, cached_box_prompts = None, None 249 save_point_prompts, save_box_prompts = False, False 250 point_prompt_save_path, box_prompt_save_path = None, None 251 else: 252 cached_point_prompts, save_point_prompts, point_prompt_save_path = get_prompt_type_caching( 253 use_points, f"points-p{n_positives}-n{n_negatives}.pkl" 254 ) 255 cached_box_prompts, save_box_prompts, box_prompt_save_path = get_prompt_type_caching( 256 use_boxes, "boxes.pkl" 257 ) 258 259 return (cached_point_prompts, save_point_prompts, point_prompt_save_path, 260 cached_box_prompts, save_box_prompts, box_prompt_save_path) 261 262 263def run_inference_with_prompts( 264 predictor: SamPredictor, 265 image_paths: List[Union[str, os.PathLike]], 266 gt_paths: List[Union[str, os.PathLike]], 267 embedding_dir: Union[str, os.PathLike], 268 prediction_dir: Union[str, os.PathLike], 269 use_points: bool, 270 use_boxes: bool, 271 n_positives: int, 272 n_negatives: int, 273 dilation: int = 5, 274 prompt_save_dir: Optional[Union[str, os.PathLike]] = None, 275 batch_size: int = 512, 276) -> None: 277 """Run segment anything inference for multiple images using prompts derived from groundtruth. 278 279 Args: 280 predictor: The SegmentAnything predictor. 281 image_paths: The image file paths. 282 gt_paths: The ground-truth segmentation file paths. 283 embedding_dir: The directory where the image embddings will be saved or are already saved. 284 use_points: Whether to use point prompts. 285 use_boxes: Whether to use box prompts 286 n_positives: The number of positive point prompts that will be sampled. 287 n_negativess: The number of negative point prompts that will be sampled. 288 dilation: The dilation factor for the radius around the ground-truth object 289 around which points will not be sampled. 290 prompt_save_dir: The directory where point prompts will be saved or are already saved. 291 This enables running multiple experiments in a reproducible manner. 292 batch_size: The batch size used for batched prediction. 293 """ 294 if not (use_points or use_boxes): 295 raise ValueError("You need to use at least one of point or box prompts.") 296 297 if len(image_paths) != len(gt_paths): 298 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 299 300 (cached_point_prompts, save_point_prompts, point_prompt_save_path, 301 cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching( 302 prompt_save_dir, use_points, use_boxes, n_positives, n_negatives 303 ) 304 305 os.makedirs(prediction_dir, exist_ok=True) 306 for image_path, gt_path in tqdm( 307 zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts" 308 ): 309 image_name = os.path.basename(image_path) 310 label_name = os.path.basename(gt_path) 311 312 # We skip the images that already have been segmented. 313 prediction_path = os.path.join(prediction_dir, image_name) 314 if os.path.exists(prediction_path): 315 continue 316 317 assert os.path.exists(image_path), image_path 318 assert os.path.exists(gt_path), gt_path 319 320 im = imageio.imread(image_path) 321 gt = imageio.imread(gt_path).astype("uint32") 322 gt = relabel_sequential(gt)[0] 323 324 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 325 this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts( 326 cached_point_prompts, save_point_prompts, 327 cached_box_prompts, save_box_prompts, 328 label_name 329 ) 330 instances, this_prompts = _run_inference_with_prompts_for_image( 331 predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives, 332 dilation=dilation, use_points=use_points, use_boxes=use_boxes, 333 batch_size=batch_size, cached_prompts=this_prompts, 334 embedding_path=embedding_path, 335 ) 336 337 if save_point_prompts: 338 cached_point_prompts[label_name] = this_prompts[:2] 339 if save_box_prompts: 340 cached_box_prompts[label_name] = this_prompts[-1] 341 342 # It's important to compress here, otherwise the predictions would take up a lot of space. 343 imageio.imwrite(prediction_path, instances, compression=5) 344 345 # Save the prompts if we run experiments with prompt caching and have computed them 346 # for the first time. 347 if save_point_prompts: 348 with open(point_prompt_save_path, "wb") as f: 349 pickle.dump(cached_point_prompts, f) 350 if save_box_prompts: 351 with open(box_prompt_save_path, "wb") as f: 352 pickle.dump(cached_box_prompts, f) 353 354 355def _save_segmentation(masks, prediction_path): 356 # masks to segmentation 357 masks = masks.cpu().numpy().squeeze(1).astype("bool") 358 masks = [{"segmentation": mask, "area": mask.sum()} for mask in masks] 359 segmentation = mask_data_to_segmentation(masks, with_background=True) 360 imageio.imwrite(prediction_path, segmentation, compression=5) 361 362 363def _get_batched_iterative_prompts(sampled_binary_gt, masks, batch_size, prompt_generator): 364 n_samples = sampled_binary_gt.shape[0] 365 n_batches = int(np.ceil(float(n_samples) / batch_size)) 366 next_coords, next_labels = [], [] 367 for batch_idx in range(n_batches): 368 batch_start = batch_idx * batch_size 369 batch_stop = min((batch_idx + 1) * batch_size, n_samples) 370 371 batch_coords, batch_labels, _, _ = prompt_generator( 372 sampled_binary_gt[batch_start: batch_stop], masks[batch_start: batch_stop] 373 ) 374 next_coords.append(batch_coords) 375 next_labels.append(batch_labels) 376 377 next_coords = torch.concatenate(next_coords) 378 next_labels = torch.concatenate(next_labels) 379 380 return next_coords, next_labels 381 382 383@torch.no_grad() 384def _run_inference_with_iterative_prompting_for_image( 385 predictor, 386 image, 387 gt, 388 start_with_box_prompt, 389 dilation, 390 batch_size, 391 embedding_path, 392 n_iterations, 393 prediction_paths, 394 use_masks=False 395) -> None: 396 verbose_embeddings = False 397 398 prompt_generator = IterativePromptGenerator() 399 400 gt_ids = np.unique(gt)[1:] 401 402 # Use multi-masking only if we have a single positive point without box 403 if start_with_box_prompt: 404 use_boxes, use_points = True, False 405 n_positives = 0 406 multimasking = False 407 else: 408 use_boxes, use_points = False, True 409 n_positives = 1 410 multimasking = True 411 412 points, point_labels, boxes = _get_batched_prompts( 413 gt, gt_ids, 414 use_points=use_points, 415 use_boxes=use_boxes, 416 n_positives=n_positives, 417 n_negatives=0, 418 dilation=dilation 419 ) 420 421 sampled_binary_gt = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids) 422 423 for iteration in range(n_iterations): 424 if iteration == 0: # logits mask can not be used for the first iteration. 425 logits_masks = None 426 else: 427 if not use_masks: # logits mask should not be used when not desired. 428 logits_masks = None 429 430 batched_outputs = batched_inference( 431 predictor=predictor, 432 image=image, 433 batch_size=batch_size, 434 boxes=boxes, 435 points=points, 436 point_labels=point_labels, 437 multimasking=multimasking, 438 embedding_path=embedding_path, 439 return_instance_segmentation=False, 440 logits_masks=logits_masks, 441 verbose_embeddings=verbose_embeddings, 442 ) 443 444 # switching off multimasking after first iter, as next iters (with multiple prompts) don't expect multimasking 445 multimasking = False 446 447 masks = torch.stack([m["segmentation"][None] for m in batched_outputs]).to(torch.float32) 448 449 next_coords, next_labels = _get_batched_iterative_prompts( 450 sampled_binary_gt, masks, batch_size, prompt_generator 451 ) 452 next_coords, next_labels = next_coords.detach().cpu().numpy(), next_labels.detach().cpu().numpy() 453 454 if points is not None: 455 points = np.concatenate([points, next_coords], axis=1) 456 else: 457 points = next_coords 458 459 if point_labels is not None: 460 point_labels = np.concatenate([point_labels, next_labels], axis=1) 461 else: 462 point_labels = next_labels 463 464 if use_masks: 465 logits_masks = torch.stack([m["logits"] for m in batched_outputs]) 466 467 _save_segmentation(masks, prediction_paths[iteration]) 468 469 470def run_inference_with_iterative_prompting( 471 predictor: SamPredictor, 472 image_paths: List[Union[str, os.PathLike]], 473 gt_paths: List[Union[str, os.PathLike]], 474 embedding_dir: Union[str, os.PathLike], 475 prediction_dir: Union[str, os.PathLike], 476 start_with_box_prompt: bool = True, 477 dilation: int = 5, 478 batch_size: int = 32, 479 n_iterations: int = 8, 480 use_masks: bool = False 481) -> None: 482 """Run Segment Anything inference for multiple images using prompts iteratively 483 derived from model outputs and ground-truth. 484 485 Args: 486 predictor: The Segment Anything predictor. 487 image_paths: The image file paths. 488 gt_paths: The ground-truth segmentation file paths. 489 embedding_dir: The directory where the image embeddings will be saved or are already saved. 490 prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration. 491 start_with_box_prompt: Whether to use the first prompt as bounding box or a single point 492 dilation: The dilation factor for the radius around the ground-truth object 493 around which points will not be sampled. 494 batch_size: The batch size used for batched predictions. 495 n_iterations: The number of iterations for iterative prompting. 496 use_masks: Whether to make use of logits from previous prompt-based segmentation. 497 """ 498 if len(image_paths) != len(gt_paths): 499 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 500 501 # create all prediction folders for all intermediate iterations 502 for i in range(n_iterations): 503 os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True) 504 505 if use_masks: 506 print("The iterative prompting will make use of logits masks from previous iterations.") 507 508 for image_path, gt_path in tqdm( 509 zip(image_paths, gt_paths), total=len(image_paths), 510 desc="Run inference with iterative prompting for all images", 511 ): 512 image_name = os.path.basename(image_path) 513 514 # We skip the images that already have been segmented 515 prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)] 516 if all(os.path.exists(prediction_path) for prediction_path in prediction_paths): 517 continue 518 519 assert os.path.exists(image_path), image_path 520 assert os.path.exists(gt_path), gt_path 521 522 image = imageio.imread(image_path) 523 gt = imageio.imread(gt_path).astype("uint32") 524 gt = relabel_sequential(gt)[0] 525 526 if embedding_dir is None: 527 embedding_path = None 528 else: 529 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 530 531 _run_inference_with_iterative_prompting_for_image( 532 predictor, image, gt, start_with_box_prompt=start_with_box_prompt, 533 dilation=dilation, batch_size=batch_size, embedding_path=embedding_path, 534 n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks 535 ) 536 537 538# 539# AMG FUNCTION 540# 541 542 543def run_amg( 544 checkpoint: Union[str, os.PathLike], 545 model_type: str, 546 experiment_folder: Union[str, os.PathLike], 547 val_image_paths: List[Union[str, os.PathLike]], 548 val_gt_paths: List[Union[str, os.PathLike]], 549 test_image_paths: List[Union[str, os.PathLike]], 550 iou_thresh_values: Optional[List[float]] = None, 551 stability_score_values: Optional[List[float]] = None, 552 peft_kwargs: Optional[Dict] = None, 553) -> str: 554 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 555 os.makedirs(embedding_folder, exist_ok=True) 556 557 predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs) 558 amg = AutomaticMaskGenerator(predictor) 559 amg_prefix = "amg" 560 561 # where the predictions are saved 562 prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference") 563 os.makedirs(prediction_folder, exist_ok=True) 564 565 # where the grid-search results are saved 566 gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search") 567 os.makedirs(gs_result_folder, exist_ok=True) 568 569 grid_search_values = instance_segmentation.default_grid_search_values_amg( 570 iou_thresh_values=iou_thresh_values, 571 stability_score_values=stability_score_values, 572 ) 573 574 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 575 amg, grid_search_values, 576 val_image_paths, val_gt_paths, test_image_paths, 577 embedding_folder, prediction_folder, gs_result_folder, 578 ) 579 return prediction_folder 580 581 582# 583# INSTANCE SEGMENTATION FUNCTION 584# 585 586 587def run_instance_segmentation_with_decoder( 588 checkpoint: Union[str, os.PathLike], 589 model_type: str, 590 experiment_folder: Union[str, os.PathLike], 591 val_image_paths: List[Union[str, os.PathLike]], 592 val_gt_paths: List[Union[str, os.PathLike]], 593 test_image_paths: List[Union[str, os.PathLike]], 594 peft_kwargs: Optional[Dict] = None, 595) -> str: 596 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 597 os.makedirs(embedding_folder, exist_ok=True) 598 599 predictor, decoder = get_predictor_and_decoder( 600 model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, 601 ) 602 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 603 seg_prefix = "instance_segmentation_with_decoder" 604 605 # where the predictions are saved 606 prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") 607 os.makedirs(prediction_folder, exist_ok=True) 608 609 # where the grid-search results are saved 610 gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") 611 os.makedirs(gs_result_folder, exist_ok=True) 612 613 grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder() 614 615 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 616 segmenter, grid_search_values, 617 val_image_paths, val_gt_paths, test_image_paths, 618 embedding_dir=embedding_folder, prediction_dir=prediction_folder, 619 result_dir=gs_result_folder, 620 ) 621 return prediction_folder
def
precompute_all_embeddings( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike]) -> None:
142def precompute_all_embeddings( 143 predictor: SamPredictor, 144 image_paths: List[Union[str, os.PathLike]], 145 embedding_dir: Union[str, os.PathLike], 146) -> None: 147 """Precompute all image embeddings. 148 149 To enable running different inference tasks in parallel afterwards. 150 151 Args: 152 predictor: The SegmentAnything predictor. 153 image_paths: The image file paths. 154 embedding_dir: The directory where the embeddings will be saved. 155 """ 156 for image_path in tqdm(image_paths, desc="Precompute embeddings"): 157 image_name = os.path.basename(image_path) 158 im = imageio.imread(image_path) 159 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 160 util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)
Precompute all image embeddings.
To enable running different inference tasks in parallel afterwards.
Arguments:
- predictor: The SegmentAnything predictor.
- image_paths: The image file paths.
- embedding_dir: The directory where the embeddings will be saved.
def
precompute_all_prompts( gt_paths: List[Union[str, os.PathLike]], prompt_save_dir: Union[str, os.PathLike], prompt_settings: List[Dict[str, Any]]) -> None:
179def precompute_all_prompts( 180 gt_paths: List[Union[str, os.PathLike]], 181 prompt_save_dir: Union[str, os.PathLike], 182 prompt_settings: List[Dict[str, Any]], 183) -> None: 184 """Precompute all point prompts. 185 186 To enable running different inference tasks in parallel afterwards. 187 188 Args: 189 gt_paths: The file paths to the ground-truth segmentations. 190 prompt_save_dir: The directory where the prompt files will be saved. 191 prompt_settings: The settings for which the prompts will be computed. 192 """ 193 os.makedirs(prompt_save_dir, exist_ok=True) 194 195 for settings in tqdm(prompt_settings, desc="Precompute prompts"): 196 197 use_points, use_boxes = settings["use_points"], settings["use_boxes"] 198 n_positives, n_negatives = settings["n_positives"], settings["n_negatives"] 199 dilation = settings.get("dilation", 5) 200 201 # check if the prompts were already computed 202 if use_boxes and not use_points: 203 prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl") 204 else: 205 prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl") 206 if os.path.exists(prompt_save_path): 207 continue 208 209 results = [] 210 for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"): 211 prompts = _precompute_prompts( 212 gt_path, 213 use_points=use_points, 214 use_boxes=use_boxes, 215 n_positives=n_positives, 216 n_negatives=n_negatives, 217 dilation=dilation, 218 ) 219 results.append(prompts) 220 221 saved_prompts = {res[0]: res[1] for res in results} 222 with open(prompt_save_path, "wb") as f: 223 pickle.dump(saved_prompts, f)
Precompute all point prompts.
To enable running different inference tasks in parallel afterwards.
Arguments:
- gt_paths: The file paths to the ground-truth segmentations.
- prompt_save_dir: The directory where the prompt files will be saved.
- prompt_settings: The settings for which the prompts will be computed.
def
run_inference_with_prompts( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], gt_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], use_points: bool, use_boxes: bool, n_positives: int, n_negatives: int, dilation: int = 5, prompt_save_dir: Union[str, os.PathLike, NoneType] = None, batch_size: int = 512) -> None:
264def run_inference_with_prompts( 265 predictor: SamPredictor, 266 image_paths: List[Union[str, os.PathLike]], 267 gt_paths: List[Union[str, os.PathLike]], 268 embedding_dir: Union[str, os.PathLike], 269 prediction_dir: Union[str, os.PathLike], 270 use_points: bool, 271 use_boxes: bool, 272 n_positives: int, 273 n_negatives: int, 274 dilation: int = 5, 275 prompt_save_dir: Optional[Union[str, os.PathLike]] = None, 276 batch_size: int = 512, 277) -> None: 278 """Run segment anything inference for multiple images using prompts derived from groundtruth. 279 280 Args: 281 predictor: The SegmentAnything predictor. 282 image_paths: The image file paths. 283 gt_paths: The ground-truth segmentation file paths. 284 embedding_dir: The directory where the image embddings will be saved or are already saved. 285 use_points: Whether to use point prompts. 286 use_boxes: Whether to use box prompts 287 n_positives: The number of positive point prompts that will be sampled. 288 n_negativess: The number of negative point prompts that will be sampled. 289 dilation: The dilation factor for the radius around the ground-truth object 290 around which points will not be sampled. 291 prompt_save_dir: The directory where point prompts will be saved or are already saved. 292 This enables running multiple experiments in a reproducible manner. 293 batch_size: The batch size used for batched prediction. 294 """ 295 if not (use_points or use_boxes): 296 raise ValueError("You need to use at least one of point or box prompts.") 297 298 if len(image_paths) != len(gt_paths): 299 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 300 301 (cached_point_prompts, save_point_prompts, point_prompt_save_path, 302 cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching( 303 prompt_save_dir, use_points, use_boxes, n_positives, n_negatives 304 ) 305 306 os.makedirs(prediction_dir, exist_ok=True) 307 for image_path, gt_path in tqdm( 308 zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts" 309 ): 310 image_name = os.path.basename(image_path) 311 label_name = os.path.basename(gt_path) 312 313 # We skip the images that already have been segmented. 314 prediction_path = os.path.join(prediction_dir, image_name) 315 if os.path.exists(prediction_path): 316 continue 317 318 assert os.path.exists(image_path), image_path 319 assert os.path.exists(gt_path), gt_path 320 321 im = imageio.imread(image_path) 322 gt = imageio.imread(gt_path).astype("uint32") 323 gt = relabel_sequential(gt)[0] 324 325 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 326 this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts( 327 cached_point_prompts, save_point_prompts, 328 cached_box_prompts, save_box_prompts, 329 label_name 330 ) 331 instances, this_prompts = _run_inference_with_prompts_for_image( 332 predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives, 333 dilation=dilation, use_points=use_points, use_boxes=use_boxes, 334 batch_size=batch_size, cached_prompts=this_prompts, 335 embedding_path=embedding_path, 336 ) 337 338 if save_point_prompts: 339 cached_point_prompts[label_name] = this_prompts[:2] 340 if save_box_prompts: 341 cached_box_prompts[label_name] = this_prompts[-1] 342 343 # It's important to compress here, otherwise the predictions would take up a lot of space. 344 imageio.imwrite(prediction_path, instances, compression=5) 345 346 # Save the prompts if we run experiments with prompt caching and have computed them 347 # for the first time. 348 if save_point_prompts: 349 with open(point_prompt_save_path, "wb") as f: 350 pickle.dump(cached_point_prompts, f) 351 if save_box_prompts: 352 with open(box_prompt_save_path, "wb") as f: 353 pickle.dump(cached_box_prompts, f)
Run segment anything inference for multiple images using prompts derived from groundtruth.
Arguments:
- predictor: The SegmentAnything predictor.
- image_paths: The image file paths.
- gt_paths: The ground-truth segmentation file paths.
- embedding_dir: The directory where the image embddings will be saved or are already saved.
- use_points: Whether to use point prompts.
- use_boxes: Whether to use box prompts
- n_positives: The number of positive point prompts that will be sampled.
- n_negativess: The number of negative point prompts that will be sampled.
- dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
- prompt_save_dir: The directory where point prompts will be saved or are already saved. This enables running multiple experiments in a reproducible manner.
- batch_size: The batch size used for batched prediction.
def
run_inference_with_iterative_prompting( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], gt_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], start_with_box_prompt: bool = True, dilation: int = 5, batch_size: int = 32, n_iterations: int = 8, use_masks: bool = False) -> None:
471def run_inference_with_iterative_prompting( 472 predictor: SamPredictor, 473 image_paths: List[Union[str, os.PathLike]], 474 gt_paths: List[Union[str, os.PathLike]], 475 embedding_dir: Union[str, os.PathLike], 476 prediction_dir: Union[str, os.PathLike], 477 start_with_box_prompt: bool = True, 478 dilation: int = 5, 479 batch_size: int = 32, 480 n_iterations: int = 8, 481 use_masks: bool = False 482) -> None: 483 """Run Segment Anything inference for multiple images using prompts iteratively 484 derived from model outputs and ground-truth. 485 486 Args: 487 predictor: The Segment Anything predictor. 488 image_paths: The image file paths. 489 gt_paths: The ground-truth segmentation file paths. 490 embedding_dir: The directory where the image embeddings will be saved or are already saved. 491 prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration. 492 start_with_box_prompt: Whether to use the first prompt as bounding box or a single point 493 dilation: The dilation factor for the radius around the ground-truth object 494 around which points will not be sampled. 495 batch_size: The batch size used for batched predictions. 496 n_iterations: The number of iterations for iterative prompting. 497 use_masks: Whether to make use of logits from previous prompt-based segmentation. 498 """ 499 if len(image_paths) != len(gt_paths): 500 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 501 502 # create all prediction folders for all intermediate iterations 503 for i in range(n_iterations): 504 os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True) 505 506 if use_masks: 507 print("The iterative prompting will make use of logits masks from previous iterations.") 508 509 for image_path, gt_path in tqdm( 510 zip(image_paths, gt_paths), total=len(image_paths), 511 desc="Run inference with iterative prompting for all images", 512 ): 513 image_name = os.path.basename(image_path) 514 515 # We skip the images that already have been segmented 516 prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)] 517 if all(os.path.exists(prediction_path) for prediction_path in prediction_paths): 518 continue 519 520 assert os.path.exists(image_path), image_path 521 assert os.path.exists(gt_path), gt_path 522 523 image = imageio.imread(image_path) 524 gt = imageio.imread(gt_path).astype("uint32") 525 gt = relabel_sequential(gt)[0] 526 527 if embedding_dir is None: 528 embedding_path = None 529 else: 530 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 531 532 _run_inference_with_iterative_prompting_for_image( 533 predictor, image, gt, start_with_box_prompt=start_with_box_prompt, 534 dilation=dilation, batch_size=batch_size, embedding_path=embedding_path, 535 n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks 536 )
Run Segment Anything inference for multiple images using prompts iteratively derived from model outputs and ground-truth.
Arguments:
- predictor: The Segment Anything predictor.
- image_paths: The image file paths.
- gt_paths: The ground-truth segmentation file paths.
- embedding_dir: The directory where the image embeddings will be saved or are already saved.
- prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
- start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
- dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
- batch_size: The batch size used for batched predictions.
- n_iterations: The number of iterations for iterative prompting.
- use_masks: Whether to make use of logits from previous prompt-based segmentation.
def
run_amg( checkpoint: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, peft_kwargs: Optional[Dict] = None) -> str:
544def run_amg( 545 checkpoint: Union[str, os.PathLike], 546 model_type: str, 547 experiment_folder: Union[str, os.PathLike], 548 val_image_paths: List[Union[str, os.PathLike]], 549 val_gt_paths: List[Union[str, os.PathLike]], 550 test_image_paths: List[Union[str, os.PathLike]], 551 iou_thresh_values: Optional[List[float]] = None, 552 stability_score_values: Optional[List[float]] = None, 553 peft_kwargs: Optional[Dict] = None, 554) -> str: 555 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 556 os.makedirs(embedding_folder, exist_ok=True) 557 558 predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs) 559 amg = AutomaticMaskGenerator(predictor) 560 amg_prefix = "amg" 561 562 # where the predictions are saved 563 prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference") 564 os.makedirs(prediction_folder, exist_ok=True) 565 566 # where the grid-search results are saved 567 gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search") 568 os.makedirs(gs_result_folder, exist_ok=True) 569 570 grid_search_values = instance_segmentation.default_grid_search_values_amg( 571 iou_thresh_values=iou_thresh_values, 572 stability_score_values=stability_score_values, 573 ) 574 575 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 576 amg, grid_search_values, 577 val_image_paths, val_gt_paths, test_image_paths, 578 embedding_folder, prediction_folder, gs_result_folder, 579 ) 580 return prediction_folder
def
run_instance_segmentation_with_decoder( checkpoint: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], peft_kwargs: Optional[Dict] = None) -> str:
588def run_instance_segmentation_with_decoder( 589 checkpoint: Union[str, os.PathLike], 590 model_type: str, 591 experiment_folder: Union[str, os.PathLike], 592 val_image_paths: List[Union[str, os.PathLike]], 593 val_gt_paths: List[Union[str, os.PathLike]], 594 test_image_paths: List[Union[str, os.PathLike]], 595 peft_kwargs: Optional[Dict] = None, 596) -> str: 597 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 598 os.makedirs(embedding_folder, exist_ok=True) 599 600 predictor, decoder = get_predictor_and_decoder( 601 model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, 602 ) 603 segmenter = InstanceSegmentationWithDecoder(predictor, decoder) 604 seg_prefix = "instance_segmentation_with_decoder" 605 606 # where the predictions are saved 607 prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") 608 os.makedirs(prediction_folder, exist_ok=True) 609 610 # where the grid-search results are saved 611 gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") 612 os.makedirs(gs_result_folder, exist_ok=True) 613 614 grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder() 615 616 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 617 segmenter, grid_search_values, 618 val_image_paths, val_gt_paths, test_image_paths, 619 embedding_dir=embedding_folder, prediction_dir=prediction_folder, 620 result_dir=gs_result_folder, 621 ) 622 return prediction_folder