micro_sam.evaluation.inference
Inference with Segment Anything models and different prompt strategies.
1"""Inference with Segment Anything models and different prompt strategies. 2""" 3 4import os 5import pickle 6import numpy as np 7from tqdm import tqdm 8from copy import deepcopy 9from typing import Any, Dict, List, Optional, Union, Tuple 10 11import imageio.v3 as imageio 12 13from bioimage_cpp.segmentation import relabel_sequential 14 15import torch 16 17from segment_anything import SamPredictor 18 19from .. import util as util 20from ..inference import batched_inference 21from ..instance_segmentation import ( 22 get_predictor_and_decoder, 23 AutomaticMaskGenerator, InstanceSegmentationWithDecoder, 24 TiledAutomaticMaskGenerator, TiledInstanceSegmentationWithDecoder, 25 AutomaticPromptGenerator, 26) 27from . import instance_segmentation 28from ..prompt_generators import PointAndBoxPromptGenerator, IterativePromptGenerator 29 30 31def _load_prompts( 32 cached_point_prompts, save_point_prompts, cached_box_prompts, save_box_prompts, image_name 33): 34 35 def load_prompt_type(cached_prompts, save_prompts): 36 # Check if we have saved prompts. 37 if cached_prompts is None or save_prompts: # we don't have cached prompts 38 return cached_prompts, None 39 40 # we have cached prompts, but they have not been loaded yet 41 if isinstance(cached_prompts, str): 42 with open(cached_prompts, "rb") as f: 43 cached_prompts = pickle.load(f) 44 45 prompts = cached_prompts[image_name] 46 return cached_prompts, prompts 47 48 cached_point_prompts, point_prompts = load_prompt_type(cached_point_prompts, save_point_prompts) 49 cached_box_prompts, box_prompts = load_prompt_type(cached_box_prompts, save_box_prompts) 50 51 # we don't have anything cached 52 if point_prompts is None and box_prompts is None: 53 return None, cached_point_prompts, cached_box_prompts 54 55 if point_prompts is None: 56 input_point, input_label = [], [] 57 else: 58 input_point, input_label = point_prompts 59 60 if box_prompts is None: 61 input_box = [] 62 else: 63 input_box = box_prompts 64 65 prompts = (input_point, input_label, input_box) 66 return prompts, cached_point_prompts, cached_box_prompts 67 68 69def _get_batched_prompts(gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation): 70 71 # Initialize the prompt generator. 72 prompt_generator = PointAndBoxPromptGenerator( 73 n_positive_points=n_positives, n_negative_points=n_negatives, 74 dilation_strength=dilation, get_point_prompts=use_points, 75 get_box_prompts=use_boxes 76 ) 77 78 # Generate the prompts. 79 center_coordinates, bbox_coordinates = util.get_centers_and_bounding_boxes(gt) 80 center_coordinates = [center_coordinates[gt_id] for gt_id in gt_ids] 81 bbox_coordinates = [bbox_coordinates[gt_id] for gt_id in gt_ids] 82 masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids) 83 84 points, point_labels, boxes, _ = prompt_generator( 85 masks, bbox_coordinates, center_coordinates 86 ) 87 88 def to_numpy(x): 89 if x is None: 90 return x 91 return x.numpy() 92 93 return to_numpy(points), to_numpy(point_labels), to_numpy(boxes) 94 95 96def _run_inference_with_prompts_for_image( 97 predictor, 98 image, 99 gt, 100 use_points, 101 use_boxes, 102 n_positives, 103 n_negatives, 104 dilation, 105 batch_size, 106 cached_prompts, 107 embedding_path, 108): 109 gt_ids = np.unique(gt)[1:] 110 if cached_prompts is None: 111 points, point_labels, boxes = _get_batched_prompts( 112 gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation 113 ) 114 else: 115 points, point_labels, boxes = cached_prompts 116 117 # Make a copy of the point prompts to return them at the end. 118 prompts = deepcopy((points, point_labels, boxes)) 119 120 # Use multi-masking only if we have a single positive point without box 121 multimasking = False 122 if not use_boxes and (n_positives == 1 and n_negatives == 0): 123 multimasking = True 124 125 instance_labels = batched_inference( 126 predictor, image, batch_size, 127 boxes=boxes, points=points, point_labels=point_labels, 128 multimasking=multimasking, embedding_path=embedding_path, 129 return_instance_segmentation=True, 130 ) 131 132 return instance_labels, prompts 133 134 135def precompute_all_embeddings( 136 predictor: SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], 137) -> None: 138 """Precompute all image embeddings. 139 140 To enable running different inference tasks in parallel afterwards. 141 142 Args: 143 predictor: The Segment Anything predictor. 144 image_paths: The image file paths. 145 embedding_dir: The directory where the embeddings will be saved. 146 """ 147 for image_path in tqdm(image_paths, desc="Precompute embeddings"): 148 image_name = os.path.basename(image_path) 149 im = imageio.imread(image_path) 150 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 151 util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2) 152 153 154def _precompute_prompts(gt_path, use_points, use_boxes, n_positives, n_negatives, dilation): 155 name = os.path.basename(gt_path) 156 157 gt = imageio.imread(gt_path).astype("uint32") 158 gt = relabel_sequential(gt)[0] 159 gt_ids = np.unique(gt)[1:] 160 161 input_point, input_label, input_box = _get_batched_prompts( 162 gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation 163 ) 164 165 if use_boxes and not use_points: 166 return name, input_box 167 return name, (input_point, input_label) 168 169 170def precompute_all_prompts( 171 gt_paths: List[Union[str, os.PathLike]], 172 prompt_save_dir: Union[str, os.PathLike], 173 prompt_settings: List[Dict[str, Any]], 174) -> None: 175 """Precompute all point prompts. 176 177 To enable running different inference tasks in parallel afterwards. 178 179 Args: 180 gt_paths: The file paths to the ground-truth segmentations. 181 prompt_save_dir: The directory where the prompt files will be saved. 182 prompt_settings: The settings for which the prompts will be computed. 183 """ 184 os.makedirs(prompt_save_dir, exist_ok=True) 185 186 for settings in tqdm(prompt_settings, desc="Precompute prompts"): 187 188 use_points, use_boxes = settings["use_points"], settings["use_boxes"] 189 n_positives, n_negatives = settings["n_positives"], settings["n_negatives"] 190 dilation = settings.get("dilation", 5) 191 192 # check if the prompts were already computed 193 if use_boxes and not use_points: 194 prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl") 195 else: 196 prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl") 197 if os.path.exists(prompt_save_path): 198 continue 199 200 results = [] 201 for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"): 202 prompts = _precompute_prompts( 203 gt_path, 204 use_points=use_points, 205 use_boxes=use_boxes, 206 n_positives=n_positives, 207 n_negatives=n_negatives, 208 dilation=dilation, 209 ) 210 results.append(prompts) 211 212 saved_prompts = {res[0]: res[1] for res in results} 213 with open(prompt_save_path, "wb") as f: 214 pickle.dump(saved_prompts, f) 215 216 217def _get_prompt_caching(prompt_save_dir, use_points, use_boxes, n_positives, n_negatives): 218 219 def get_prompt_type_caching(use_type, save_name): 220 if not use_type: 221 return None, False, None 222 223 prompt_save_path = os.path.join(prompt_save_dir, save_name) 224 if os.path.exists(prompt_save_path): 225 print("Using precomputed prompts from", prompt_save_path) 226 # We delay loading the prompts, so we only have to load them once they're needed the first time. 227 # This avoids loading the prompts (which are in a big pickle file) if all predictions are done already. 228 cached_prompts = prompt_save_path 229 save_prompts = False 230 else: 231 print("Saving prompts in", prompt_save_path) 232 cached_prompts = {} 233 save_prompts = True 234 return cached_prompts, save_prompts, prompt_save_path 235 236 # Check if prompt serialization is enabled. 237 # If it is then load the prompts if they are already cached and otherwise store them. 238 if prompt_save_dir is None: 239 print("Prompts are not cached.") 240 cached_point_prompts, cached_box_prompts = None, None 241 save_point_prompts, save_box_prompts = False, False 242 point_prompt_save_path, box_prompt_save_path = None, None 243 else: 244 cached_point_prompts, save_point_prompts, point_prompt_save_path = get_prompt_type_caching( 245 use_points, f"points-p{n_positives}-n{n_negatives}.pkl" 246 ) 247 cached_box_prompts, save_box_prompts, box_prompt_save_path = get_prompt_type_caching( 248 use_boxes, "boxes.pkl" 249 ) 250 251 return (cached_point_prompts, save_point_prompts, point_prompt_save_path, 252 cached_box_prompts, save_box_prompts, box_prompt_save_path) 253 254 255def run_inference_with_prompts( 256 predictor: SamPredictor, 257 image_paths: List[Union[str, os.PathLike]], 258 gt_paths: List[Union[str, os.PathLike]], 259 embedding_dir: Union[str, os.PathLike], 260 prediction_dir: Union[str, os.PathLike], 261 use_points: bool, 262 use_boxes: bool, 263 n_positives: int, 264 n_negatives: int, 265 dilation: int = 5, 266 prompt_save_dir: Optional[Union[str, os.PathLike]] = None, 267 batch_size: int = 512, 268) -> None: 269 """Run segment anything inference for multiple images using prompts derived from groundtruth. 270 271 Args: 272 predictor: The Segment Anything predictor. 273 image_paths: The image file paths. 274 gt_paths: The ground-truth segmentation file paths. 275 embedding_dir: The directory where the image embddings will be saved or are already saved. 276 use_points: Whether to use point prompts. 277 use_boxes: Whether to use box prompts 278 n_positives: The number of positive point prompts that will be sampled. 279 n_negativess: The number of negative point prompts that will be sampled. 280 dilation: The dilation factor for the radius around the ground-truth object 281 around which points will not be sampled. 282 prompt_save_dir: The directory where point prompts will be saved or are already saved. 283 This enables running multiple experiments in a reproducible manner. 284 batch_size: The batch size used for batched prediction. 285 """ 286 if not (use_points or use_boxes): 287 raise ValueError("You need to use at least one of point or box prompts.") 288 289 if len(image_paths) != len(gt_paths): 290 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 291 292 (cached_point_prompts, save_point_prompts, point_prompt_save_path, 293 cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching( 294 prompt_save_dir, use_points, use_boxes, n_positives, n_negatives 295 ) 296 297 os.makedirs(prediction_dir, exist_ok=True) 298 for image_path, gt_path in tqdm( 299 zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts" 300 ): 301 image_name = os.path.basename(image_path) 302 label_name = os.path.basename(gt_path) 303 304 # We skip the images that already have been segmented. 305 prediction_path = os.path.join(prediction_dir, image_name) 306 if os.path.exists(prediction_path): 307 continue 308 309 assert os.path.exists(image_path), image_path 310 assert os.path.exists(gt_path), gt_path 311 312 im = imageio.imread(image_path) 313 gt = imageio.imread(gt_path).astype("uint32") 314 gt = relabel_sequential(gt)[0] 315 316 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 317 this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts( 318 cached_point_prompts, save_point_prompts, 319 cached_box_prompts, save_box_prompts, 320 label_name 321 ) 322 instances, this_prompts = _run_inference_with_prompts_for_image( 323 predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives, 324 dilation=dilation, use_points=use_points, use_boxes=use_boxes, 325 batch_size=batch_size, cached_prompts=this_prompts, 326 embedding_path=embedding_path, 327 ) 328 329 if save_point_prompts: 330 cached_point_prompts[label_name] = this_prompts[:2] 331 if save_box_prompts: 332 cached_box_prompts[label_name] = this_prompts[-1] 333 334 # It's important to compress here, otherwise the predictions would take up a lot of space. 335 imageio.imwrite(prediction_path, instances, compression=5) 336 337 # Save the prompts if we run experiments with prompt caching and have computed them 338 # for the first time. 339 if save_point_prompts: 340 with open(point_prompt_save_path, "wb") as f: 341 pickle.dump(cached_point_prompts, f) 342 if save_box_prompts: 343 with open(box_prompt_save_path, "wb") as f: 344 pickle.dump(cached_box_prompts, f) 345 346 347def _save_segmentation(masks, prediction_path): 348 # masks to segmentation 349 masks = masks.cpu().numpy().squeeze(1).astype("bool") 350 masks = [{"segmentation": mask, "area": mask.sum()} for mask in masks] 351 segmentation = util.mask_data_to_segmentation(masks) 352 imageio.imwrite(prediction_path, segmentation, compression=5) 353 354 355def _get_batched_iterative_prompts(sampled_binary_gt, masks, batch_size, prompt_generator): 356 n_samples = sampled_binary_gt.shape[0] 357 n_batches = int(np.ceil(float(n_samples) / batch_size)) 358 next_coords, next_labels = [], [] 359 for batch_idx in range(n_batches): 360 batch_start = batch_idx * batch_size 361 batch_stop = min((batch_idx + 1) * batch_size, n_samples) 362 363 batch_coords, batch_labels, _, _ = prompt_generator( 364 sampled_binary_gt[batch_start: batch_stop], masks[batch_start: batch_stop] 365 ) 366 next_coords.append(batch_coords) 367 next_labels.append(batch_labels) 368 369 next_coords = torch.concatenate(next_coords) 370 next_labels = torch.concatenate(next_labels) 371 372 return next_coords, next_labels 373 374 375@torch.no_grad() 376def _run_inference_with_iterative_prompting_for_image( 377 predictor, 378 image, 379 gt, 380 start_with_box_prompt, 381 dilation, 382 batch_size, 383 embedding_path, 384 n_iterations, 385 prediction_paths, 386 use_masks=False 387) -> None: 388 verbose_embeddings = False 389 390 prompt_generator = IterativePromptGenerator() 391 392 gt_ids = np.unique(gt)[1:] 393 394 # Use multi-masking only if we have a single positive point without box 395 if start_with_box_prompt: 396 use_boxes, use_points = True, False 397 n_positives = 0 398 multimasking = False 399 else: 400 use_boxes, use_points = False, True 401 n_positives = 1 402 multimasking = True 403 404 points, point_labels, boxes = _get_batched_prompts( 405 gt, gt_ids, 406 use_points=use_points, 407 use_boxes=use_boxes, 408 n_positives=n_positives, 409 n_negatives=0, 410 dilation=dilation 411 ) 412 413 sampled_binary_gt = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids) 414 415 for iteration in range(n_iterations): 416 if iteration == 0: # logits mask can not be used for the first iteration. 417 logits_masks = None 418 else: 419 if not use_masks: # logits mask should not be used when not desired. 420 logits_masks = None 421 422 batched_outputs = batched_inference( 423 predictor=predictor, 424 image=image, 425 batch_size=batch_size, 426 boxes=boxes, 427 points=points, 428 point_labels=point_labels, 429 multimasking=multimasking, 430 embedding_path=embedding_path, 431 return_instance_segmentation=False, 432 logits_masks=logits_masks, 433 verbose_embeddings=verbose_embeddings, 434 ) 435 436 # switching off multimasking after first iter, as next iters (with multiple prompts) don't expect multimasking 437 multimasking = False 438 439 masks = torch.stack([m["segmentation"][None] for m in batched_outputs]).to(torch.float32) 440 441 next_coords, next_labels = _get_batched_iterative_prompts( 442 sampled_binary_gt, masks, batch_size, prompt_generator 443 ) 444 next_coords, next_labels = next_coords.detach().cpu().numpy(), next_labels.detach().cpu().numpy() 445 446 if points is not None: 447 points = np.concatenate([points, next_coords], axis=1) 448 else: 449 points = next_coords 450 451 if point_labels is not None: 452 point_labels = np.concatenate([point_labels, next_labels], axis=1) 453 else: 454 point_labels = next_labels 455 456 if use_masks: 457 logits_masks = torch.stack([m["logits"] for m in batched_outputs]) 458 459 _save_segmentation(masks, prediction_paths[iteration]) 460 461 462def run_inference_with_iterative_prompting( 463 predictor: SamPredictor, 464 image_paths: List[Union[str, os.PathLike]], 465 gt_paths: List[Union[str, os.PathLike]], 466 embedding_dir: Union[str, os.PathLike], 467 prediction_dir: Union[str, os.PathLike], 468 start_with_box_prompt: bool = True, 469 dilation: int = 5, 470 batch_size: int = 32, 471 n_iterations: int = 8, 472 use_masks: bool = False 473) -> None: 474 """Run Segment Anything inference for multiple images using prompts iteratively 475 derived from model outputs and ground-truth. 476 477 Args: 478 predictor: The Segment Anything predictor. 479 image_paths: The image file paths. 480 gt_paths: The ground-truth segmentation file paths. 481 embedding_dir: The directory where the image embeddings will be saved or are already saved. 482 prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration. 483 start_with_box_prompt: Whether to use the first prompt as bounding box or a single point 484 dilation: The dilation factor for the radius around the ground-truth object 485 around which points will not be sampled. 486 batch_size: The batch size used for batched predictions. 487 n_iterations: The number of iterations for iterative prompting. 488 use_masks: Whether to make use of logits from previous prompt-based segmentation. 489 """ 490 if len(image_paths) != len(gt_paths): 491 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 492 493 # create all prediction folders for all intermediate iterations 494 for i in range(n_iterations): 495 os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True) 496 497 if use_masks: 498 print("The iterative prompting will make use of logits masks from previous iterations.") 499 500 for image_path, gt_path in tqdm( 501 zip(image_paths, gt_paths), total=len(image_paths), 502 desc="Run inference with iterative prompting for all images", 503 ): 504 image_name = os.path.basename(image_path) 505 506 # We skip the images that already have been segmented 507 prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)] 508 if all(os.path.exists(prediction_path) for prediction_path in prediction_paths): 509 continue 510 511 assert os.path.exists(image_path), image_path 512 assert os.path.exists(gt_path), gt_path 513 514 image = imageio.imread(image_path) 515 gt = imageio.imread(gt_path).astype("uint32") 516 gt = relabel_sequential(gt)[0] 517 518 if embedding_dir is None: 519 embedding_path = None 520 else: 521 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 522 523 _run_inference_with_iterative_prompting_for_image( 524 predictor, image, gt, start_with_box_prompt=start_with_box_prompt, 525 dilation=dilation, batch_size=batch_size, embedding_path=embedding_path, 526 n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks 527 ) 528 529 530# 531# AMG FUNCTION 532# 533 534 535def run_amg( 536 checkpoint: Union[str, os.PathLike], 537 model_type: str, 538 experiment_folder: Union[str, os.PathLike], 539 val_image_paths: List[Union[str, os.PathLike]], 540 val_gt_paths: List[Union[str, os.PathLike]], 541 test_image_paths: List[Union[str, os.PathLike]], 542 iou_thresh_values: Optional[List[float]] = None, 543 stability_score_values: Optional[List[float]] = None, 544 peft_kwargs: Optional[Dict] = None, 545 cache_embeddings: bool = False, 546 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 547) -> str: 548 """Run Segment Anything inference for multiple images using automatic mask generation (AMG). 549 550 Args: 551 checkpoint: The filepath to model checkpoints. 552 model_type: The segment anything model choice. 553 experimet_folder: The directory where the relevant files are saved. 554 val_image_paths: The list of filepaths of input images for grid-search. 555 val_gt_paths: The list of filepaths of corresponding labels for grid-search. 556 test_image_paths: The list of filepaths of input images for automatic instance segmentation. 557 iou_thresh_values: Optional choice of values for grid search of `iou_thresh` parameter. 558 stability_score_values: Optional choice of values for grid search of `stability_score` parameter. 559 peft_kwargs: Keyword arguments for th PEFT wrapper class. 560 cache_embeddings: Whether to cache embeddings in experiment folder. 561 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 562 563 Returns: 564 Filepath where the predictions have been saved. 565 """ 566 567 if cache_embeddings: 568 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 569 os.makedirs(embedding_folder, exist_ok=True) 570 else: 571 embedding_folder = None 572 573 predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs) 574 575 # Get the AMG class. 576 if tiling_window_params: 577 if not isinstance(tiling_window_params, dict): 578 raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.") 579 580 if "tile_shape" not in tiling_window_params: 581 raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.") 582 583 if "halo" not in tiling_window_params: 584 raise RuntimeError("'halo' parameter is missing from the provided parameters.") 585 586 amg_class = TiledAutomaticMaskGenerator 587 else: 588 amg_class = AutomaticMaskGenerator 589 590 amg = amg_class(predictor) 591 amg_prefix = "amg" 592 593 # where the predictions are saved 594 prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference") 595 os.makedirs(prediction_folder, exist_ok=True) 596 597 # where the grid-search results are saved 598 gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search") 599 os.makedirs(gs_result_folder, exist_ok=True) 600 601 grid_search_values = instance_segmentation.default_grid_search_values_amg( 602 iou_thresh_values=iou_thresh_values, 603 stability_score_values=stability_score_values, 604 ) 605 606 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 607 segmenter=amg, 608 grid_search_values=grid_search_values, 609 val_image_paths=val_image_paths, 610 val_gt_paths=val_gt_paths, 611 test_image_paths=test_image_paths, 612 embedding_dir=embedding_folder, 613 prediction_dir=prediction_folder, 614 result_dir=gs_result_folder, 615 experiment_folder=experiment_folder, 616 tiling_window_params=tiling_window_params, 617 ) 618 return prediction_folder 619 620 621def run_apg( 622 checkpoint: Optional[Union[str, os.PathLike]], 623 model_type: str, 624 experiment_folder: Union[str, os.PathLike], 625 val_image_paths: List[Union[str, os.PathLike]], 626 val_gt_paths: List[Union[str, os.PathLike]], 627 test_image_paths: List[Union[str, os.PathLike]], 628 peft_kwargs: Optional[Dict] = None, 629 cache_embeddings: bool = False, 630 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 631) -> str: 632 """Run Segment Anything inference for multiple images using automatic prompt generation (APG). 633 634 Args: 635 ... 636 637 Returns: 638 Filepath where the predictions have been saved. 639 """ 640 if cache_embeddings: 641 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 642 os.makedirs(embedding_folder, exist_ok=True) 643 else: 644 embedding_folder = None 645 646 predictor, decoder = get_predictor_and_decoder( 647 model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, 648 ) 649 650 # Get the APG class. 651 if tiling_window_params: 652 raise NotImplementedError 653 else: 654 apg_class = AutomaticPromptGenerator 655 656 segmenter = apg_class(predictor, decoder) 657 seg_prefix = "apg" 658 659 # where the predictions are saved 660 prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") 661 os.makedirs(prediction_folder, exist_ok=True) 662 663 # where the grid-search results are saved 664 gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") 665 os.makedirs(gs_result_folder, exist_ok=True) 666 667 grid_search_values = instance_segmentation.default_grid_search_values_apg() 668 669 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 670 segmenter=segmenter, 671 grid_search_values=grid_search_values, 672 val_image_paths=val_image_paths, 673 val_gt_paths=val_gt_paths, 674 test_image_paths=test_image_paths, 675 embedding_dir=embedding_folder, 676 prediction_dir=prediction_folder, 677 result_dir=gs_result_folder, 678 experiment_folder=experiment_folder, 679 tiling_window_params=tiling_window_params, 680 ) 681 return prediction_folder 682 683 684# 685# INSTANCE SEGMENTATION FUNCTION 686# 687 688 689def run_instance_segmentation_with_decoder( 690 checkpoint: Union[str, os.PathLike], 691 model_type: str, 692 experiment_folder: Union[str, os.PathLike], 693 val_image_paths: List[Union[str, os.PathLike]], 694 val_gt_paths: List[Union[str, os.PathLike]], 695 test_image_paths: List[Union[str, os.PathLike]], 696 peft_kwargs: Optional[Dict] = None, 697 cache_embeddings: bool = False, 698 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 699) -> str: 700 """Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS). 701 702 Args: 703 checkpoint: The filepath to model checkpoints. 704 model_type: The segment anything model choice. 705 experimet_folder: The directory where the relevant files are saved. 706 val_image_paths: The list of filepaths of input images for grid-search. 707 val_gt_paths: The list of filepaths of corresponding labels for grid-search. 708 test_image_paths: The list of filepaths of input images for automatic instance segmentation. 709 peft_kwargs: Keyword arguments for th PEFT wrapper class. 710 cache_embeddings: Whether to cache embeddings in experiment folder. 711 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 712 713 Returns: 714 Filepath where the predictions have been saved. 715 """ 716 717 if cache_embeddings: 718 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 719 os.makedirs(embedding_folder, exist_ok=True) 720 else: 721 embedding_folder = None 722 723 predictor, decoder = get_predictor_and_decoder( 724 model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, 725 ) 726 727 # Get the AIS class. 728 if tiling_window_params: 729 if not isinstance(tiling_window_params, dict): 730 raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.") 731 732 if "tile_shape" not in tiling_window_params: 733 raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.") 734 735 if "halo" not in tiling_window_params: 736 raise RuntimeError("'halo' parameter is missing from the provided parameters.") 737 738 ais_class = TiledInstanceSegmentationWithDecoder 739 else: 740 ais_class = InstanceSegmentationWithDecoder 741 742 segmenter = ais_class(predictor, decoder) 743 seg_prefix = "instance_segmentation_with_decoder" 744 745 # where the predictions are saved 746 prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") 747 os.makedirs(prediction_folder, exist_ok=True) 748 749 # where the grid-search results are saved 750 gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") 751 os.makedirs(gs_result_folder, exist_ok=True) 752 753 grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder() 754 755 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 756 segmenter=segmenter, 757 grid_search_values=grid_search_values, 758 val_image_paths=val_image_paths, 759 val_gt_paths=val_gt_paths, 760 test_image_paths=test_image_paths, 761 embedding_dir=embedding_folder, 762 prediction_dir=prediction_folder, 763 result_dir=gs_result_folder, 764 experiment_folder=experiment_folder, 765 tiling_window_params=tiling_window_params, 766 ) 767 return prediction_folder
136def precompute_all_embeddings( 137 predictor: SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], 138) -> None: 139 """Precompute all image embeddings. 140 141 To enable running different inference tasks in parallel afterwards. 142 143 Args: 144 predictor: The Segment Anything predictor. 145 image_paths: The image file paths. 146 embedding_dir: The directory where the embeddings will be saved. 147 """ 148 for image_path in tqdm(image_paths, desc="Precompute embeddings"): 149 image_name = os.path.basename(image_path) 150 im = imageio.imread(image_path) 151 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 152 util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)
Precompute all image embeddings.
To enable running different inference tasks in parallel afterwards.
Arguments:
- predictor: The Segment Anything predictor.
- image_paths: The image file paths.
- embedding_dir: The directory where the embeddings will be saved.
171def precompute_all_prompts( 172 gt_paths: List[Union[str, os.PathLike]], 173 prompt_save_dir: Union[str, os.PathLike], 174 prompt_settings: List[Dict[str, Any]], 175) -> None: 176 """Precompute all point prompts. 177 178 To enable running different inference tasks in parallel afterwards. 179 180 Args: 181 gt_paths: The file paths to the ground-truth segmentations. 182 prompt_save_dir: The directory where the prompt files will be saved. 183 prompt_settings: The settings for which the prompts will be computed. 184 """ 185 os.makedirs(prompt_save_dir, exist_ok=True) 186 187 for settings in tqdm(prompt_settings, desc="Precompute prompts"): 188 189 use_points, use_boxes = settings["use_points"], settings["use_boxes"] 190 n_positives, n_negatives = settings["n_positives"], settings["n_negatives"] 191 dilation = settings.get("dilation", 5) 192 193 # check if the prompts were already computed 194 if use_boxes and not use_points: 195 prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl") 196 else: 197 prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl") 198 if os.path.exists(prompt_save_path): 199 continue 200 201 results = [] 202 for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"): 203 prompts = _precompute_prompts( 204 gt_path, 205 use_points=use_points, 206 use_boxes=use_boxes, 207 n_positives=n_positives, 208 n_negatives=n_negatives, 209 dilation=dilation, 210 ) 211 results.append(prompts) 212 213 saved_prompts = {res[0]: res[1] for res in results} 214 with open(prompt_save_path, "wb") as f: 215 pickle.dump(saved_prompts, f)
Precompute all point prompts.
To enable running different inference tasks in parallel afterwards.
Arguments:
- gt_paths: The file paths to the ground-truth segmentations.
- prompt_save_dir: The directory where the prompt files will be saved.
- prompt_settings: The settings for which the prompts will be computed.
256def run_inference_with_prompts( 257 predictor: SamPredictor, 258 image_paths: List[Union[str, os.PathLike]], 259 gt_paths: List[Union[str, os.PathLike]], 260 embedding_dir: Union[str, os.PathLike], 261 prediction_dir: Union[str, os.PathLike], 262 use_points: bool, 263 use_boxes: bool, 264 n_positives: int, 265 n_negatives: int, 266 dilation: int = 5, 267 prompt_save_dir: Optional[Union[str, os.PathLike]] = None, 268 batch_size: int = 512, 269) -> None: 270 """Run segment anything inference for multiple images using prompts derived from groundtruth. 271 272 Args: 273 predictor: The Segment Anything predictor. 274 image_paths: The image file paths. 275 gt_paths: The ground-truth segmentation file paths. 276 embedding_dir: The directory where the image embddings will be saved or are already saved. 277 use_points: Whether to use point prompts. 278 use_boxes: Whether to use box prompts 279 n_positives: The number of positive point prompts that will be sampled. 280 n_negativess: The number of negative point prompts that will be sampled. 281 dilation: The dilation factor for the radius around the ground-truth object 282 around which points will not be sampled. 283 prompt_save_dir: The directory where point prompts will be saved or are already saved. 284 This enables running multiple experiments in a reproducible manner. 285 batch_size: The batch size used for batched prediction. 286 """ 287 if not (use_points or use_boxes): 288 raise ValueError("You need to use at least one of point or box prompts.") 289 290 if len(image_paths) != len(gt_paths): 291 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 292 293 (cached_point_prompts, save_point_prompts, point_prompt_save_path, 294 cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching( 295 prompt_save_dir, use_points, use_boxes, n_positives, n_negatives 296 ) 297 298 os.makedirs(prediction_dir, exist_ok=True) 299 for image_path, gt_path in tqdm( 300 zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts" 301 ): 302 image_name = os.path.basename(image_path) 303 label_name = os.path.basename(gt_path) 304 305 # We skip the images that already have been segmented. 306 prediction_path = os.path.join(prediction_dir, image_name) 307 if os.path.exists(prediction_path): 308 continue 309 310 assert os.path.exists(image_path), image_path 311 assert os.path.exists(gt_path), gt_path 312 313 im = imageio.imread(image_path) 314 gt = imageio.imread(gt_path).astype("uint32") 315 gt = relabel_sequential(gt)[0] 316 317 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 318 this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts( 319 cached_point_prompts, save_point_prompts, 320 cached_box_prompts, save_box_prompts, 321 label_name 322 ) 323 instances, this_prompts = _run_inference_with_prompts_for_image( 324 predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives, 325 dilation=dilation, use_points=use_points, use_boxes=use_boxes, 326 batch_size=batch_size, cached_prompts=this_prompts, 327 embedding_path=embedding_path, 328 ) 329 330 if save_point_prompts: 331 cached_point_prompts[label_name] = this_prompts[:2] 332 if save_box_prompts: 333 cached_box_prompts[label_name] = this_prompts[-1] 334 335 # It's important to compress here, otherwise the predictions would take up a lot of space. 336 imageio.imwrite(prediction_path, instances, compression=5) 337 338 # Save the prompts if we run experiments with prompt caching and have computed them 339 # for the first time. 340 if save_point_prompts: 341 with open(point_prompt_save_path, "wb") as f: 342 pickle.dump(cached_point_prompts, f) 343 if save_box_prompts: 344 with open(box_prompt_save_path, "wb") as f: 345 pickle.dump(cached_box_prompts, f)
Run segment anything inference for multiple images using prompts derived from groundtruth.
Arguments:
- predictor: The Segment Anything predictor.
- image_paths: The image file paths.
- gt_paths: The ground-truth segmentation file paths.
- embedding_dir: The directory where the image embddings will be saved or are already saved.
- use_points: Whether to use point prompts.
- use_boxes: Whether to use box prompts
- n_positives: The number of positive point prompts that will be sampled.
- n_negativess: The number of negative point prompts that will be sampled.
- dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
- prompt_save_dir: The directory where point prompts will be saved or are already saved. This enables running multiple experiments in a reproducible manner.
- batch_size: The batch size used for batched prediction.
463def run_inference_with_iterative_prompting( 464 predictor: SamPredictor, 465 image_paths: List[Union[str, os.PathLike]], 466 gt_paths: List[Union[str, os.PathLike]], 467 embedding_dir: Union[str, os.PathLike], 468 prediction_dir: Union[str, os.PathLike], 469 start_with_box_prompt: bool = True, 470 dilation: int = 5, 471 batch_size: int = 32, 472 n_iterations: int = 8, 473 use_masks: bool = False 474) -> None: 475 """Run Segment Anything inference for multiple images using prompts iteratively 476 derived from model outputs and ground-truth. 477 478 Args: 479 predictor: The Segment Anything predictor. 480 image_paths: The image file paths. 481 gt_paths: The ground-truth segmentation file paths. 482 embedding_dir: The directory where the image embeddings will be saved or are already saved. 483 prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration. 484 start_with_box_prompt: Whether to use the first prompt as bounding box or a single point 485 dilation: The dilation factor for the radius around the ground-truth object 486 around which points will not be sampled. 487 batch_size: The batch size used for batched predictions. 488 n_iterations: The number of iterations for iterative prompting. 489 use_masks: Whether to make use of logits from previous prompt-based segmentation. 490 """ 491 if len(image_paths) != len(gt_paths): 492 raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}") 493 494 # create all prediction folders for all intermediate iterations 495 for i in range(n_iterations): 496 os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True) 497 498 if use_masks: 499 print("The iterative prompting will make use of logits masks from previous iterations.") 500 501 for image_path, gt_path in tqdm( 502 zip(image_paths, gt_paths), total=len(image_paths), 503 desc="Run inference with iterative prompting for all images", 504 ): 505 image_name = os.path.basename(image_path) 506 507 # We skip the images that already have been segmented 508 prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)] 509 if all(os.path.exists(prediction_path) for prediction_path in prediction_paths): 510 continue 511 512 assert os.path.exists(image_path), image_path 513 assert os.path.exists(gt_path), gt_path 514 515 image = imageio.imread(image_path) 516 gt = imageio.imread(gt_path).astype("uint32") 517 gt = relabel_sequential(gt)[0] 518 519 if embedding_dir is None: 520 embedding_path = None 521 else: 522 embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") 523 524 _run_inference_with_iterative_prompting_for_image( 525 predictor, image, gt, start_with_box_prompt=start_with_box_prompt, 526 dilation=dilation, batch_size=batch_size, embedding_path=embedding_path, 527 n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks 528 )
Run Segment Anything inference for multiple images using prompts iteratively derived from model outputs and ground-truth.
Arguments:
- predictor: The Segment Anything predictor.
- image_paths: The image file paths.
- gt_paths: The ground-truth segmentation file paths.
- embedding_dir: The directory where the image embeddings will be saved or are already saved.
- prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
- start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
- dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
- batch_size: The batch size used for batched predictions.
- n_iterations: The number of iterations for iterative prompting.
- use_masks: Whether to make use of logits from previous prompt-based segmentation.
536def run_amg( 537 checkpoint: Union[str, os.PathLike], 538 model_type: str, 539 experiment_folder: Union[str, os.PathLike], 540 val_image_paths: List[Union[str, os.PathLike]], 541 val_gt_paths: List[Union[str, os.PathLike]], 542 test_image_paths: List[Union[str, os.PathLike]], 543 iou_thresh_values: Optional[List[float]] = None, 544 stability_score_values: Optional[List[float]] = None, 545 peft_kwargs: Optional[Dict] = None, 546 cache_embeddings: bool = False, 547 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 548) -> str: 549 """Run Segment Anything inference for multiple images using automatic mask generation (AMG). 550 551 Args: 552 checkpoint: The filepath to model checkpoints. 553 model_type: The segment anything model choice. 554 experimet_folder: The directory where the relevant files are saved. 555 val_image_paths: The list of filepaths of input images for grid-search. 556 val_gt_paths: The list of filepaths of corresponding labels for grid-search. 557 test_image_paths: The list of filepaths of input images for automatic instance segmentation. 558 iou_thresh_values: Optional choice of values for grid search of `iou_thresh` parameter. 559 stability_score_values: Optional choice of values for grid search of `stability_score` parameter. 560 peft_kwargs: Keyword arguments for th PEFT wrapper class. 561 cache_embeddings: Whether to cache embeddings in experiment folder. 562 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 563 564 Returns: 565 Filepath where the predictions have been saved. 566 """ 567 568 if cache_embeddings: 569 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 570 os.makedirs(embedding_folder, exist_ok=True) 571 else: 572 embedding_folder = None 573 574 predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs) 575 576 # Get the AMG class. 577 if tiling_window_params: 578 if not isinstance(tiling_window_params, dict): 579 raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.") 580 581 if "tile_shape" not in tiling_window_params: 582 raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.") 583 584 if "halo" not in tiling_window_params: 585 raise RuntimeError("'halo' parameter is missing from the provided parameters.") 586 587 amg_class = TiledAutomaticMaskGenerator 588 else: 589 amg_class = AutomaticMaskGenerator 590 591 amg = amg_class(predictor) 592 amg_prefix = "amg" 593 594 # where the predictions are saved 595 prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference") 596 os.makedirs(prediction_folder, exist_ok=True) 597 598 # where the grid-search results are saved 599 gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search") 600 os.makedirs(gs_result_folder, exist_ok=True) 601 602 grid_search_values = instance_segmentation.default_grid_search_values_amg( 603 iou_thresh_values=iou_thresh_values, 604 stability_score_values=stability_score_values, 605 ) 606 607 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 608 segmenter=amg, 609 grid_search_values=grid_search_values, 610 val_image_paths=val_image_paths, 611 val_gt_paths=val_gt_paths, 612 test_image_paths=test_image_paths, 613 embedding_dir=embedding_folder, 614 prediction_dir=prediction_folder, 615 result_dir=gs_result_folder, 616 experiment_folder=experiment_folder, 617 tiling_window_params=tiling_window_params, 618 ) 619 return prediction_folder
Run Segment Anything inference for multiple images using automatic mask generation (AMG).
Arguments:
- checkpoint: The filepath to model checkpoints.
- model_type: The segment anything model choice.
- experimet_folder: The directory where the relevant files are saved.
- val_image_paths: The list of filepaths of input images for grid-search.
- val_gt_paths: The list of filepaths of corresponding labels for grid-search.
- test_image_paths: The list of filepaths of input images for automatic instance segmentation.
- iou_thresh_values: Optional choice of values for grid search of
iou_threshparameter. - stability_score_values: Optional choice of values for grid search of
stability_scoreparameter. - peft_kwargs: Keyword arguments for th PEFT wrapper class.
- cache_embeddings: Whether to cache embeddings in experiment folder.
- tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
Returns:
Filepath where the predictions have been saved.
622def run_apg( 623 checkpoint: Optional[Union[str, os.PathLike]], 624 model_type: str, 625 experiment_folder: Union[str, os.PathLike], 626 val_image_paths: List[Union[str, os.PathLike]], 627 val_gt_paths: List[Union[str, os.PathLike]], 628 test_image_paths: List[Union[str, os.PathLike]], 629 peft_kwargs: Optional[Dict] = None, 630 cache_embeddings: bool = False, 631 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 632) -> str: 633 """Run Segment Anything inference for multiple images using automatic prompt generation (APG). 634 635 Args: 636 ... 637 638 Returns: 639 Filepath where the predictions have been saved. 640 """ 641 if cache_embeddings: 642 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 643 os.makedirs(embedding_folder, exist_ok=True) 644 else: 645 embedding_folder = None 646 647 predictor, decoder = get_predictor_and_decoder( 648 model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, 649 ) 650 651 # Get the APG class. 652 if tiling_window_params: 653 raise NotImplementedError 654 else: 655 apg_class = AutomaticPromptGenerator 656 657 segmenter = apg_class(predictor, decoder) 658 seg_prefix = "apg" 659 660 # where the predictions are saved 661 prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") 662 os.makedirs(prediction_folder, exist_ok=True) 663 664 # where the grid-search results are saved 665 gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") 666 os.makedirs(gs_result_folder, exist_ok=True) 667 668 grid_search_values = instance_segmentation.default_grid_search_values_apg() 669 670 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 671 segmenter=segmenter, 672 grid_search_values=grid_search_values, 673 val_image_paths=val_image_paths, 674 val_gt_paths=val_gt_paths, 675 test_image_paths=test_image_paths, 676 embedding_dir=embedding_folder, 677 prediction_dir=prediction_folder, 678 result_dir=gs_result_folder, 679 experiment_folder=experiment_folder, 680 tiling_window_params=tiling_window_params, 681 ) 682 return prediction_folder
Run Segment Anything inference for multiple images using automatic prompt generation (APG).
Arguments:
- ...
Returns:
Filepath where the predictions have been saved.
690def run_instance_segmentation_with_decoder( 691 checkpoint: Union[str, os.PathLike], 692 model_type: str, 693 experiment_folder: Union[str, os.PathLike], 694 val_image_paths: List[Union[str, os.PathLike]], 695 val_gt_paths: List[Union[str, os.PathLike]], 696 test_image_paths: List[Union[str, os.PathLike]], 697 peft_kwargs: Optional[Dict] = None, 698 cache_embeddings: bool = False, 699 tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None, 700) -> str: 701 """Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS). 702 703 Args: 704 checkpoint: The filepath to model checkpoints. 705 model_type: The segment anything model choice. 706 experimet_folder: The directory where the relevant files are saved. 707 val_image_paths: The list of filepaths of input images for grid-search. 708 val_gt_paths: The list of filepaths of corresponding labels for grid-search. 709 test_image_paths: The list of filepaths of input images for automatic instance segmentation. 710 peft_kwargs: Keyword arguments for th PEFT wrapper class. 711 cache_embeddings: Whether to cache embeddings in experiment folder. 712 tiling_window_params: The parameters to decide whether to use tiling window operation for AIS. 713 714 Returns: 715 Filepath where the predictions have been saved. 716 """ 717 718 if cache_embeddings: 719 embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved 720 os.makedirs(embedding_folder, exist_ok=True) 721 else: 722 embedding_folder = None 723 724 predictor, decoder = get_predictor_and_decoder( 725 model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, 726 ) 727 728 # Get the AIS class. 729 if tiling_window_params: 730 if not isinstance(tiling_window_params, dict): 731 raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.") 732 733 if "tile_shape" not in tiling_window_params: 734 raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.") 735 736 if "halo" not in tiling_window_params: 737 raise RuntimeError("'halo' parameter is missing from the provided parameters.") 738 739 ais_class = TiledInstanceSegmentationWithDecoder 740 else: 741 ais_class = InstanceSegmentationWithDecoder 742 743 segmenter = ais_class(predictor, decoder) 744 seg_prefix = "instance_segmentation_with_decoder" 745 746 # where the predictions are saved 747 prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference") 748 os.makedirs(prediction_folder, exist_ok=True) 749 750 # where the grid-search results are saved 751 gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search") 752 os.makedirs(gs_result_folder, exist_ok=True) 753 754 grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder() 755 756 instance_segmentation.run_instance_segmentation_grid_search_and_inference( 757 segmenter=segmenter, 758 grid_search_values=grid_search_values, 759 val_image_paths=val_image_paths, 760 val_gt_paths=val_gt_paths, 761 test_image_paths=test_image_paths, 762 embedding_dir=embedding_folder, 763 prediction_dir=prediction_folder, 764 result_dir=gs_result_folder, 765 experiment_folder=experiment_folder, 766 tiling_window_params=tiling_window_params, 767 ) 768 return prediction_folder
Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS).
Arguments:
- checkpoint: The filepath to model checkpoints.
- model_type: The segment anything model choice.
- experimet_folder: The directory where the relevant files are saved.
- val_image_paths: The list of filepaths of input images for grid-search.
- val_gt_paths: The list of filepaths of corresponding labels for grid-search.
- test_image_paths: The list of filepaths of input images for automatic instance segmentation.
- peft_kwargs: Keyword arguments for th PEFT wrapper class.
- cache_embeddings: Whether to cache embeddings in experiment folder.
- tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
Returns:
Filepath where the predictions have been saved.