micro_sam.evaluation.inference
Inference with Segment Anything models and different prompt strategies.
1"""Inference with Segment Anything models and different prompt strategies. 2""" 3 4import os 5import pickle 6import numpy as np 7from tqdm import tqdm 8from copy import deepcopy 9from typing import Any, Dict, List, Optional, Union, Tuple 10 11import imageio.v3 as imageio 12from skimage.segmentation import relabel_sequential 13 14import torch 15 16from segment_anything import SamPredictor 17 18from .. import util as util 19from ..inference import batched_inference 20from ..instance_segmentation import ( 21 get_predictor_and_decoder, 22 AutomaticMaskGenerator, InstanceSegmentationWithDecoder, 23 TiledAutomaticMaskGenerator, TiledInstanceSegmentationWithDecoder, 24 AutomaticPromptGenerator, 25) 26from . import instance_segmentation 27from ..prompt_generators import PointAndBoxPromptGenerator, IterativePromptGenerator 28 29 30def _load_prompts( 31 cached_point_prompts, save_point_prompts, cached_box_prompts, save_box_prompts, image_name 32): 33 34 def load_prompt_type(cached_prompts, save_prompts): 35 # Check if we have saved prompts. 36 if cached_prompts is None or save_prompts: # we don't have cached prompts 37 return cached_prompts, None 38 39 # we have cached prompts, but they have not been loaded yet 40 if isinstance(cached_prompts, str): 41 with open(cached_prompts, "rb") as f: 42 cached_prompts = pickle.load(f) 43 44 prompts = cached_prompts[image_name] 45 return cached_prompts, prompts 46 47 cached_point_prompts, point_prompts = load_prompt_type(cached_point_prompts, save_point_prompts) 48 cached_box_prompts, box_prompts = load_prompt_type(cached_box_prompts, save_box_prompts) 49 50 # we don't have anything cached 51 if point_prompts is None and box_prompts is None: 52 return None, cached_point_prompts, cached_box_prompts 53 54 if point_prompts is None: 55 input_point, input_label = [], [] 56 else: 57 input_point, input_label = point_prompts 58 59 if box_prompts is None: 60 input_box = [] 61 else: 62 input_box = box_prompts 63 64 prompts = (input_point, input_label, input_box) 65 return prompts, cached_point_prompts, cached_box_prompts 66 67 68def _get_batched_prompts(gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation): 69 70 # Initialize the prompt generator. 71 prompt_generator = PointAndBoxPromptGenerator( 72 n_positive_points=n_positives, n_negative_points=n_negatives, 73 dilation_strength=dilation, get_point_prompts=use_points, 74 get_box_prompts=use_boxes 75 ) 76 77 # Generate the prompts. 78 center_coordinates, bbox_coordinates = util.get_centers_and_bounding_boxes(gt) 79 center_coordinates = [center_coordinates[gt_id] for gt_id in gt_ids] 80 bbox_coordinates = [bbox_coordinates[gt_id] for gt_id in gt_ids] 81 masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids) 82 83 points, point_labels, boxes, _ = prompt_generator( 84 masks, bbox_coordinates, center_coordinates 85 ) 86 87 def to_numpy(x): 88 if x is None: 89 return x 90 return x.numpy() 91 92 return to_numpy(points), to_numpy(point_labels), to_numpy(boxes) 93 94 95def _run_inference_with_prompts_for_image( 96 predictor, 97 image, 98 gt, 99 use_points, 100 use_boxes, 101 n_positives, 102 n_negatives, 103 dilation, 104 batch_size, 105 cached_prompts, 106 embedding_path, 107): 108 gt_ids = np.unique(gt)[1:] 109 if cached_prompts is None: 110 points, point_labels, boxes = _get_batched_prompts( 111 gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation 112 ) 113 else: 114 points, point_labels, boxes = cached_prompts 115 116 # Make a copy of the point prompts to return them at the end. 117 prompts = deepcopy((points, point_labels, boxes)) 118 119 # Use multi-masking only if we have a single positive point without box 120 multimasking = False 121 if not use_boxes and (n_positives == 1 and n_negatives == 0): 122 multimasking = True 123 124 instance_labels = batched_inference( 125 predictor, image, batch_size, 126 boxes=boxes, points=points, point_labels=point_labels, 127 multimasking=multimasking, embedding_path=embedding_path, 128 return_instance_segmentation=True, 129 ) 130 131 return instance_labels, prompts 132 133 134def precompute_all_embeddings( 135 predictor: SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], 136) -> None: 137 """Precompute all image embeddings. 138 139 To enable running different inference tasks in parallel afterwards. 140 141 Args: 142 predictor: The Segment Anything predictor. 143 image_paths: The image file paths. 144 embedding_dir: The directory where the embeddings will be saved. 145 """ 146 for image_path in tqdm(image_paths, desc="Precompute embeddings"): 147 image_name = os.path.basename(image_path) 148 im = imageio.imread(image_path) 149 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 150 util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2) 151 152 153def _precompute_prompts(gt_path, use_points, use_boxes, n_positives, n_negatives, dilation): 154 name = os.path.basename(gt_path) 155 156 gt = imageio.imread(gt_path).astype("uint32") 157 gt = relabel_sequential(gt)[0] 158 gt_ids = np.unique(gt)[1:] 159 160 input_point, input_label, input_box = _get_batched_prompts( 161 gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation 162 ) 163 164 if use_boxes and not use_points: 165 return name, input_box 166 return name, (input_point, input_label) 167 168 169def precompute_all_prompts( 170 gt_paths: List[Union[str, os.PathLike]], 171 prompt_save_dir: Union[str, os.PathLike], 172 prompt_settings: List[Dict[str, Any]], 173) -> None: 174 """Precompute all point prompts. 175 176 To enable running different inference tasks in parallel afterwards. 177 178 Args: 179 gt_paths: The file paths to the ground-truth segmentations. 180 prompt_save_dir: The directory where the prompt files will be saved. 181 prompt_settings: The settings for which the prompts will be computed. 182 """ 183 os.makedirs(prompt_save_dir, exist_ok=True) 184 185 for settings in tqdm(prompt_settings, desc="Precompute prompts"): 186 187 use_points, use_boxes = settings["use_points"], settings["use_boxes"] 188 n_positives, n_negatives = settings["n_positives"], settings["n_negatives"] 189 dilation = settings.get("dilation", 5) 190 191 # check if the prompts were already computed 192 if use_boxes and not use_points: 193 prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl") 194 else: 195 prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl") 196 if os.path.exists(prompt_save_path): 197 continue 198 199 results = [] 200 for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"): 201 prompts = _precompute_prompts( 202 gt_path, 203 use_points=use_points, 204 use_boxes=use_boxes, 205 n_positives=n_positives, 206 n_negatives=n_negatives, 207 dilation=dilation, 208 ) 209 results.append(prompts) 210 211 saved_prompts = {res[0]: res[1] for res in results} 212 with open(prompt_save_path, "wb") as f: 213 pickle.dump(saved_prompts, f) 214 215 216def _get_prompt_caching(prompt_save_dir, use_points, use_boxes, n_positives, n_negatives): 217 218 def get_prompt_type_caching(use_type, save_name): 219 if not use_type: 220 return None, False, None 221 222 prompt_save_path = os.path.join(prompt_save_dir, save_name) 223 if os.path.exists(prompt_save_path): 224 print("Using precomputed prompts from", prompt_save_path) 225 # We delay loading the prompts, so we only have to load them once they're needed the first time. 226 # This avoids loading the prompts (which are in a big pickle file) if all predictions are done already. 227 cached_prompts = prompt_save_path 228 save_prompts = False 229 else: 230 print("Saving prompts in", prompt_save_path) 231 cached_prompts = {} 232 save_prompts = True 233 return cached_prompts, save_prompts, prompt_save_path 234 235 # Check if prompt serialization is enabled. 236 # If it is then load the prompts if they are already cached and otherwise store them. 237 if prompt_save_dir is None: 238 print("Prompts are not cached.") 239 cached_point_prompts, cached_box_prompts = None, None 240 save_point_prompts, save_box_prompts = False, False 241 point_prompt_save_path, box_prompt_save_path = None, None 242 else: 243 cached_point_prompts, save_point_prompts, point_prompt_save_path = get_prompt_type_caching( 244 use_points, f"points-p{n_positives}-n{n_negatives}.pkl" 245 ) 246 cached_box_prompts, save_box_prompts, box_prompt_save_path = get_prompt_type_caching( 247 use_boxes, "boxes.pkl" 248 ) 249 250 return (cached_point_prompts, save_point_prompts, point_prompt_save_path, 251 cached_box_prompts, save_box_prompts, box_prompt_save_path) 252 253 254def run_inference_with_prompts( 255 predictor: SamPredictor, 256 image_paths: List[Union[str, os.PathLike]], 257 gt_paths: List[Union[str, os.PathLike]], 258 embedding_dir: Union[str, os.PathLike], 259 prediction_dir: Union[str, os.PathLike], 260 use_points: bool, 261 use_boxes: bool, 262 n_positives: int, 263 n_negatives: int, 264 dilation: int = 5, 265 prompt_save_dir: Optional[Union[str, os.PathLike]] = None, 266 batch_size: int = 512, 267) -> None: 268 """Run segment anything inference for multiple images using prompts derived from groundtruth. 269 270 Args: 271 predictor: The Segment Anything predictor. 272 image_paths: The image file paths. 273 gt_paths: The ground-truth segmentation file paths. 274 embedding_dir: The directory where the image embddings will be saved or are already saved. 275 use_points: Whether to use point prompts. 276 use_boxes: Whether to use box prompts 277 n_positives: The number of positive point prompts that will be sampled. 278 n_negativess: The number of negative point prompts that will be sampled. 279 dilation: The dilation factor for the radius around the ground-truth object 280 around which points will not be sampled. 281 prompt_save_dir: The directory where point prompts will be saved or are already saved. 282 This enables running multiple experiments in a reproducible manner. 283 batch_size: The batch size used for batched prediction. 284 """ 285 if not (use_points or use_boxes): 286 raise ValueError("You need to use at least one of point or box prompts.") 287 288 if len(image_paths) != len(gt_paths): 289 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 290 291 (cached_point_prompts, save_point_prompts, point_prompt_save_path, 292 cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching( 293 prompt_save_dir, use_points, use_boxes, n_positives, n_negatives 294 ) 295 296 os.makedirs(prediction_dir, exist_ok=True) 297 for image_path, gt_path in tqdm( 298 zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts" 299 ): 300 image_name = os.path.basename(image_path) 301 label_name = os.path.basename(gt_path) 302 303 # We skip the images that already have been segmented. 304 prediction_path = os.path.join(prediction_dir, image_name) 305 if os.path.exists(prediction_path): 306 continue 307 308 assert os.path.exists(image_path), image_path 309 assert os.path.exists(gt_path), gt_path 310 311 im = imageio.imread(image_path) 312 gt = imageio.imread(gt_path).astype("uint32") 313 gt = relabel_sequential(gt)[0] 314 315 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 316 this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts( 317 cached_point_prompts, save_point_prompts, 318 cached_box_prompts, save_box_prompts, 319 label_name 320 ) 321 instances, this_prompts = _run_inference_with_prompts_for_image( 322 predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives, 323 dilation=dilation, use_points=use_points, use_boxes=use_boxes, 324 batch_size=batch_size, cached_prompts=this_prompts, 325 embedding_path=embedding_path, 326 ) 327 328 if save_point_prompts: 329 cached_point_prompts[label_name] = this_prompts[:2] 330 if save_box_prompts: 331 cached_box_prompts[label_name] = this_prompts[-1] 332 333 # It's important to compress here, otherwise the predictions would take up a lot of space. 334 imageio.imwrite(prediction_path, instances, compression=5) 335 336 # Save the prompts if we run experiments with prompt caching and have computed them 337 # for the first time. 338 if save_point_prompts: 339 with open(point_prompt_save_path, "wb") as f: 340 pickle.dump(cached_point_prompts, f) 341 if save_box_prompts: 342 with open(box_prompt_save_path, "wb") as f: 343 pickle.dump(cached_box_prompts, f) 344 345 346def _save_segmentation(masks, prediction_path): 347 # masks to segmentation 348 masks = masks.cpu().numpy().squeeze(1).astype("bool") 349 masks = [{"segmentation": mask, "area": mask.sum()} for mask in masks] 350 segmentation = util.mask_data_to_segmentation(masks) 351 imageio.imwrite(prediction_path, segmentation, compression=5) 352 353 354def _get_batched_iterative_prompts(sampled_binary_gt, masks, batch_size, prompt_generator): 355 n_samples = sampled_binary_gt.shape[0] 356 n_batches = int(np.ceil(float(n_samples) / batch_size)) 357 next_coords, next_labels = [], [] 358 for batch_idx in range(n_batches): 359 batch_start = batch_idx * batch_size 360 batch_stop = min((batch_idx + 1) * batch_size, n_samples) 361 362 batch_coords, batch_labels, _, _ = prompt_generator( 363 sampled_binary_gt[batch_start: batch_stop], masks[batch_start: batch_stop] 364 ) 365 next_coords.append(batch_coords) 366 next_labels.append(batch_labels) 367 368 next_coords = torch.concatenate(next_coords) 369 next_labels = torch.concatenate(next_labels) 370 371 return next_coords, next_labels 372 373 374@torch.no_grad() 375def _run_inference_with_iterative_prompting_for_image( 376 predictor, 377 image, 378 gt, 379 start_with_box_prompt, 380 dilation, 381 batch_size, 382 embedding_path, 383 n_iterations, 384 prediction_paths, 385 use_masks=False 386) -> None: 387 verbose_embeddings = False 388 389 prompt_generator = IterativePromptGenerator() 390 391 gt_ids = np.unique(gt)[1:] 392 393 # Use multi-masking only if we have a single positive point without box 394 if start_with_box_prompt: 395 use_boxes, use_points = True, False 396 n_positives = 0 397 multimasking = False 398 else: 399 use_boxes, use_points = False, True 400 n_positives = 1 401 multimasking = True 402 403 points, point_labels, boxes = _get_batched_prompts( 404 gt, gt_ids, 405 use_points=use_points, 406 use_boxes=use_boxes, 407 n_positives=n_positives, 408 n_negatives=0, 409 dilation=dilation 410 ) 411 412 sampled_binary_gt = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids) 413 414 for iteration in range(n_iterations): 415 if iteration == 0: # logits mask can not be used for the first iteration. 416 logits_masks = None 417 else: 418 if not use_masks: # logits mask should not be used when not desired. 419 logits_masks = None 420 421 batched_outputs = batched_inference( 422 predictor=predictor, 423 image=image, 424 batch_size=batch_size, 425 boxes=boxes, 426 points=points, 427 point_labels=point_labels, 428 multimasking=multimasking, 429 embedding_path=embedding_path, 430 return_instance_segmentation=False, 431 logits_masks=logits_masks, 432 verbose_embeddings=verbose_embeddings, 433 ) 434 435 # switching off multimasking after first iter, as next iters (with multiple prompts) don't expect multimasking 436 multimasking = False 437 438 masks = torch.stack([m["segmentation"][None] for m in batched_outputs]).to(torch.float32) 439 440 next_coords, next_labels = _get_batched_iterative_prompts( 441 sampled_binary_gt, masks, batch_size, prompt_generator 442 ) 443 next_coords, next_labels = next_coords.detach().cpu().numpy(), next_labels.detach().cpu().numpy() 444 445 if points is not None: 446 points = np.concatenate([points, next_coords], axis=1) 447 else: 448 points = next_coords 449 450 if point_labels is not None: 451 point_labels = np.concatenate([point_labels, next_labels], axis=1) 452 else: 453 point_labels = next_labels 454 455 if use_masks: 456 logits_masks = torch.stack([m["logits"] for m in batched_outputs]) 457 458 _save_segmentation(masks, prediction_paths[iteration]) 459 460 461def run_inference_with_iterative_prompting( 462 predictor: SamPredictor, 463 image_paths: List[Union[str, os.PathLike]], 464 gt_paths: List[Union[str, os.PathLike]], 465 embedding_dir: Union[str, os.PathLike], 466 prediction_dir: Union[str, os.PathLike], 467 start_with_box_prompt: bool = True, 468 dilation: int = 5, 469 batch_size: int = 32, 470 n_iterations: int = 8, 471 use_masks: bool = False 472) -> None: 473 """Run Segment Anything inference for multiple images using prompts iteratively 474 derived from model outputs and ground-truth. 475 476 Args: 477 predictor: The Segment Anything predictor. 478 image_paths: The image file paths. 479 gt_paths: The ground-truth segmentation file paths. 480 embedding_dir: The directory where the image embeddings will be saved or are already saved. 481 prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration. 482 start_with_box_prompt: Whether to use the first prompt as bounding box or a single point 483 dilation: The dilation factor for the radius around the ground-truth object 484 around which points will not be sampled. 485 batch_size: The batch size used for batched predictions. 486 n_iterations: The number of iterations for iterative prompting. 487 use_masks: Whether to make use of logits from previous prompt-based segmentation. 488 """ 489 if len(image_paths) != len(gt_paths): 490 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 491 492 # create all prediction folders for all intermediate iterations 493 for i in range(n_iterations): 494 os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True) 495 496 if use_masks: 497 print("The iterative prompting will make use of logits masks from previous iterations.") 498 499 for image_path, gt_path in tqdm( 500 zip(image_paths, gt_paths), total=len(image_paths), 501 desc="Run inference with iterative prompting for all images", 502 ): 503 image_name = os.path.basename(image_path) 504 505 # We skip the images that already have been segmented 506 prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)] 507 if all(os.path.exists(prediction_path) for prediction_path in prediction_paths): 508 continue 509 510 assert os.path.exists(image_path), image_path 511 assert os.path.exists(gt_path), gt_path 512 513 image = imageio.imread(image_path) 514 gt = imageio.imread(gt_path).astype("uint32") 515 gt = relabel_sequential(gt)[0] 516 517 if embedding_dir is None: 518 embedding_path = None 519 else: 520 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 521 522 _run_inference_with_iterative_prompting_for_image( 523 predictor, image, gt, start_with_box_prompt=start_with_box_prompt, 524 dilation=dilation, batch_size=batch_size, embedding_path=embedding_path, 525 n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks 526 ) 527 528 529# 530# AMG FUNCTION 531# 532 533 534def run_amg( 535 checkpoint: Union[str, os.PathLike], 536 model_type: str, 537 experiment_folder: Union[str, os.PathLike], 538 val_image_paths: List[Union[str, os.PathLike]], 539 val_gt_paths: List[Union[str, os.PathLike]], 540 test_image_paths: List[Union[str, os.PathLike]], 541 iou_thresh_values: Optional[List[float]] = None, 542 stability_score_values: Optional[List[float]] = None, 543 peft_kwargs: Optional[Dict] = None, 544 cache_embeddings: bool = False, 545 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 546) -> str: 547 """Run Segment Anything inference for multiple images using automatic mask generation (AMG). 548 549 Args: 550 checkpoint: The filepath to model checkpoints. 551 model_type: The segment anything model choice. 552 experimet_folder: The directory where the relevant files are saved. 553 val_image_paths: The list of filepaths of input images for grid-search. 554 val_gt_paths: The list of filepaths of corresponding labels for grid-search. 555 test_image_paths: The list of filepaths of input images for automatic instance segmentation. 556 iou_thresh_values: Optional choice of values for grid search of `iou_thresh` parameter. 557 stability_score_values: Optional choice of values for grid search of `stability_score` parameter. 558 peft_kwargs: Keyword arguments for th PEFT wrapper class. 559 cache_embeddings: Whether to cache embeddings in experiment folder. 560 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 561 562 Returns: 563 Filepath where the predictions have been saved. 564 """ 565 566 if cache_embeddings: 567 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 568 os.makedirs(embedding_folder, exist_ok=True) 569 else: 570 embedding_folder = None 571 572 predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs) 573 574 # Get the AMG class. 575 if tiling_window_params: 576 if not isinstance(tiling_window_params, dict): 577 raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.") 578 579 if "tile_shape" not in tiling_window_params: 580 raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.") 581 582 if "halo" not in tiling_window_params: 583 raise RuntimeError("'halo' parameter is missing from the provided parameters.") 584 585 amg_class = TiledAutomaticMaskGenerator 586 else: 587 amg_class = AutomaticMaskGenerator 588 589 amg = amg_class(predictor) 590 amg_prefix = "amg" 591 592 # where the predictions are saved 593 prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference") 594 os.makedirs(prediction_folder, exist_ok=True) 595 596 # where the grid-search results are saved 597 gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search") 598 os.makedirs(gs_result_folder, exist_ok=True) 599 600 grid_search_values = instance_segmentation.default_grid_search_values_amg( 601 iou_thresh_values=iou_thresh_values, 602 stability_score_values=stability_score_values, 603 ) 604 605 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 606 segmenter=amg, 607 grid_search_values=grid_search_values, 608 val_image_paths=val_image_paths, 609 val_gt_paths=val_gt_paths, 610 test_image_paths=test_image_paths, 611 embedding_dir=embedding_folder, 612 prediction_dir=prediction_folder, 613 result_dir=gs_result_folder, 614 experiment_folder=experiment_folder, 615 tiling_window_params=tiling_window_params, 616 ) 617 return prediction_folder 618 619 620def run_apg( 621 checkpoint: Optional[Union[str, os.PathLike]], 622 model_type: str, 623 experiment_folder: Union[str, os.PathLike], 624 val_image_paths: List[Union[str, os.PathLike]], 625 val_gt_paths: List[Union[str, os.PathLike]], 626 test_image_paths: List[Union[str, os.PathLike]], 627 peft_kwargs: Optional[Dict] = None, 628 cache_embeddings: bool = False, 629 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 630) -> str: 631 """Run Segment Anything inference for multiple images using automatic prompt generation (APG). 632 633 Args: 634 ... 635 636 Returns: 637 Filepath where the predictions have been saved. 638 """ 639 if cache_embeddings: 640 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 641 os.makedirs(embedding_folder, exist_ok=True) 642 else: 643 embedding_folder = None 644 645 predictor, decoder = get_predictor_and_decoder( 646 model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, 647 ) 648 649 # Get the APG class. 650 if tiling_window_params: 651 raise NotImplementedError 652 else: 653 apg_class = AutomaticPromptGenerator 654 655 segmenter = apg_class(predictor, decoder) 656 seg_prefix = "apg" 657 658 # where the predictions are saved 659 prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") 660 os.makedirs(prediction_folder, exist_ok=True) 661 662 # where the grid-search results are saved 663 gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") 664 os.makedirs(gs_result_folder, exist_ok=True) 665 666 grid_search_values = instance_segmentation.default_grid_search_values_apg() 667 668 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 669 segmenter=segmenter, 670 grid_search_values=grid_search_values, 671 val_image_paths=val_image_paths, 672 val_gt_paths=val_gt_paths, 673 test_image_paths=test_image_paths, 674 embedding_dir=embedding_folder, 675 prediction_dir=prediction_folder, 676 result_dir=gs_result_folder, 677 experiment_folder=experiment_folder, 678 tiling_window_params=tiling_window_params, 679 ) 680 return prediction_folder 681 682 683# 684# INSTANCE SEGMENTATION FUNCTION 685# 686 687 688def run_instance_segmentation_with_decoder( 689 checkpoint: Union[str, os.PathLike], 690 model_type: str, 691 experiment_folder: Union[str, os.PathLike], 692 val_image_paths: List[Union[str, os.PathLike]], 693 val_gt_paths: List[Union[str, os.PathLike]], 694 test_image_paths: List[Union[str, os.PathLike]], 695 peft_kwargs: Optional[Dict] = None, 696 cache_embeddings: bool = False, 697 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 698) -> str: 699 """Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS). 700 701 Args: 702 checkpoint: The filepath to model checkpoints. 703 model_type: The segment anything model choice. 704 experimet_folder: The directory where the relevant files are saved. 705 val_image_paths: The list of filepaths of input images for grid-search. 706 val_gt_paths: The list of filepaths of corresponding labels for grid-search. 707 test_image_paths: The list of filepaths of input images for automatic instance segmentation. 708 peft_kwargs: Keyword arguments for th PEFT wrapper class. 709 cache_embeddings: Whether to cache embeddings in experiment folder. 710 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 711 712 Returns: 713 Filepath where the predictions have been saved. 714 """ 715 716 if cache_embeddings: 717 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 718 os.makedirs(embedding_folder, exist_ok=True) 719 else: 720 embedding_folder = None 721 722 predictor, decoder = get_predictor_and_decoder( 723 model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, 724 ) 725 726 # Get the AIS class. 727 if tiling_window_params: 728 if not isinstance(tiling_window_params, dict): 729 raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.") 730 731 if "tile_shape" not in tiling_window_params: 732 raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.") 733 734 if "halo" not in tiling_window_params: 735 raise RuntimeError("'halo' parameter is missing from the provided parameters.") 736 737 ais_class = TiledInstanceSegmentationWithDecoder 738 else: 739 ais_class = InstanceSegmentationWithDecoder 740 741 segmenter = ais_class(predictor, decoder) 742 seg_prefix = "instance_segmentation_with_decoder" 743 744 # where the predictions are saved 745 prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") 746 os.makedirs(prediction_folder, exist_ok=True) 747 748 # where the grid-search results are saved 749 gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") 750 os.makedirs(gs_result_folder, exist_ok=True) 751 752 grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder() 753 754 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 755 segmenter=segmenter, 756 grid_search_values=grid_search_values, 757 val_image_paths=val_image_paths, 758 val_gt_paths=val_gt_paths, 759 test_image_paths=test_image_paths, 760 embedding_dir=embedding_folder, 761 prediction_dir=prediction_folder, 762 result_dir=gs_result_folder, 763 experiment_folder=experiment_folder, 764 tiling_window_params=tiling_window_params, 765 ) 766 return prediction_folder
135def precompute_all_embeddings( 136 predictor: SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], 137) -> None: 138 """Precompute all image embeddings. 139 140 To enable running different inference tasks in parallel afterwards. 141 142 Args: 143 predictor: The Segment Anything predictor. 144 image_paths: The image file paths. 145 embedding_dir: The directory where the embeddings will be saved. 146 """ 147 for image_path in tqdm(image_paths, desc="Precompute embeddings"): 148 image_name = os.path.basename(image_path) 149 im = imageio.imread(image_path) 150 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 151 util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)
Precompute all image embeddings.
To enable running different inference tasks in parallel afterwards.
Arguments:
- predictor: The Segment Anything predictor.
- image_paths: The image file paths.
- embedding_dir: The directory where the embeddings will be saved.
170def precompute_all_prompts( 171 gt_paths: List[Union[str, os.PathLike]], 172 prompt_save_dir: Union[str, os.PathLike], 173 prompt_settings: List[Dict[str, Any]], 174) -> None: 175 """Precompute all point prompts. 176 177 To enable running different inference tasks in parallel afterwards. 178 179 Args: 180 gt_paths: The file paths to the ground-truth segmentations. 181 prompt_save_dir: The directory where the prompt files will be saved. 182 prompt_settings: The settings for which the prompts will be computed. 183 """ 184 os.makedirs(prompt_save_dir, exist_ok=True) 185 186 for settings in tqdm(prompt_settings, desc="Precompute prompts"): 187 188 use_points, use_boxes = settings["use_points"], settings["use_boxes"] 189 n_positives, n_negatives = settings["n_positives"], settings["n_negatives"] 190 dilation = settings.get("dilation", 5) 191 192 # check if the prompts were already computed 193 if use_boxes and not use_points: 194 prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl") 195 else: 196 prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl") 197 if os.path.exists(prompt_save_path): 198 continue 199 200 results = [] 201 for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"): 202 prompts = _precompute_prompts( 203 gt_path, 204 use_points=use_points, 205 use_boxes=use_boxes, 206 n_positives=n_positives, 207 n_negatives=n_negatives, 208 dilation=dilation, 209 ) 210 results.append(prompts) 211 212 saved_prompts = {res[0]: res[1] for res in results} 213 with open(prompt_save_path, "wb") as f: 214 pickle.dump(saved_prompts, f)
Precompute all point prompts.
To enable running different inference tasks in parallel afterwards.
Arguments:
- gt_paths: The file paths to the ground-truth segmentations.
- prompt_save_dir: The directory where the prompt files will be saved.
- prompt_settings: The settings for which the prompts will be computed.
255def run_inference_with_prompts( 256 predictor: SamPredictor, 257 image_paths: List[Union[str, os.PathLike]], 258 gt_paths: List[Union[str, os.PathLike]], 259 embedding_dir: Union[str, os.PathLike], 260 prediction_dir: Union[str, os.PathLike], 261 use_points: bool, 262 use_boxes: bool, 263 n_positives: int, 264 n_negatives: int, 265 dilation: int = 5, 266 prompt_save_dir: Optional[Union[str, os.PathLike]] = None, 267 batch_size: int = 512, 268) -> None: 269 """Run segment anything inference for multiple images using prompts derived from groundtruth. 270 271 Args: 272 predictor: The Segment Anything predictor. 273 image_paths: The image file paths. 274 gt_paths: The ground-truth segmentation file paths. 275 embedding_dir: The directory where the image embddings will be saved or are already saved. 276 use_points: Whether to use point prompts. 277 use_boxes: Whether to use box prompts 278 n_positives: The number of positive point prompts that will be sampled. 279 n_negativess: The number of negative point prompts that will be sampled. 280 dilation: The dilation factor for the radius around the ground-truth object 281 around which points will not be sampled. 282 prompt_save_dir: The directory where point prompts will be saved or are already saved. 283 This enables running multiple experiments in a reproducible manner. 284 batch_size: The batch size used for batched prediction. 285 """ 286 if not (use_points or use_boxes): 287 raise ValueError("You need to use at least one of point or box prompts.") 288 289 if len(image_paths) != len(gt_paths): 290 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 291 292 (cached_point_prompts, save_point_prompts, point_prompt_save_path, 293 cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching( 294 prompt_save_dir, use_points, use_boxes, n_positives, n_negatives 295 ) 296 297 os.makedirs(prediction_dir, exist_ok=True) 298 for image_path, gt_path in tqdm( 299 zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts" 300 ): 301 image_name = os.path.basename(image_path) 302 label_name = os.path.basename(gt_path) 303 304 # We skip the images that already have been segmented. 305 prediction_path = os.path.join(prediction_dir, image_name) 306 if os.path.exists(prediction_path): 307 continue 308 309 assert os.path.exists(image_path), image_path 310 assert os.path.exists(gt_path), gt_path 311 312 im = imageio.imread(image_path) 313 gt = imageio.imread(gt_path).astype("uint32") 314 gt = relabel_sequential(gt)[0] 315 316 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 317 this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts( 318 cached_point_prompts, save_point_prompts, 319 cached_box_prompts, save_box_prompts, 320 label_name 321 ) 322 instances, this_prompts = _run_inference_with_prompts_for_image( 323 predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives, 324 dilation=dilation, use_points=use_points, use_boxes=use_boxes, 325 batch_size=batch_size, cached_prompts=this_prompts, 326 embedding_path=embedding_path, 327 ) 328 329 if save_point_prompts: 330 cached_point_prompts[label_name] = this_prompts[:2] 331 if save_box_prompts: 332 cached_box_prompts[label_name] = this_prompts[-1] 333 334 # It's important to compress here, otherwise the predictions would take up a lot of space. 335 imageio.imwrite(prediction_path, instances, compression=5) 336 337 # Save the prompts if we run experiments with prompt caching and have computed them 338 # for the first time. 339 if save_point_prompts: 340 with open(point_prompt_save_path, "wb") as f: 341 pickle.dump(cached_point_prompts, f) 342 if save_box_prompts: 343 with open(box_prompt_save_path, "wb") as f: 344 pickle.dump(cached_box_prompts, f)
Run segment anything inference for multiple images using prompts derived from groundtruth.
Arguments:
- predictor: The Segment Anything predictor.
- image_paths: The image file paths.
- gt_paths: The ground-truth segmentation file paths.
- embedding_dir: The directory where the image embddings will be saved or are already saved.
- use_points: Whether to use point prompts.
- use_boxes: Whether to use box prompts
- n_positives: The number of positive point prompts that will be sampled.
- n_negativess: The number of negative point prompts that will be sampled.
- dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
- prompt_save_dir: The directory where point prompts will be saved or are already saved. This enables running multiple experiments in a reproducible manner.
- batch_size: The batch size used for batched prediction.
462def run_inference_with_iterative_prompting( 463 predictor: SamPredictor, 464 image_paths: List[Union[str, os.PathLike]], 465 gt_paths: List[Union[str, os.PathLike]], 466 embedding_dir: Union[str, os.PathLike], 467 prediction_dir: Union[str, os.PathLike], 468 start_with_box_prompt: bool = True, 469 dilation: int = 5, 470 batch_size: int = 32, 471 n_iterations: int = 8, 472 use_masks: bool = False 473) -> None: 474 """Run Segment Anything inference for multiple images using prompts iteratively 475 derived from model outputs and ground-truth. 476 477 Args: 478 predictor: The Segment Anything predictor. 479 image_paths: The image file paths. 480 gt_paths: The ground-truth segmentation file paths. 481 embedding_dir: The directory where the image embeddings will be saved or are already saved. 482 prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration. 483 start_with_box_prompt: Whether to use the first prompt as bounding box or a single point 484 dilation: The dilation factor for the radius around the ground-truth object 485 around which points will not be sampled. 486 batch_size: The batch size used for batched predictions. 487 n_iterations: The number of iterations for iterative prompting. 488 use_masks: Whether to make use of logits from previous prompt-based segmentation. 489 """ 490 if len(image_paths) != len(gt_paths): 491 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 492 493 # create all prediction folders for all intermediate iterations 494 for i in range(n_iterations): 495 os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True) 496 497 if use_masks: 498 print("The iterative prompting will make use of logits masks from previous iterations.") 499 500 for image_path, gt_path in tqdm( 501 zip(image_paths, gt_paths), total=len(image_paths), 502 desc="Run inference with iterative prompting for all images", 503 ): 504 image_name = os.path.basename(image_path) 505 506 # We skip the images that already have been segmented 507 prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)] 508 if all(os.path.exists(prediction_path) for prediction_path in prediction_paths): 509 continue 510 511 assert os.path.exists(image_path), image_path 512 assert os.path.exists(gt_path), gt_path 513 514 image = imageio.imread(image_path) 515 gt = imageio.imread(gt_path).astype("uint32") 516 gt = relabel_sequential(gt)[0] 517 518 if embedding_dir is None: 519 embedding_path = None 520 else: 521 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 522 523 _run_inference_with_iterative_prompting_for_image( 524 predictor, image, gt, start_with_box_prompt=start_with_box_prompt, 525 dilation=dilation, batch_size=batch_size, embedding_path=embedding_path, 526 n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks 527 )
Run Segment Anything inference for multiple images using prompts iteratively derived from model outputs and ground-truth.
Arguments:
- predictor: The Segment Anything predictor.
- image_paths: The image file paths.
- gt_paths: The ground-truth segmentation file paths.
- embedding_dir: The directory where the image embeddings will be saved or are already saved.
- prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
- start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
- dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
- batch_size: The batch size used for batched predictions.
- n_iterations: The number of iterations for iterative prompting.
- use_masks: Whether to make use of logits from previous prompt-based segmentation.
535def run_amg( 536 checkpoint: Union[str, os.PathLike], 537 model_type: str, 538 experiment_folder: Union[str, os.PathLike], 539 val_image_paths: List[Union[str, os.PathLike]], 540 val_gt_paths: List[Union[str, os.PathLike]], 541 test_image_paths: List[Union[str, os.PathLike]], 542 iou_thresh_values: Optional[List[float]] = None, 543 stability_score_values: Optional[List[float]] = None, 544 peft_kwargs: Optional[Dict] = None, 545 cache_embeddings: bool = False, 546 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 547) -> str: 548 """Run Segment Anything inference for multiple images using automatic mask generation (AMG). 549 550 Args: 551 checkpoint: The filepath to model checkpoints. 552 model_type: The segment anything model choice. 553 experimet_folder: The directory where the relevant files are saved. 554 val_image_paths: The list of filepaths of input images for grid-search. 555 val_gt_paths: The list of filepaths of corresponding labels for grid-search. 556 test_image_paths: The list of filepaths of input images for automatic instance segmentation. 557 iou_thresh_values: Optional choice of values for grid search of `iou_thresh` parameter. 558 stability_score_values: Optional choice of values for grid search of `stability_score` parameter. 559 peft_kwargs: Keyword arguments for th PEFT wrapper class. 560 cache_embeddings: Whether to cache embeddings in experiment folder. 561 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 562 563 Returns: 564 Filepath where the predictions have been saved. 565 """ 566 567 if cache_embeddings: 568 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 569 os.makedirs(embedding_folder, exist_ok=True) 570 else: 571 embedding_folder = None 572 573 predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs) 574 575 # Get the AMG class. 576 if tiling_window_params: 577 if not isinstance(tiling_window_params, dict): 578 raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.") 579 580 if "tile_shape" not in tiling_window_params: 581 raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.") 582 583 if "halo" not in tiling_window_params: 584 raise RuntimeError("'halo' parameter is missing from the provided parameters.") 585 586 amg_class = TiledAutomaticMaskGenerator 587 else: 588 amg_class = AutomaticMaskGenerator 589 590 amg = amg_class(predictor) 591 amg_prefix = "amg" 592 593 # where the predictions are saved 594 prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference") 595 os.makedirs(prediction_folder, exist_ok=True) 596 597 # where the grid-search results are saved 598 gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search") 599 os.makedirs(gs_result_folder, exist_ok=True) 600 601 grid_search_values = instance_segmentation.default_grid_search_values_amg( 602 iou_thresh_values=iou_thresh_values, 603 stability_score_values=stability_score_values, 604 ) 605 606 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 607 segmenter=amg, 608 grid_search_values=grid_search_values, 609 val_image_paths=val_image_paths, 610 val_gt_paths=val_gt_paths, 611 test_image_paths=test_image_paths, 612 embedding_dir=embedding_folder, 613 prediction_dir=prediction_folder, 614 result_dir=gs_result_folder, 615 experiment_folder=experiment_folder, 616 tiling_window_params=tiling_window_params, 617 ) 618 return prediction_folder
Run Segment Anything inference for multiple images using automatic mask generation (AMG).
Arguments:
- checkpoint: The filepath to model checkpoints.
- model_type: The segment anything model choice.
- experimet_folder: The directory where the relevant files are saved.
- val_image_paths: The list of filepaths of input images for grid-search.
- val_gt_paths: The list of filepaths of corresponding labels for grid-search.
- test_image_paths: The list of filepaths of input images for automatic instance segmentation.
- iou_thresh_values: Optional choice of values for grid search of
iou_threshparameter. - stability_score_values: Optional choice of values for grid search of
stability_scoreparameter. - peft_kwargs: Keyword arguments for th PEFT wrapper class.
- cache_embeddings: Whether to cache embeddings in experiment folder.
- tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
Returns:
Filepath where the predictions have been saved.
621def run_apg( 622 checkpoint: Optional[Union[str, os.PathLike]], 623 model_type: str, 624 experiment_folder: Union[str, os.PathLike], 625 val_image_paths: List[Union[str, os.PathLike]], 626 val_gt_paths: List[Union[str, os.PathLike]], 627 test_image_paths: List[Union[str, os.PathLike]], 628 peft_kwargs: Optional[Dict] = None, 629 cache_embeddings: bool = False, 630 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 631) -> str: 632 """Run Segment Anything inference for multiple images using automatic prompt generation (APG). 633 634 Args: 635 ... 636 637 Returns: 638 Filepath where the predictions have been saved. 639 """ 640 if cache_embeddings: 641 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 642 os.makedirs(embedding_folder, exist_ok=True) 643 else: 644 embedding_folder = None 645 646 predictor, decoder = get_predictor_and_decoder( 647 model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, 648 ) 649 650 # Get the APG class. 651 if tiling_window_params: 652 raise NotImplementedError 653 else: 654 apg_class = AutomaticPromptGenerator 655 656 segmenter = apg_class(predictor, decoder) 657 seg_prefix = "apg" 658 659 # where the predictions are saved 660 prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") 661 os.makedirs(prediction_folder, exist_ok=True) 662 663 # where the grid-search results are saved 664 gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") 665 os.makedirs(gs_result_folder, exist_ok=True) 666 667 grid_search_values = instance_segmentation.default_grid_search_values_apg() 668 669 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 670 segmenter=segmenter, 671 grid_search_values=grid_search_values, 672 val_image_paths=val_image_paths, 673 val_gt_paths=val_gt_paths, 674 test_image_paths=test_image_paths, 675 embedding_dir=embedding_folder, 676 prediction_dir=prediction_folder, 677 result_dir=gs_result_folder, 678 experiment_folder=experiment_folder, 679 tiling_window_params=tiling_window_params, 680 ) 681 return prediction_folder
Run Segment Anything inference for multiple images using automatic prompt generation (APG).
Arguments:
- ...
Returns:
Filepath where the predictions have been saved.
689def run_instance_segmentation_with_decoder( 690 checkpoint: Union[str, os.PathLike], 691 model_type: str, 692 experiment_folder: Union[str, os.PathLike], 693 val_image_paths: List[Union[str, os.PathLike]], 694 val_gt_paths: List[Union[str, os.PathLike]], 695 test_image_paths: List[Union[str, os.PathLike]], 696 peft_kwargs: Optional[Dict] = None, 697 cache_embeddings: bool = False, 698 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 699) -> str: 700 """Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS). 701 702 Args: 703 checkpoint: The filepath to model checkpoints. 704 model_type: The segment anything model choice. 705 experimet_folder: The directory where the relevant files are saved. 706 val_image_paths: The list of filepaths of input images for grid-search. 707 val_gt_paths: The list of filepaths of corresponding labels for grid-search. 708 test_image_paths: The list of filepaths of input images for automatic instance segmentation. 709 peft_kwargs: Keyword arguments for th PEFT wrapper class. 710 cache_embeddings: Whether to cache embeddings in experiment folder. 711 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 712 713 Returns: 714 Filepath where the predictions have been saved. 715 """ 716 717 if cache_embeddings: 718 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 719 os.makedirs(embedding_folder, exist_ok=True) 720 else: 721 embedding_folder = None 722 723 predictor, decoder = get_predictor_and_decoder( 724 model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, 725 ) 726 727 # Get the AIS class. 728 if tiling_window_params: 729 if not isinstance(tiling_window_params, dict): 730 raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.") 731 732 if "tile_shape" not in tiling_window_params: 733 raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.") 734 735 if "halo" not in tiling_window_params: 736 raise RuntimeError("'halo' parameter is missing from the provided parameters.") 737 738 ais_class = TiledInstanceSegmentationWithDecoder 739 else: 740 ais_class = InstanceSegmentationWithDecoder 741 742 segmenter = ais_class(predictor, decoder) 743 seg_prefix = "instance_segmentation_with_decoder" 744 745 # where the predictions are saved 746 prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") 747 os.makedirs(prediction_folder, exist_ok=True) 748 749 # where the grid-search results are saved 750 gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") 751 os.makedirs(gs_result_folder, exist_ok=True) 752 753 grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder() 754 755 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 756 segmenter=segmenter, 757 grid_search_values=grid_search_values, 758 val_image_paths=val_image_paths, 759 val_gt_paths=val_gt_paths, 760 test_image_paths=test_image_paths, 761 embedding_dir=embedding_folder, 762 prediction_dir=prediction_folder, 763 result_dir=gs_result_folder, 764 experiment_folder=experiment_folder, 765 tiling_window_params=tiling_window_params, 766 ) 767 return prediction_folder
Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS).
Arguments:
- checkpoint: The filepath to model checkpoints.
- model_type: The segment anything model choice.
- experimet_folder: The directory where the relevant files are saved.
- val_image_paths: The list of filepaths of input images for grid-search.
- val_gt_paths: The list of filepaths of corresponding labels for grid-search.
- test_image_paths: The list of filepaths of input images for automatic instance segmentation.
- peft_kwargs: Keyword arguments for th PEFT wrapper class.
- cache_embeddings: Whether to cache embeddings in experiment folder.
- tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
Returns:
Filepath where the predictions have been saved.