micro_sam.evaluation.inference

Inference with Segment Anything models and different prompt strategies.

  1"""Inference with Segment Anything models and different prompt strategies.
  2"""
  3
  4import os
  5import pickle
  6import numpy as np
  7from tqdm import tqdm
  8from copy import deepcopy
  9from typing import Any, Dict, List, Optional, Union, Tuple
 10
 11import imageio.v3 as imageio
 12from skimage.segmentation import relabel_sequential
 13
 14import torch
 15
 16from segment_anything import SamPredictor
 17
 18from .. import util as util
 19from ..inference import batched_inference
 20from ..instance_segmentation import (
 21    get_predictor_and_decoder,
 22    AutomaticMaskGenerator, InstanceSegmentationWithDecoder,
 23    TiledAutomaticMaskGenerator, TiledInstanceSegmentationWithDecoder,
 24    AutomaticPromptGenerator,
 25)
 26from . import instance_segmentation
 27from ..prompt_generators import PointAndBoxPromptGenerator, IterativePromptGenerator
 28
 29
 30def _load_prompts(
 31    cached_point_prompts, save_point_prompts, cached_box_prompts, save_box_prompts, image_name
 32):
 33
 34    def load_prompt_type(cached_prompts, save_prompts):
 35        # Check if we have saved prompts.
 36        if cached_prompts is None or save_prompts:  # we don't have cached prompts
 37            return cached_prompts, None
 38
 39        # we have cached prompts, but they have not been loaded yet
 40        if isinstance(cached_prompts, str):
 41            with open(cached_prompts, "rb") as f:
 42                cached_prompts = pickle.load(f)
 43
 44        prompts = cached_prompts[image_name]
 45        return cached_prompts, prompts
 46
 47    cached_point_prompts, point_prompts = load_prompt_type(cached_point_prompts, save_point_prompts)
 48    cached_box_prompts, box_prompts = load_prompt_type(cached_box_prompts, save_box_prompts)
 49
 50    # we don't have anything cached
 51    if point_prompts is None and box_prompts is None:
 52        return None, cached_point_prompts, cached_box_prompts
 53
 54    if point_prompts is None:
 55        input_point, input_label = [], []
 56    else:
 57        input_point, input_label = point_prompts
 58
 59    if box_prompts is None:
 60        input_box = []
 61    else:
 62        input_box = box_prompts
 63
 64    prompts = (input_point, input_label, input_box)
 65    return prompts, cached_point_prompts, cached_box_prompts
 66
 67
 68def _get_batched_prompts(gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation):
 69
 70    # Initialize the prompt generator.
 71    prompt_generator = PointAndBoxPromptGenerator(
 72        n_positive_points=n_positives, n_negative_points=n_negatives,
 73        dilation_strength=dilation, get_point_prompts=use_points,
 74        get_box_prompts=use_boxes
 75    )
 76
 77    # Generate the prompts.
 78    center_coordinates, bbox_coordinates = util.get_centers_and_bounding_boxes(gt)
 79    center_coordinates = [center_coordinates[gt_id] for gt_id in gt_ids]
 80    bbox_coordinates = [bbox_coordinates[gt_id] for gt_id in gt_ids]
 81    masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids)
 82
 83    points, point_labels, boxes, _ = prompt_generator(
 84        masks, bbox_coordinates, center_coordinates
 85    )
 86
 87    def to_numpy(x):
 88        if x is None:
 89            return x
 90        return x.numpy()
 91
 92    return to_numpy(points), to_numpy(point_labels), to_numpy(boxes)
 93
 94
 95def _run_inference_with_prompts_for_image(
 96    predictor,
 97    image,
 98    gt,
 99    use_points,
100    use_boxes,
101    n_positives,
102    n_negatives,
103    dilation,
104    batch_size,
105    cached_prompts,
106    embedding_path,
107):
108    gt_ids = np.unique(gt)[1:]
109    if cached_prompts is None:
110        points, point_labels, boxes = _get_batched_prompts(
111            gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation
112        )
113    else:
114        points, point_labels, boxes = cached_prompts
115
116    # Make a copy of the point prompts to return them at the end.
117    prompts = deepcopy((points, point_labels, boxes))
118
119    # Use multi-masking only if we have a single positive point without box
120    multimasking = False
121    if not use_boxes and (n_positives == 1 and n_negatives == 0):
122        multimasking = True
123
124    instance_labels = batched_inference(
125        predictor, image, batch_size,
126        boxes=boxes, points=points, point_labels=point_labels,
127        multimasking=multimasking, embedding_path=embedding_path,
128        return_instance_segmentation=True,
129    )
130
131    return instance_labels, prompts
132
133
134def precompute_all_embeddings(
135    predictor: SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike],
136) -> None:
137    """Precompute all image embeddings.
138
139    To enable running different inference tasks in parallel afterwards.
140
141    Args:
142        predictor: The Segment Anything predictor.
143        image_paths: The image file paths.
144        embedding_dir: The directory where the embeddings will be saved.
145    """
146    for image_path in tqdm(image_paths, desc="Precompute embeddings"):
147        image_name = os.path.basename(image_path)
148        im = imageio.imread(image_path)
149        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
150        util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)
151
152
153def _precompute_prompts(gt_path, use_points, use_boxes, n_positives, n_negatives, dilation):
154    name = os.path.basename(gt_path)
155
156    gt = imageio.imread(gt_path).astype("uint32")
157    gt = relabel_sequential(gt)[0]
158    gt_ids = np.unique(gt)[1:]
159
160    input_point, input_label, input_box = _get_batched_prompts(
161        gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation
162    )
163
164    if use_boxes and not use_points:
165        return name, input_box
166    return name, (input_point, input_label)
167
168
169def precompute_all_prompts(
170    gt_paths: List[Union[str, os.PathLike]],
171    prompt_save_dir: Union[str, os.PathLike],
172    prompt_settings: List[Dict[str, Any]],
173) -> None:
174    """Precompute all point prompts.
175
176    To enable running different inference tasks in parallel afterwards.
177
178    Args:
179        gt_paths: The file paths to the ground-truth segmentations.
180        prompt_save_dir: The directory where the prompt files will be saved.
181        prompt_settings: The settings for which the prompts will be computed.
182    """
183    os.makedirs(prompt_save_dir, exist_ok=True)
184
185    for settings in tqdm(prompt_settings, desc="Precompute prompts"):
186
187        use_points, use_boxes = settings["use_points"], settings["use_boxes"]
188        n_positives, n_negatives = settings["n_positives"], settings["n_negatives"]
189        dilation = settings.get("dilation", 5)
190
191        # check if the prompts were already computed
192        if use_boxes and not use_points:
193            prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl")
194        else:
195            prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl")
196        if os.path.exists(prompt_save_path):
197            continue
198
199        results = []
200        for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"):
201            prompts = _precompute_prompts(
202                gt_path,
203                use_points=use_points,
204                use_boxes=use_boxes,
205                n_positives=n_positives,
206                n_negatives=n_negatives,
207                dilation=dilation,
208            )
209            results.append(prompts)
210
211        saved_prompts = {res[0]: res[1] for res in results}
212        with open(prompt_save_path, "wb") as f:
213            pickle.dump(saved_prompts, f)
214
215
216def _get_prompt_caching(prompt_save_dir, use_points, use_boxes, n_positives, n_negatives):
217
218    def get_prompt_type_caching(use_type, save_name):
219        if not use_type:
220            return None, False, None
221
222        prompt_save_path = os.path.join(prompt_save_dir, save_name)
223        if os.path.exists(prompt_save_path):
224            print("Using precomputed prompts from", prompt_save_path)
225            # We delay loading the prompts, so we only have to load them once they're needed the first time.
226            # This avoids loading the prompts (which are in a big pickle file) if all predictions are done already.
227            cached_prompts = prompt_save_path
228            save_prompts = False
229        else:
230            print("Saving prompts in", prompt_save_path)
231            cached_prompts = {}
232            save_prompts = True
233        return cached_prompts, save_prompts, prompt_save_path
234
235    # Check if prompt serialization is enabled.
236    # If it is then load the prompts if they are already cached and otherwise store them.
237    if prompt_save_dir is None:
238        print("Prompts are not cached.")
239        cached_point_prompts, cached_box_prompts = None, None
240        save_point_prompts, save_box_prompts = False, False
241        point_prompt_save_path, box_prompt_save_path = None, None
242    else:
243        cached_point_prompts, save_point_prompts, point_prompt_save_path = get_prompt_type_caching(
244            use_points, f"points-p{n_positives}-n{n_negatives}.pkl"
245        )
246        cached_box_prompts, save_box_prompts, box_prompt_save_path = get_prompt_type_caching(
247            use_boxes, "boxes.pkl"
248        )
249
250    return (cached_point_prompts, save_point_prompts, point_prompt_save_path,
251            cached_box_prompts, save_box_prompts, box_prompt_save_path)
252
253
254def run_inference_with_prompts(
255    predictor: SamPredictor,
256    image_paths: List[Union[str, os.PathLike]],
257    gt_paths: List[Union[str, os.PathLike]],
258    embedding_dir: Union[str, os.PathLike],
259    prediction_dir: Union[str, os.PathLike],
260    use_points: bool,
261    use_boxes: bool,
262    n_positives: int,
263    n_negatives: int,
264    dilation: int = 5,
265    prompt_save_dir: Optional[Union[str, os.PathLike]] = None,
266    batch_size: int = 512,
267) -> None:
268    """Run segment anything inference for multiple images using prompts derived from groundtruth.
269
270    Args:
271        predictor: The Segment Anything predictor.
272        image_paths: The image file paths.
273        gt_paths: The ground-truth segmentation file paths.
274        embedding_dir: The directory where the image embddings will be saved or are already saved.
275        use_points: Whether to use point prompts.
276        use_boxes: Whether to use box prompts
277        n_positives: The number of positive point prompts that will be sampled.
278        n_negativess: The number of negative point prompts that will be sampled.
279        dilation: The dilation factor for the radius around the ground-truth object
280            around which points will not be sampled.
281        prompt_save_dir: The directory where point prompts will be saved or are already saved.
282            This enables running multiple experiments in a reproducible manner.
283        batch_size: The batch size used for batched prediction.
284    """
285    if not (use_points or use_boxes):
286        raise ValueError("You need to use at least one of point or box prompts.")
287
288    if len(image_paths) != len(gt_paths):
289        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
290
291    (cached_point_prompts, save_point_prompts, point_prompt_save_path,
292     cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching(
293         prompt_save_dir, use_points, use_boxes, n_positives, n_negatives
294     )
295
296    os.makedirs(prediction_dir, exist_ok=True)
297    for image_path, gt_path in tqdm(
298        zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts"
299    ):
300        image_name = os.path.basename(image_path)
301        label_name = os.path.basename(gt_path)
302
303        # We skip the images that already have been segmented.
304        prediction_path = os.path.join(prediction_dir, image_name)
305        if os.path.exists(prediction_path):
306            continue
307
308        assert os.path.exists(image_path), image_path
309        assert os.path.exists(gt_path), gt_path
310
311        im = imageio.imread(image_path)
312        gt = imageio.imread(gt_path).astype("uint32")
313        gt = relabel_sequential(gt)[0]
314
315        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
316        this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts(
317            cached_point_prompts, save_point_prompts,
318            cached_box_prompts, save_box_prompts,
319            label_name
320        )
321        instances, this_prompts = _run_inference_with_prompts_for_image(
322            predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives,
323            dilation=dilation, use_points=use_points, use_boxes=use_boxes,
324            batch_size=batch_size, cached_prompts=this_prompts,
325            embedding_path=embedding_path,
326        )
327
328        if save_point_prompts:
329            cached_point_prompts[label_name] = this_prompts[:2]
330        if save_box_prompts:
331            cached_box_prompts[label_name] = this_prompts[-1]
332
333        # It's important to compress here, otherwise the predictions would take up a lot of space.
334        imageio.imwrite(prediction_path, instances, compression=5)
335
336    # Save the prompts if we run experiments with prompt caching and have computed them
337    # for the first time.
338    if save_point_prompts:
339        with open(point_prompt_save_path, "wb") as f:
340            pickle.dump(cached_point_prompts, f)
341    if save_box_prompts:
342        with open(box_prompt_save_path, "wb") as f:
343            pickle.dump(cached_box_prompts, f)
344
345
346def _save_segmentation(masks, prediction_path):
347    # masks to segmentation
348    masks = masks.cpu().numpy().squeeze(1).astype("bool")
349    masks = [{"segmentation": mask, "area": mask.sum()} for mask in masks]
350    segmentation = util.mask_data_to_segmentation(masks)
351    imageio.imwrite(prediction_path, segmentation, compression=5)
352
353
354def _get_batched_iterative_prompts(sampled_binary_gt, masks, batch_size, prompt_generator):
355    n_samples = sampled_binary_gt.shape[0]
356    n_batches = int(np.ceil(float(n_samples) / batch_size))
357    next_coords, next_labels = [], []
358    for batch_idx in range(n_batches):
359        batch_start = batch_idx * batch_size
360        batch_stop = min((batch_idx + 1) * batch_size, n_samples)
361
362        batch_coords, batch_labels, _, _ = prompt_generator(
363            sampled_binary_gt[batch_start: batch_stop], masks[batch_start: batch_stop]
364        )
365        next_coords.append(batch_coords)
366        next_labels.append(batch_labels)
367
368    next_coords = torch.concatenate(next_coords)
369    next_labels = torch.concatenate(next_labels)
370
371    return next_coords, next_labels
372
373
374@torch.no_grad()
375def _run_inference_with_iterative_prompting_for_image(
376    predictor,
377    image,
378    gt,
379    start_with_box_prompt,
380    dilation,
381    batch_size,
382    embedding_path,
383    n_iterations,
384    prediction_paths,
385    use_masks=False
386) -> None:
387    verbose_embeddings = False
388
389    prompt_generator = IterativePromptGenerator()
390
391    gt_ids = np.unique(gt)[1:]
392
393    # Use multi-masking only if we have a single positive point without box
394    if start_with_box_prompt:
395        use_boxes, use_points = True, False
396        n_positives = 0
397        multimasking = False
398    else:
399        use_boxes, use_points = False, True
400        n_positives = 1
401        multimasking = True
402
403    points, point_labels, boxes = _get_batched_prompts(
404        gt, gt_ids,
405        use_points=use_points,
406        use_boxes=use_boxes,
407        n_positives=n_positives,
408        n_negatives=0,
409        dilation=dilation
410    )
411
412    sampled_binary_gt = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids)
413
414    for iteration in range(n_iterations):
415        if iteration == 0:  # logits mask can not be used for the first iteration.
416            logits_masks = None
417        else:
418            if not use_masks:  # logits mask should not be used when not desired.
419                logits_masks = None
420
421        batched_outputs = batched_inference(
422            predictor=predictor,
423            image=image,
424            batch_size=batch_size,
425            boxes=boxes,
426            points=points,
427            point_labels=point_labels,
428            multimasking=multimasking,
429            embedding_path=embedding_path,
430            return_instance_segmentation=False,
431            logits_masks=logits_masks,
432            verbose_embeddings=verbose_embeddings,
433        )
434
435        # switching off multimasking after first iter, as next iters (with multiple prompts) don't expect multimasking
436        multimasking = False
437
438        masks = torch.stack([m["segmentation"][None] for m in batched_outputs]).to(torch.float32)
439
440        next_coords, next_labels = _get_batched_iterative_prompts(
441            sampled_binary_gt, masks, batch_size, prompt_generator
442        )
443        next_coords, next_labels = next_coords.detach().cpu().numpy(), next_labels.detach().cpu().numpy()
444
445        if points is not None:
446            points = np.concatenate([points, next_coords], axis=1)
447        else:
448            points = next_coords
449
450        if point_labels is not None:
451            point_labels = np.concatenate([point_labels, next_labels], axis=1)
452        else:
453            point_labels = next_labels
454
455        if use_masks:
456            logits_masks = torch.stack([m["logits"] for m in batched_outputs])
457
458        _save_segmentation(masks, prediction_paths[iteration])
459
460
461def run_inference_with_iterative_prompting(
462    predictor: SamPredictor,
463    image_paths: List[Union[str, os.PathLike]],
464    gt_paths: List[Union[str, os.PathLike]],
465    embedding_dir: Union[str, os.PathLike],
466    prediction_dir: Union[str, os.PathLike],
467    start_with_box_prompt: bool = True,
468    dilation: int = 5,
469    batch_size: int = 32,
470    n_iterations: int = 8,
471    use_masks: bool = False
472) -> None:
473    """Run Segment Anything inference for multiple images using prompts iteratively
474    derived from model outputs and ground-truth.
475
476    Args:
477        predictor: The Segment Anything predictor.
478        image_paths: The image file paths.
479        gt_paths: The ground-truth segmentation file paths.
480        embedding_dir: The directory where the image embeddings will be saved or are already saved.
481        prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
482        start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
483        dilation: The dilation factor for the radius around the ground-truth object
484            around which points will not be sampled.
485        batch_size: The batch size used for batched predictions.
486        n_iterations: The number of iterations for iterative prompting.
487        use_masks: Whether to make use of logits from previous prompt-based segmentation.
488    """
489    if len(image_paths) != len(gt_paths):
490        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
491
492    # create all prediction folders for all intermediate iterations
493    for i in range(n_iterations):
494        os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True)
495
496    if use_masks:
497        print("The iterative prompting will make use of logits masks from previous iterations.")
498
499    for image_path, gt_path in tqdm(
500        zip(image_paths, gt_paths), total=len(image_paths),
501        desc="Run inference with iterative prompting for all images",
502    ):
503        image_name = os.path.basename(image_path)
504
505        # We skip the images that already have been segmented
506        prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)]
507        if all(os.path.exists(prediction_path) for prediction_path in prediction_paths):
508            continue
509
510        assert os.path.exists(image_path), image_path
511        assert os.path.exists(gt_path), gt_path
512
513        image = imageio.imread(image_path)
514        gt = imageio.imread(gt_path).astype("uint32")
515        gt = relabel_sequential(gt)[0]
516
517        if embedding_dir is None:
518            embedding_path = None
519        else:
520            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
521
522        _run_inference_with_iterative_prompting_for_image(
523            predictor, image, gt, start_with_box_prompt=start_with_box_prompt,
524            dilation=dilation, batch_size=batch_size, embedding_path=embedding_path,
525            n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks
526        )
527
528
529#
530# AMG FUNCTION
531#
532
533
534def run_amg(
535    checkpoint: Union[str, os.PathLike],
536    model_type: str,
537    experiment_folder: Union[str, os.PathLike],
538    val_image_paths: List[Union[str, os.PathLike]],
539    val_gt_paths: List[Union[str, os.PathLike]],
540    test_image_paths: List[Union[str, os.PathLike]],
541    iou_thresh_values: Optional[List[float]] = None,
542    stability_score_values: Optional[List[float]] = None,
543    peft_kwargs: Optional[Dict] = None,
544    cache_embeddings: bool = False,
545    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
546) -> str:
547    """Run Segment Anything inference for multiple images using automatic mask generation (AMG).
548
549    Args:
550        checkpoint: The filepath to model checkpoints.
551        model_type: The segment anything model choice.
552        experimet_folder: The directory where the relevant files are saved.
553        val_image_paths: The list of filepaths of input images for grid-search.
554        val_gt_paths: The list of filepaths of corresponding labels for grid-search.
555        test_image_paths: The list of filepaths of input images for automatic instance segmentation.
556        iou_thresh_values: Optional choice of values for grid search of `iou_thresh` parameter.
557        stability_score_values: Optional choice of values for grid search of `stability_score` parameter.
558        peft_kwargs: Keyword arguments for th PEFT wrapper class.
559        cache_embeddings: Whether to cache embeddings in experiment folder.
560        tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
561
562    Returns:
563        Filepath where the predictions have been saved.
564    """
565
566    if cache_embeddings:
567        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
568        os.makedirs(embedding_folder, exist_ok=True)
569    else:
570        embedding_folder = None
571
572    predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs)
573
574    # Get the AMG class.
575    if tiling_window_params:
576        if not isinstance(tiling_window_params, dict):
577            raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.")
578
579        if "tile_shape" not in tiling_window_params:
580            raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.")
581
582        if "halo" not in tiling_window_params:
583            raise RuntimeError("'halo' parameter is missing from the provided parameters.")
584
585        amg_class = TiledAutomaticMaskGenerator
586    else:
587        amg_class = AutomaticMaskGenerator
588
589    amg = amg_class(predictor)
590    amg_prefix = "amg"
591
592    # where the predictions are saved
593    prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference")
594    os.makedirs(prediction_folder, exist_ok=True)
595
596    # where the grid-search results are saved
597    gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search")
598    os.makedirs(gs_result_folder, exist_ok=True)
599
600    grid_search_values = instance_segmentation.default_grid_search_values_amg(
601        iou_thresh_values=iou_thresh_values,
602        stability_score_values=stability_score_values,
603    )
604
605    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
606        segmenter=amg,
607        grid_search_values=grid_search_values,
608        val_image_paths=val_image_paths,
609        val_gt_paths=val_gt_paths,
610        test_image_paths=test_image_paths,
611        embedding_dir=embedding_folder,
612        prediction_dir=prediction_folder,
613        result_dir=gs_result_folder,
614        experiment_folder=experiment_folder,
615        tiling_window_params=tiling_window_params,
616    )
617    return prediction_folder
618
619
620def run_apg(
621    checkpoint: Optional[Union[str, os.PathLike]],
622    model_type: str,
623    experiment_folder: Union[str, os.PathLike],
624    val_image_paths: List[Union[str, os.PathLike]],
625    val_gt_paths: List[Union[str, os.PathLike]],
626    test_image_paths: List[Union[str, os.PathLike]],
627    peft_kwargs: Optional[Dict] = None,
628    cache_embeddings: bool = False,
629    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
630) -> str:
631    """Run Segment Anything inference for multiple images using automatic prompt generation (APG).
632
633    Args:
634        ...
635
636    Returns:
637        Filepath where the predictions have been saved.
638    """
639    if cache_embeddings:
640        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
641        os.makedirs(embedding_folder, exist_ok=True)
642    else:
643        embedding_folder = None
644
645    predictor, decoder = get_predictor_and_decoder(
646        model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
647    )
648
649    # Get the APG class.
650    if tiling_window_params:
651        raise NotImplementedError
652    else:
653        apg_class = AutomaticPromptGenerator
654
655    segmenter = apg_class(predictor, decoder)
656    seg_prefix = "apg"
657
658    # where the predictions are saved
659    prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference")
660    os.makedirs(prediction_folder, exist_ok=True)
661
662    # where the grid-search results are saved
663    gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search")
664    os.makedirs(gs_result_folder, exist_ok=True)
665
666    grid_search_values = instance_segmentation.default_grid_search_values_apg()
667
668    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
669        segmenter=segmenter,
670        grid_search_values=grid_search_values,
671        val_image_paths=val_image_paths,
672        val_gt_paths=val_gt_paths,
673        test_image_paths=test_image_paths,
674        embedding_dir=embedding_folder,
675        prediction_dir=prediction_folder,
676        result_dir=gs_result_folder,
677        experiment_folder=experiment_folder,
678        tiling_window_params=tiling_window_params,
679    )
680    return prediction_folder
681
682
683#
684# INSTANCE SEGMENTATION FUNCTION
685#
686
687
688def run_instance_segmentation_with_decoder(
689    checkpoint: Union[str, os.PathLike],
690    model_type: str,
691    experiment_folder: Union[str, os.PathLike],
692    val_image_paths: List[Union[str, os.PathLike]],
693    val_gt_paths: List[Union[str, os.PathLike]],
694    test_image_paths: List[Union[str, os.PathLike]],
695    peft_kwargs: Optional[Dict] = None,
696    cache_embeddings: bool = False,
697    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
698) -> str:
699    """Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS).
700
701    Args:
702        checkpoint: The filepath to model checkpoints.
703        model_type: The segment anything model choice.
704        experimet_folder: The directory where the relevant files are saved.
705        val_image_paths: The list of filepaths of input images for grid-search.
706        val_gt_paths: The list of filepaths of corresponding labels for grid-search.
707        test_image_paths: The list of filepaths of input images for automatic instance segmentation.
708        peft_kwargs: Keyword arguments for th PEFT wrapper class.
709        cache_embeddings: Whether to cache embeddings in experiment folder.
710        tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
711
712    Returns:
713        Filepath where the predictions have been saved.
714    """
715
716    if cache_embeddings:
717        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
718        os.makedirs(embedding_folder, exist_ok=True)
719    else:
720        embedding_folder = None
721
722    predictor, decoder = get_predictor_and_decoder(
723        model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
724    )
725
726    # Get the AIS class.
727    if tiling_window_params:
728        if not isinstance(tiling_window_params, dict):
729            raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.")
730
731        if "tile_shape" not in tiling_window_params:
732            raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.")
733
734        if "halo" not in tiling_window_params:
735            raise RuntimeError("'halo' parameter is missing from the provided parameters.")
736
737        ais_class = TiledInstanceSegmentationWithDecoder
738    else:
739        ais_class = InstanceSegmentationWithDecoder
740
741    segmenter = ais_class(predictor, decoder)
742    seg_prefix = "instance_segmentation_with_decoder"
743
744    # where the predictions are saved
745    prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference")
746    os.makedirs(prediction_folder, exist_ok=True)
747
748    # where the grid-search results are saved
749    gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search")
750    os.makedirs(gs_result_folder, exist_ok=True)
751
752    grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder()
753
754    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
755        segmenter=segmenter,
756        grid_search_values=grid_search_values,
757        val_image_paths=val_image_paths,
758        val_gt_paths=val_gt_paths,
759        test_image_paths=test_image_paths,
760        embedding_dir=embedding_folder,
761        prediction_dir=prediction_folder,
762        result_dir=gs_result_folder,
763        experiment_folder=experiment_folder,
764        tiling_window_params=tiling_window_params,
765    )
766    return prediction_folder
def precompute_all_embeddings( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike]) -> None:
135def precompute_all_embeddings(
136    predictor: SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike],
137) -> None:
138    """Precompute all image embeddings.
139
140    To enable running different inference tasks in parallel afterwards.
141
142    Args:
143        predictor: The Segment Anything predictor.
144        image_paths: The image file paths.
145        embedding_dir: The directory where the embeddings will be saved.
146    """
147    for image_path in tqdm(image_paths, desc="Precompute embeddings"):
148        image_name = os.path.basename(image_path)
149        im = imageio.imread(image_path)
150        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
151        util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)

Precompute all image embeddings.

To enable running different inference tasks in parallel afterwards.

Arguments:
  • predictor: The Segment Anything predictor.
  • image_paths: The image file paths.
  • embedding_dir: The directory where the embeddings will be saved.
def precompute_all_prompts( gt_paths: List[Union[str, os.PathLike]], prompt_save_dir: Union[str, os.PathLike], prompt_settings: List[Dict[str, Any]]) -> None:
170def precompute_all_prompts(
171    gt_paths: List[Union[str, os.PathLike]],
172    prompt_save_dir: Union[str, os.PathLike],
173    prompt_settings: List[Dict[str, Any]],
174) -> None:
175    """Precompute all point prompts.
176
177    To enable running different inference tasks in parallel afterwards.
178
179    Args:
180        gt_paths: The file paths to the ground-truth segmentations.
181        prompt_save_dir: The directory where the prompt files will be saved.
182        prompt_settings: The settings for which the prompts will be computed.
183    """
184    os.makedirs(prompt_save_dir, exist_ok=True)
185
186    for settings in tqdm(prompt_settings, desc="Precompute prompts"):
187
188        use_points, use_boxes = settings["use_points"], settings["use_boxes"]
189        n_positives, n_negatives = settings["n_positives"], settings["n_negatives"]
190        dilation = settings.get("dilation", 5)
191
192        # check if the prompts were already computed
193        if use_boxes and not use_points:
194            prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl")
195        else:
196            prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl")
197        if os.path.exists(prompt_save_path):
198            continue
199
200        results = []
201        for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"):
202            prompts = _precompute_prompts(
203                gt_path,
204                use_points=use_points,
205                use_boxes=use_boxes,
206                n_positives=n_positives,
207                n_negatives=n_negatives,
208                dilation=dilation,
209            )
210            results.append(prompts)
211
212        saved_prompts = {res[0]: res[1] for res in results}
213        with open(prompt_save_path, "wb") as f:
214            pickle.dump(saved_prompts, f)

Precompute all point prompts.

To enable running different inference tasks in parallel afterwards.

Arguments:
  • gt_paths: The file paths to the ground-truth segmentations.
  • prompt_save_dir: The directory where the prompt files will be saved.
  • prompt_settings: The settings for which the prompts will be computed.
def run_inference_with_prompts( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], gt_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], use_points: bool, use_boxes: bool, n_positives: int, n_negatives: int, dilation: int = 5, prompt_save_dir: Union[str, os.PathLike, NoneType] = None, batch_size: int = 512) -> None:
255def run_inference_with_prompts(
256    predictor: SamPredictor,
257    image_paths: List[Union[str, os.PathLike]],
258    gt_paths: List[Union[str, os.PathLike]],
259    embedding_dir: Union[str, os.PathLike],
260    prediction_dir: Union[str, os.PathLike],
261    use_points: bool,
262    use_boxes: bool,
263    n_positives: int,
264    n_negatives: int,
265    dilation: int = 5,
266    prompt_save_dir: Optional[Union[str, os.PathLike]] = None,
267    batch_size: int = 512,
268) -> None:
269    """Run segment anything inference for multiple images using prompts derived from groundtruth.
270
271    Args:
272        predictor: The Segment Anything predictor.
273        image_paths: The image file paths.
274        gt_paths: The ground-truth segmentation file paths.
275        embedding_dir: The directory where the image embddings will be saved or are already saved.
276        use_points: Whether to use point prompts.
277        use_boxes: Whether to use box prompts
278        n_positives: The number of positive point prompts that will be sampled.
279        n_negativess: The number of negative point prompts that will be sampled.
280        dilation: The dilation factor for the radius around the ground-truth object
281            around which points will not be sampled.
282        prompt_save_dir: The directory where point prompts will be saved or are already saved.
283            This enables running multiple experiments in a reproducible manner.
284        batch_size: The batch size used for batched prediction.
285    """
286    if not (use_points or use_boxes):
287        raise ValueError("You need to use at least one of point or box prompts.")
288
289    if len(image_paths) != len(gt_paths):
290        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
291
292    (cached_point_prompts, save_point_prompts, point_prompt_save_path,
293     cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching(
294         prompt_save_dir, use_points, use_boxes, n_positives, n_negatives
295     )
296
297    os.makedirs(prediction_dir, exist_ok=True)
298    for image_path, gt_path in tqdm(
299        zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts"
300    ):
301        image_name = os.path.basename(image_path)
302        label_name = os.path.basename(gt_path)
303
304        # We skip the images that already have been segmented.
305        prediction_path = os.path.join(prediction_dir, image_name)
306        if os.path.exists(prediction_path):
307            continue
308
309        assert os.path.exists(image_path), image_path
310        assert os.path.exists(gt_path), gt_path
311
312        im = imageio.imread(image_path)
313        gt = imageio.imread(gt_path).astype("uint32")
314        gt = relabel_sequential(gt)[0]
315
316        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
317        this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts(
318            cached_point_prompts, save_point_prompts,
319            cached_box_prompts, save_box_prompts,
320            label_name
321        )
322        instances, this_prompts = _run_inference_with_prompts_for_image(
323            predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives,
324            dilation=dilation, use_points=use_points, use_boxes=use_boxes,
325            batch_size=batch_size, cached_prompts=this_prompts,
326            embedding_path=embedding_path,
327        )
328
329        if save_point_prompts:
330            cached_point_prompts[label_name] = this_prompts[:2]
331        if save_box_prompts:
332            cached_box_prompts[label_name] = this_prompts[-1]
333
334        # It's important to compress here, otherwise the predictions would take up a lot of space.
335        imageio.imwrite(prediction_path, instances, compression=5)
336
337    # Save the prompts if we run experiments with prompt caching and have computed them
338    # for the first time.
339    if save_point_prompts:
340        with open(point_prompt_save_path, "wb") as f:
341            pickle.dump(cached_point_prompts, f)
342    if save_box_prompts:
343        with open(box_prompt_save_path, "wb") as f:
344            pickle.dump(cached_box_prompts, f)

Run segment anything inference for multiple images using prompts derived from groundtruth.

Arguments:
  • predictor: The Segment Anything predictor.
  • image_paths: The image file paths.
  • gt_paths: The ground-truth segmentation file paths.
  • embedding_dir: The directory where the image embddings will be saved or are already saved.
  • use_points: Whether to use point prompts.
  • use_boxes: Whether to use box prompts
  • n_positives: The number of positive point prompts that will be sampled.
  • n_negativess: The number of negative point prompts that will be sampled.
  • dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
  • prompt_save_dir: The directory where point prompts will be saved or are already saved. This enables running multiple experiments in a reproducible manner.
  • batch_size: The batch size used for batched prediction.
def run_inference_with_iterative_prompting( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], gt_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], start_with_box_prompt: bool = True, dilation: int = 5, batch_size: int = 32, n_iterations: int = 8, use_masks: bool = False) -> None:
462def run_inference_with_iterative_prompting(
463    predictor: SamPredictor,
464    image_paths: List[Union[str, os.PathLike]],
465    gt_paths: List[Union[str, os.PathLike]],
466    embedding_dir: Union[str, os.PathLike],
467    prediction_dir: Union[str, os.PathLike],
468    start_with_box_prompt: bool = True,
469    dilation: int = 5,
470    batch_size: int = 32,
471    n_iterations: int = 8,
472    use_masks: bool = False
473) -> None:
474    """Run Segment Anything inference for multiple images using prompts iteratively
475    derived from model outputs and ground-truth.
476
477    Args:
478        predictor: The Segment Anything predictor.
479        image_paths: The image file paths.
480        gt_paths: The ground-truth segmentation file paths.
481        embedding_dir: The directory where the image embeddings will be saved or are already saved.
482        prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
483        start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
484        dilation: The dilation factor for the radius around the ground-truth object
485            around which points will not be sampled.
486        batch_size: The batch size used for batched predictions.
487        n_iterations: The number of iterations for iterative prompting.
488        use_masks: Whether to make use of logits from previous prompt-based segmentation.
489    """
490    if len(image_paths) != len(gt_paths):
491        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
492
493    # create all prediction folders for all intermediate iterations
494    for i in range(n_iterations):
495        os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True)
496
497    if use_masks:
498        print("The iterative prompting will make use of logits masks from previous iterations.")
499
500    for image_path, gt_path in tqdm(
501        zip(image_paths, gt_paths), total=len(image_paths),
502        desc="Run inference with iterative prompting for all images",
503    ):
504        image_name = os.path.basename(image_path)
505
506        # We skip the images that already have been segmented
507        prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)]
508        if all(os.path.exists(prediction_path) for prediction_path in prediction_paths):
509            continue
510
511        assert os.path.exists(image_path), image_path
512        assert os.path.exists(gt_path), gt_path
513
514        image = imageio.imread(image_path)
515        gt = imageio.imread(gt_path).astype("uint32")
516        gt = relabel_sequential(gt)[0]
517
518        if embedding_dir is None:
519            embedding_path = None
520        else:
521            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
522
523        _run_inference_with_iterative_prompting_for_image(
524            predictor, image, gt, start_with_box_prompt=start_with_box_prompt,
525            dilation=dilation, batch_size=batch_size, embedding_path=embedding_path,
526            n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks
527        )

Run Segment Anything inference for multiple images using prompts iteratively derived from model outputs and ground-truth.

Arguments:
  • predictor: The Segment Anything predictor.
  • image_paths: The image file paths.
  • gt_paths: The ground-truth segmentation file paths.
  • embedding_dir: The directory where the image embeddings will be saved or are already saved.
  • prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
  • start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
  • dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
  • batch_size: The batch size used for batched predictions.
  • n_iterations: The number of iterations for iterative prompting.
  • use_masks: Whether to make use of logits from previous prompt-based segmentation.
def run_amg( checkpoint: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, peft_kwargs: Optional[Dict] = None, cache_embeddings: bool = False, tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None) -> str:
535def run_amg(
536    checkpoint: Union[str, os.PathLike],
537    model_type: str,
538    experiment_folder: Union[str, os.PathLike],
539    val_image_paths: List[Union[str, os.PathLike]],
540    val_gt_paths: List[Union[str, os.PathLike]],
541    test_image_paths: List[Union[str, os.PathLike]],
542    iou_thresh_values: Optional[List[float]] = None,
543    stability_score_values: Optional[List[float]] = None,
544    peft_kwargs: Optional[Dict] = None,
545    cache_embeddings: bool = False,
546    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
547) -> str:
548    """Run Segment Anything inference for multiple images using automatic mask generation (AMG).
549
550    Args:
551        checkpoint: The filepath to model checkpoints.
552        model_type: The segment anything model choice.
553        experimet_folder: The directory where the relevant files are saved.
554        val_image_paths: The list of filepaths of input images for grid-search.
555        val_gt_paths: The list of filepaths of corresponding labels for grid-search.
556        test_image_paths: The list of filepaths of input images for automatic instance segmentation.
557        iou_thresh_values: Optional choice of values for grid search of `iou_thresh` parameter.
558        stability_score_values: Optional choice of values for grid search of `stability_score` parameter.
559        peft_kwargs: Keyword arguments for th PEFT wrapper class.
560        cache_embeddings: Whether to cache embeddings in experiment folder.
561        tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
562
563    Returns:
564        Filepath where the predictions have been saved.
565    """
566
567    if cache_embeddings:
568        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
569        os.makedirs(embedding_folder, exist_ok=True)
570    else:
571        embedding_folder = None
572
573    predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs)
574
575    # Get the AMG class.
576    if tiling_window_params:
577        if not isinstance(tiling_window_params, dict):
578            raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.")
579
580        if "tile_shape" not in tiling_window_params:
581            raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.")
582
583        if "halo" not in tiling_window_params:
584            raise RuntimeError("'halo' parameter is missing from the provided parameters.")
585
586        amg_class = TiledAutomaticMaskGenerator
587    else:
588        amg_class = AutomaticMaskGenerator
589
590    amg = amg_class(predictor)
591    amg_prefix = "amg"
592
593    # where the predictions are saved
594    prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference")
595    os.makedirs(prediction_folder, exist_ok=True)
596
597    # where the grid-search results are saved
598    gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search")
599    os.makedirs(gs_result_folder, exist_ok=True)
600
601    grid_search_values = instance_segmentation.default_grid_search_values_amg(
602        iou_thresh_values=iou_thresh_values,
603        stability_score_values=stability_score_values,
604    )
605
606    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
607        segmenter=amg,
608        grid_search_values=grid_search_values,
609        val_image_paths=val_image_paths,
610        val_gt_paths=val_gt_paths,
611        test_image_paths=test_image_paths,
612        embedding_dir=embedding_folder,
613        prediction_dir=prediction_folder,
614        result_dir=gs_result_folder,
615        experiment_folder=experiment_folder,
616        tiling_window_params=tiling_window_params,
617    )
618    return prediction_folder

Run Segment Anything inference for multiple images using automatic mask generation (AMG).

Arguments:
  • checkpoint: The filepath to model checkpoints.
  • model_type: The segment anything model choice.
  • experimet_folder: The directory where the relevant files are saved.
  • val_image_paths: The list of filepaths of input images for grid-search.
  • val_gt_paths: The list of filepaths of corresponding labels for grid-search.
  • test_image_paths: The list of filepaths of input images for automatic instance segmentation.
  • iou_thresh_values: Optional choice of values for grid search of iou_thresh parameter.
  • stability_score_values: Optional choice of values for grid search of stability_score parameter.
  • peft_kwargs: Keyword arguments for th PEFT wrapper class.
  • cache_embeddings: Whether to cache embeddings in experiment folder.
  • tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
Returns:

Filepath where the predictions have been saved.

def run_apg( checkpoint: Union[str, os.PathLike, NoneType], model_type: str, experiment_folder: Union[str, os.PathLike], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], peft_kwargs: Optional[Dict] = None, cache_embeddings: bool = False, tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None) -> str:
621def run_apg(
622    checkpoint: Optional[Union[str, os.PathLike]],
623    model_type: str,
624    experiment_folder: Union[str, os.PathLike],
625    val_image_paths: List[Union[str, os.PathLike]],
626    val_gt_paths: List[Union[str, os.PathLike]],
627    test_image_paths: List[Union[str, os.PathLike]],
628    peft_kwargs: Optional[Dict] = None,
629    cache_embeddings: bool = False,
630    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
631) -> str:
632    """Run Segment Anything inference for multiple images using automatic prompt generation (APG).
633
634    Args:
635        ...
636
637    Returns:
638        Filepath where the predictions have been saved.
639    """
640    if cache_embeddings:
641        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
642        os.makedirs(embedding_folder, exist_ok=True)
643    else:
644        embedding_folder = None
645
646    predictor, decoder = get_predictor_and_decoder(
647        model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
648    )
649
650    # Get the APG class.
651    if tiling_window_params:
652        raise NotImplementedError
653    else:
654        apg_class = AutomaticPromptGenerator
655
656    segmenter = apg_class(predictor, decoder)
657    seg_prefix = "apg"
658
659    # where the predictions are saved
660    prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference")
661    os.makedirs(prediction_folder, exist_ok=True)
662
663    # where the grid-search results are saved
664    gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search")
665    os.makedirs(gs_result_folder, exist_ok=True)
666
667    grid_search_values = instance_segmentation.default_grid_search_values_apg()
668
669    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
670        segmenter=segmenter,
671        grid_search_values=grid_search_values,
672        val_image_paths=val_image_paths,
673        val_gt_paths=val_gt_paths,
674        test_image_paths=test_image_paths,
675        embedding_dir=embedding_folder,
676        prediction_dir=prediction_folder,
677        result_dir=gs_result_folder,
678        experiment_folder=experiment_folder,
679        tiling_window_params=tiling_window_params,
680    )
681    return prediction_folder

Run Segment Anything inference for multiple images using automatic prompt generation (APG).

Arguments:
  • ...
Returns:

Filepath where the predictions have been saved.

def run_instance_segmentation_with_decoder( checkpoint: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], peft_kwargs: Optional[Dict] = None, cache_embeddings: bool = False, tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None) -> str:
689def run_instance_segmentation_with_decoder(
690    checkpoint: Union[str, os.PathLike],
691    model_type: str,
692    experiment_folder: Union[str, os.PathLike],
693    val_image_paths: List[Union[str, os.PathLike]],
694    val_gt_paths: List[Union[str, os.PathLike]],
695    test_image_paths: List[Union[str, os.PathLike]],
696    peft_kwargs: Optional[Dict] = None,
697    cache_embeddings: bool = False,
698    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
699) -> str:
700    """Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS).
701
702    Args:
703        checkpoint: The filepath to model checkpoints.
704        model_type: The segment anything model choice.
705        experimet_folder: The directory where the relevant files are saved.
706        val_image_paths: The list of filepaths of input images for grid-search.
707        val_gt_paths: The list of filepaths of corresponding labels for grid-search.
708        test_image_paths: The list of filepaths of input images for automatic instance segmentation.
709        peft_kwargs: Keyword arguments for th PEFT wrapper class.
710        cache_embeddings: Whether to cache embeddings in experiment folder.
711        tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
712
713    Returns:
714        Filepath where the predictions have been saved.
715    """
716
717    if cache_embeddings:
718        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
719        os.makedirs(embedding_folder, exist_ok=True)
720    else:
721        embedding_folder = None
722
723    predictor, decoder = get_predictor_and_decoder(
724        model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
725    )
726
727    # Get the AIS class.
728    if tiling_window_params:
729        if not isinstance(tiling_window_params, dict):
730            raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.")
731
732        if "tile_shape" not in tiling_window_params:
733            raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.")
734
735        if "halo" not in tiling_window_params:
736            raise RuntimeError("'halo' parameter is missing from the provided parameters.")
737
738        ais_class = TiledInstanceSegmentationWithDecoder
739    else:
740        ais_class = InstanceSegmentationWithDecoder
741
742    segmenter = ais_class(predictor, decoder)
743    seg_prefix = "instance_segmentation_with_decoder"
744
745    # where the predictions are saved
746    prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference")
747    os.makedirs(prediction_folder, exist_ok=True)
748
749    # where the grid-search results are saved
750    gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search")
751    os.makedirs(gs_result_folder, exist_ok=True)
752
753    grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder()
754
755    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
756        segmenter=segmenter,
757        grid_search_values=grid_search_values,
758        val_image_paths=val_image_paths,
759        val_gt_paths=val_gt_paths,
760        test_image_paths=test_image_paths,
761        embedding_dir=embedding_folder,
762        prediction_dir=prediction_folder,
763        result_dir=gs_result_folder,
764        experiment_folder=experiment_folder,
765        tiling_window_params=tiling_window_params,
766    )
767    return prediction_folder

Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS).

Arguments:
  • checkpoint: The filepath to model checkpoints.
  • model_type: The segment anything model choice.
  • experimet_folder: The directory where the relevant files are saved.
  • val_image_paths: The list of filepaths of input images for grid-search.
  • val_gt_paths: The list of filepaths of corresponding labels for grid-search.
  • test_image_paths: The list of filepaths of input images for automatic instance segmentation.
  • peft_kwargs: Keyword arguments for th PEFT wrapper class.
  • cache_embeddings: Whether to cache embeddings in experiment folder.
  • tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
Returns:

Filepath where the predictions have been saved.