micro_sam.evaluation.inference

Inference with Segment Anything models and different prompt strategies.

  1"""Inference with Segment Anything models and different prompt strategies.
  2"""
  3
  4import os
  5import pickle
  6import numpy as np
  7from tqdm import tqdm
  8from copy import deepcopy
  9from typing import Any, Dict, List, Optional, Union
 10
 11import imageio.v3 as imageio
 12from skimage.segmentation import relabel_sequential
 13
 14import torch
 15
 16from segment_anything import SamPredictor
 17
 18from .. import util as util
 19from ..inference import batched_inference
 20from ..instance_segmentation import (
 21    mask_data_to_segmentation, get_predictor_and_decoder,
 22    AutomaticMaskGenerator, InstanceSegmentationWithDecoder,
 23)
 24from . import instance_segmentation
 25from ..prompt_generators import PointAndBoxPromptGenerator, IterativePromptGenerator
 26
 27
 28def _load_prompts(
 29    cached_point_prompts, save_point_prompts,
 30    cached_box_prompts, save_box_prompts,
 31    image_name
 32):
 33
 34    def load_prompt_type(cached_prompts, save_prompts):
 35        # Check if we have saved prompts.
 36        if cached_prompts is None or save_prompts:  # we don't have cached prompts
 37            return cached_prompts, None
 38
 39        # we have cached prompts, but they have not been loaded yet
 40        if isinstance(cached_prompts, str):
 41            with open(cached_prompts, "rb") as f:
 42                cached_prompts = pickle.load(f)
 43
 44        prompts = cached_prompts[image_name]
 45        return cached_prompts, prompts
 46
 47    cached_point_prompts, point_prompts = load_prompt_type(cached_point_prompts, save_point_prompts)
 48    cached_box_prompts, box_prompts = load_prompt_type(cached_box_prompts, save_box_prompts)
 49
 50    # we don't have anything cached
 51    if point_prompts is None and box_prompts is None:
 52        return None, cached_point_prompts, cached_box_prompts
 53
 54    if point_prompts is None:
 55        input_point, input_label = [], []
 56    else:
 57        input_point, input_label = point_prompts
 58
 59    if box_prompts is None:
 60        input_box = []
 61    else:
 62        input_box = box_prompts
 63
 64    prompts = (input_point, input_label, input_box)
 65    return prompts, cached_point_prompts, cached_box_prompts
 66
 67
 68def _get_batched_prompts(
 69    gt,
 70    gt_ids,
 71    use_points,
 72    use_boxes,
 73    n_positives,
 74    n_negatives,
 75    dilation,
 76):
 77    # Initialize the prompt generator.
 78    prompt_generator = PointAndBoxPromptGenerator(
 79        n_positive_points=n_positives, n_negative_points=n_negatives,
 80        dilation_strength=dilation, get_point_prompts=use_points,
 81        get_box_prompts=use_boxes
 82    )
 83
 84    # Generate the prompts.
 85    center_coordinates, bbox_coordinates = util.get_centers_and_bounding_boxes(gt)
 86    center_coordinates = [center_coordinates[gt_id] for gt_id in gt_ids]
 87    bbox_coordinates = [bbox_coordinates[gt_id] for gt_id in gt_ids]
 88    masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids)
 89
 90    points, point_labels, boxes, _ = prompt_generator(
 91        masks, bbox_coordinates, center_coordinates
 92    )
 93
 94    def to_numpy(x):
 95        if x is None:
 96            return x
 97        return x.numpy()
 98
 99    return to_numpy(points), to_numpy(point_labels), to_numpy(boxes)
100
101
102def _run_inference_with_prompts_for_image(
103    predictor,
104    image,
105    gt,
106    use_points,
107    use_boxes,
108    n_positives,
109    n_negatives,
110    dilation,
111    batch_size,
112    cached_prompts,
113    embedding_path,
114):
115    gt_ids = np.unique(gt)[1:]
116    if cached_prompts is None:
117        points, point_labels, boxes = _get_batched_prompts(
118            gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation
119        )
120    else:
121        points, point_labels, boxes = cached_prompts
122
123    # Make a copy of the point prompts to return them at the end.
124    prompts = deepcopy((points, point_labels, boxes))
125
126    # Use multi-masking only if we have a single positive point without box
127    multimasking = False
128    if not use_boxes and (n_positives == 1 and n_negatives == 0):
129        multimasking = True
130
131    instance_labels = batched_inference(
132        predictor, image, batch_size,
133        boxes=boxes, points=points, point_labels=point_labels,
134        multimasking=multimasking, embedding_path=embedding_path,
135        return_instance_segmentation=True,
136    )
137
138    return instance_labels, prompts
139
140
141def precompute_all_embeddings(
142    predictor: SamPredictor,
143    image_paths: List[Union[str, os.PathLike]],
144    embedding_dir: Union[str, os.PathLike],
145) -> None:
146    """Precompute all image embeddings.
147
148    To enable running different inference tasks in parallel afterwards.
149
150    Args:
151        predictor: The SegmentAnything predictor.
152        image_paths: The image file paths.
153        embedding_dir: The directory where the embeddings will be saved.
154    """
155    for image_path in tqdm(image_paths, desc="Precompute embeddings"):
156        image_name = os.path.basename(image_path)
157        im = imageio.imread(image_path)
158        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
159        util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)
160
161
162def _precompute_prompts(gt_path, use_points, use_boxes, n_positives, n_negatives, dilation):
163    name = os.path.basename(gt_path)
164
165    gt = imageio.imread(gt_path).astype("uint32")
166    gt = relabel_sequential(gt)[0]
167    gt_ids = np.unique(gt)[1:]
168
169    input_point, input_label, input_box = _get_batched_prompts(
170        gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation
171    )
172
173    if use_boxes and not use_points:
174        return name, input_box
175    return name, (input_point, input_label)
176
177
178def precompute_all_prompts(
179    gt_paths: List[Union[str, os.PathLike]],
180    prompt_save_dir: Union[str, os.PathLike],
181    prompt_settings: List[Dict[str, Any]],
182) -> None:
183    """Precompute all point prompts.
184
185    To enable running different inference tasks in parallel afterwards.
186
187    Args:
188        gt_paths: The file paths to the ground-truth segmentations.
189        prompt_save_dir: The directory where the prompt files will be saved.
190        prompt_settings: The settings for which the prompts will be computed.
191    """
192    os.makedirs(prompt_save_dir, exist_ok=True)
193
194    for settings in tqdm(prompt_settings, desc="Precompute prompts"):
195
196        use_points, use_boxes = settings["use_points"], settings["use_boxes"]
197        n_positives, n_negatives = settings["n_positives"], settings["n_negatives"]
198        dilation = settings.get("dilation", 5)
199
200        # check if the prompts were already computed
201        if use_boxes and not use_points:
202            prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl")
203        else:
204            prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl")
205        if os.path.exists(prompt_save_path):
206            continue
207
208        results = []
209        for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"):
210            prompts = _precompute_prompts(
211                gt_path,
212                use_points=use_points,
213                use_boxes=use_boxes,
214                n_positives=n_positives,
215                n_negatives=n_negatives,
216                dilation=dilation,
217            )
218            results.append(prompts)
219
220        saved_prompts = {res[0]: res[1] for res in results}
221        with open(prompt_save_path, "wb") as f:
222            pickle.dump(saved_prompts, f)
223
224
225def _get_prompt_caching(prompt_save_dir, use_points, use_boxes, n_positives, n_negatives):
226
227    def get_prompt_type_caching(use_type, save_name):
228        if not use_type:
229            return None, False, None
230
231        prompt_save_path = os.path.join(prompt_save_dir, save_name)
232        if os.path.exists(prompt_save_path):
233            print("Using precomputed prompts from", prompt_save_path)
234            # We delay loading the prompts, so we only have to load them once they're needed the first time.
235            # This avoids loading the prompts (which are in a big pickle file) if all predictions are done already.
236            cached_prompts = prompt_save_path
237            save_prompts = False
238        else:
239            print("Saving prompts in", prompt_save_path)
240            cached_prompts = {}
241            save_prompts = True
242        return cached_prompts, save_prompts, prompt_save_path
243
244    # Check if prompt serialization is enabled.
245    # If it is then load the prompts if they are already cached and otherwise store them.
246    if prompt_save_dir is None:
247        print("Prompts are not cached.")
248        cached_point_prompts, cached_box_prompts = None, None
249        save_point_prompts, save_box_prompts = False, False
250        point_prompt_save_path, box_prompt_save_path = None, None
251    else:
252        cached_point_prompts, save_point_prompts, point_prompt_save_path = get_prompt_type_caching(
253            use_points, f"points-p{n_positives}-n{n_negatives}.pkl"
254        )
255        cached_box_prompts, save_box_prompts, box_prompt_save_path = get_prompt_type_caching(
256            use_boxes, "boxes.pkl"
257        )
258
259    return (cached_point_prompts, save_point_prompts, point_prompt_save_path,
260            cached_box_prompts, save_box_prompts, box_prompt_save_path)
261
262
263def run_inference_with_prompts(
264    predictor: SamPredictor,
265    image_paths: List[Union[str, os.PathLike]],
266    gt_paths: List[Union[str, os.PathLike]],
267    embedding_dir: Union[str, os.PathLike],
268    prediction_dir: Union[str, os.PathLike],
269    use_points: bool,
270    use_boxes: bool,
271    n_positives: int,
272    n_negatives: int,
273    dilation: int = 5,
274    prompt_save_dir: Optional[Union[str, os.PathLike]] = None,
275    batch_size: int = 512,
276) -> None:
277    """Run segment anything inference for multiple images using prompts derived from groundtruth.
278
279    Args:
280        predictor: The SegmentAnything predictor.
281        image_paths: The image file paths.
282        gt_paths: The ground-truth segmentation file paths.
283        embedding_dir: The directory where the image embddings will be saved or are already saved.
284        use_points: Whether to use point prompts.
285        use_boxes: Whether to use box prompts
286        n_positives: The number of positive point prompts that will be sampled.
287        n_negativess: The number of negative point prompts that will be sampled.
288        dilation: The dilation factor for the radius around the ground-truth object
289            around which points will not be sampled.
290        prompt_save_dir: The directory where point prompts will be saved or are already saved.
291            This enables running multiple experiments in a reproducible manner.
292        batch_size: The batch size used for batched prediction.
293    """
294    if not (use_points or use_boxes):
295        raise ValueError("You need to use at least one of point or box prompts.")
296
297    if len(image_paths) != len(gt_paths):
298        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
299
300    (cached_point_prompts, save_point_prompts, point_prompt_save_path,
301     cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching(
302         prompt_save_dir, use_points, use_boxes, n_positives, n_negatives
303     )
304
305    os.makedirs(prediction_dir, exist_ok=True)
306    for image_path, gt_path in tqdm(
307        zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts"
308    ):
309        image_name = os.path.basename(image_path)
310        label_name = os.path.basename(gt_path)
311
312        # We skip the images that already have been segmented.
313        prediction_path = os.path.join(prediction_dir, image_name)
314        if os.path.exists(prediction_path):
315            continue
316
317        assert os.path.exists(image_path), image_path
318        assert os.path.exists(gt_path), gt_path
319
320        im = imageio.imread(image_path)
321        gt = imageio.imread(gt_path).astype("uint32")
322        gt = relabel_sequential(gt)[0]
323
324        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
325        this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts(
326            cached_point_prompts, save_point_prompts,
327            cached_box_prompts, save_box_prompts,
328            label_name
329        )
330        instances, this_prompts = _run_inference_with_prompts_for_image(
331            predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives,
332            dilation=dilation, use_points=use_points, use_boxes=use_boxes,
333            batch_size=batch_size, cached_prompts=this_prompts,
334            embedding_path=embedding_path,
335        )
336
337        if save_point_prompts:
338            cached_point_prompts[label_name] = this_prompts[:2]
339        if save_box_prompts:
340            cached_box_prompts[label_name] = this_prompts[-1]
341
342        # It's important to compress here, otherwise the predictions would take up a lot of space.
343        imageio.imwrite(prediction_path, instances, compression=5)
344
345    # Save the prompts if we run experiments with prompt caching and have computed them
346    # for the first time.
347    if save_point_prompts:
348        with open(point_prompt_save_path, "wb") as f:
349            pickle.dump(cached_point_prompts, f)
350    if save_box_prompts:
351        with open(box_prompt_save_path, "wb") as f:
352            pickle.dump(cached_box_prompts, f)
353
354
355def _save_segmentation(masks, prediction_path):
356    # masks to segmentation
357    masks = masks.cpu().numpy().squeeze(1).astype("bool")
358    masks = [{"segmentation": mask, "area": mask.sum()} for mask in masks]
359    segmentation = mask_data_to_segmentation(masks, with_background=True)
360    imageio.imwrite(prediction_path, segmentation, compression=5)
361
362
363def _get_batched_iterative_prompts(sampled_binary_gt, masks, batch_size, prompt_generator):
364    n_samples = sampled_binary_gt.shape[0]
365    n_batches = int(np.ceil(float(n_samples) / batch_size))
366    next_coords, next_labels = [], []
367    for batch_idx in range(n_batches):
368        batch_start = batch_idx * batch_size
369        batch_stop = min((batch_idx + 1) * batch_size, n_samples)
370
371        batch_coords, batch_labels, _, _ = prompt_generator(
372            sampled_binary_gt[batch_start: batch_stop], masks[batch_start: batch_stop]
373        )
374        next_coords.append(batch_coords)
375        next_labels.append(batch_labels)
376
377    next_coords = torch.concatenate(next_coords)
378    next_labels = torch.concatenate(next_labels)
379
380    return next_coords, next_labels
381
382
383@torch.no_grad()
384def _run_inference_with_iterative_prompting_for_image(
385    predictor,
386    image,
387    gt,
388    start_with_box_prompt,
389    dilation,
390    batch_size,
391    embedding_path,
392    n_iterations,
393    prediction_paths,
394    use_masks=False
395) -> None:
396    verbose_embeddings = False
397
398    prompt_generator = IterativePromptGenerator()
399
400    gt_ids = np.unique(gt)[1:]
401
402    # Use multi-masking only if we have a single positive point without box
403    if start_with_box_prompt:
404        use_boxes, use_points = True, False
405        n_positives = 0
406        multimasking = False
407    else:
408        use_boxes, use_points = False, True
409        n_positives = 1
410        multimasking = True
411
412    points, point_labels, boxes = _get_batched_prompts(
413        gt, gt_ids,
414        use_points=use_points,
415        use_boxes=use_boxes,
416        n_positives=n_positives,
417        n_negatives=0,
418        dilation=dilation
419    )
420
421    sampled_binary_gt = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids)
422
423    for iteration in range(n_iterations):
424        if iteration == 0:  # logits mask can not be used for the first iteration.
425            logits_masks = None
426        else:
427            if not use_masks:  # logits mask should not be used when not desired.
428                logits_masks = None
429
430        batched_outputs = batched_inference(
431            predictor=predictor,
432            image=image,
433            batch_size=batch_size,
434            boxes=boxes,
435            points=points,
436            point_labels=point_labels,
437            multimasking=multimasking,
438            embedding_path=embedding_path,
439            return_instance_segmentation=False,
440            logits_masks=logits_masks,
441            verbose_embeddings=verbose_embeddings,
442        )
443
444        # switching off multimasking after first iter, as next iters (with multiple prompts) don't expect multimasking
445        multimasking = False
446
447        masks = torch.stack([m["segmentation"][None] for m in batched_outputs]).to(torch.float32)
448
449        next_coords, next_labels = _get_batched_iterative_prompts(
450            sampled_binary_gt, masks, batch_size, prompt_generator
451        )
452        next_coords, next_labels = next_coords.detach().cpu().numpy(), next_labels.detach().cpu().numpy()
453
454        if points is not None:
455            points = np.concatenate([points, next_coords], axis=1)
456        else:
457            points = next_coords
458
459        if point_labels is not None:
460            point_labels = np.concatenate([point_labels, next_labels], axis=1)
461        else:
462            point_labels = next_labels
463
464        if use_masks:
465            logits_masks = torch.stack([m["logits"] for m in batched_outputs])
466
467        _save_segmentation(masks, prediction_paths[iteration])
468
469
470def run_inference_with_iterative_prompting(
471    predictor: SamPredictor,
472    image_paths: List[Union[str, os.PathLike]],
473    gt_paths: List[Union[str, os.PathLike]],
474    embedding_dir: Union[str, os.PathLike],
475    prediction_dir: Union[str, os.PathLike],
476    start_with_box_prompt: bool = True,
477    dilation: int = 5,
478    batch_size: int = 32,
479    n_iterations: int = 8,
480    use_masks: bool = False
481) -> None:
482    """Run Segment Anything inference for multiple images using prompts iteratively
483    derived from model outputs and ground-truth.
484
485    Args:
486        predictor: The Segment Anything predictor.
487        image_paths: The image file paths.
488        gt_paths: The ground-truth segmentation file paths.
489        embedding_dir: The directory where the image embeddings will be saved or are already saved.
490        prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
491        start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
492        dilation: The dilation factor for the radius around the ground-truth object
493            around which points will not be sampled.
494        batch_size: The batch size used for batched predictions.
495        n_iterations: The number of iterations for iterative prompting.
496        use_masks: Whether to make use of logits from previous prompt-based segmentation.
497    """
498    if len(image_paths) != len(gt_paths):
499        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
500
501    # create all prediction folders for all intermediate iterations
502    for i in range(n_iterations):
503        os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True)
504
505    if use_masks:
506        print("The iterative prompting will make use of logits masks from previous iterations.")
507
508    for image_path, gt_path in tqdm(
509        zip(image_paths, gt_paths), total=len(image_paths),
510        desc="Run inference with iterative prompting for all images",
511    ):
512        image_name = os.path.basename(image_path)
513
514        # We skip the images that already have been segmented
515        prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)]
516        if all(os.path.exists(prediction_path) for prediction_path in prediction_paths):
517            continue
518
519        assert os.path.exists(image_path), image_path
520        assert os.path.exists(gt_path), gt_path
521
522        image = imageio.imread(image_path)
523        gt = imageio.imread(gt_path).astype("uint32")
524        gt = relabel_sequential(gt)[0]
525
526        if embedding_dir is None:
527            embedding_path = None
528        else:
529            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
530
531        _run_inference_with_iterative_prompting_for_image(
532            predictor, image, gt, start_with_box_prompt=start_with_box_prompt,
533            dilation=dilation, batch_size=batch_size, embedding_path=embedding_path,
534            n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks
535        )
536
537
538#
539# AMG FUNCTION
540#
541
542
543def run_amg(
544    checkpoint: Union[str, os.PathLike],
545    model_type: str,
546    experiment_folder: Union[str, os.PathLike],
547    val_image_paths: List[Union[str, os.PathLike]],
548    val_gt_paths: List[Union[str, os.PathLike]],
549    test_image_paths: List[Union[str, os.PathLike]],
550    iou_thresh_values: Optional[List[float]] = None,
551    stability_score_values: Optional[List[float]] = None,
552    peft_kwargs: Optional[Dict] = None,
553) -> str:
554    embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
555    os.makedirs(embedding_folder, exist_ok=True)
556
557    predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs)
558    amg = AutomaticMaskGenerator(predictor)
559    amg_prefix = "amg"
560
561    # where the predictions are saved
562    prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference")
563    os.makedirs(prediction_folder, exist_ok=True)
564
565    # where the grid-search results are saved
566    gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search")
567    os.makedirs(gs_result_folder, exist_ok=True)
568
569    grid_search_values = instance_segmentation.default_grid_search_values_amg(
570        iou_thresh_values=iou_thresh_values,
571        stability_score_values=stability_score_values,
572    )
573
574    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
575        amg, grid_search_values,
576        val_image_paths, val_gt_paths, test_image_paths,
577        embedding_folder, prediction_folder, gs_result_folder,
578    )
579    return prediction_folder
580
581
582#
583# INSTANCE SEGMENTATION FUNCTION
584#
585
586
587def run_instance_segmentation_with_decoder(
588    checkpoint: Union[str, os.PathLike],
589    model_type: str,
590    experiment_folder: Union[str, os.PathLike],
591    val_image_paths: List[Union[str, os.PathLike]],
592    val_gt_paths: List[Union[str, os.PathLike]],
593    test_image_paths: List[Union[str, os.PathLike]],
594    peft_kwargs: Optional[Dict] = None,
595) -> str:
596    embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
597    os.makedirs(embedding_folder, exist_ok=True)
598
599    predictor, decoder = get_predictor_and_decoder(
600        model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
601    )
602    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
603    seg_prefix = "instance_segmentation_with_decoder"
604
605    # where the predictions are saved
606    prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference")
607    os.makedirs(prediction_folder, exist_ok=True)
608
609    # where the grid-search results are saved
610    gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search")
611    os.makedirs(gs_result_folder, exist_ok=True)
612
613    grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder()
614
615    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
616        segmenter, grid_search_values,
617        val_image_paths, val_gt_paths, test_image_paths,
618        embedding_dir=embedding_folder, prediction_dir=prediction_folder,
619        result_dir=gs_result_folder,
620    )
621    return prediction_folder
def precompute_all_embeddings( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike]) -> None:
142def precompute_all_embeddings(
143    predictor: SamPredictor,
144    image_paths: List[Union[str, os.PathLike]],
145    embedding_dir: Union[str, os.PathLike],
146) -> None:
147    """Precompute all image embeddings.
148
149    To enable running different inference tasks in parallel afterwards.
150
151    Args:
152        predictor: The SegmentAnything predictor.
153        image_paths: The image file paths.
154        embedding_dir: The directory where the embeddings will be saved.
155    """
156    for image_path in tqdm(image_paths, desc="Precompute embeddings"):
157        image_name = os.path.basename(image_path)
158        im = imageio.imread(image_path)
159        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
160        util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)

Precompute all image embeddings.

To enable running different inference tasks in parallel afterwards.

Arguments:
  • predictor: The SegmentAnything predictor.
  • image_paths: The image file paths.
  • embedding_dir: The directory where the embeddings will be saved.
def precompute_all_prompts( gt_paths: List[Union[str, os.PathLike]], prompt_save_dir: Union[str, os.PathLike], prompt_settings: List[Dict[str, Any]]) -> None:
179def precompute_all_prompts(
180    gt_paths: List[Union[str, os.PathLike]],
181    prompt_save_dir: Union[str, os.PathLike],
182    prompt_settings: List[Dict[str, Any]],
183) -> None:
184    """Precompute all point prompts.
185
186    To enable running different inference tasks in parallel afterwards.
187
188    Args:
189        gt_paths: The file paths to the ground-truth segmentations.
190        prompt_save_dir: The directory where the prompt files will be saved.
191        prompt_settings: The settings for which the prompts will be computed.
192    """
193    os.makedirs(prompt_save_dir, exist_ok=True)
194
195    for settings in tqdm(prompt_settings, desc="Precompute prompts"):
196
197        use_points, use_boxes = settings["use_points"], settings["use_boxes"]
198        n_positives, n_negatives = settings["n_positives"], settings["n_negatives"]
199        dilation = settings.get("dilation", 5)
200
201        # check if the prompts were already computed
202        if use_boxes and not use_points:
203            prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl")
204        else:
205            prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl")
206        if os.path.exists(prompt_save_path):
207            continue
208
209        results = []
210        for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"):
211            prompts = _precompute_prompts(
212                gt_path,
213                use_points=use_points,
214                use_boxes=use_boxes,
215                n_positives=n_positives,
216                n_negatives=n_negatives,
217                dilation=dilation,
218            )
219            results.append(prompts)
220
221        saved_prompts = {res[0]: res[1] for res in results}
222        with open(prompt_save_path, "wb") as f:
223            pickle.dump(saved_prompts, f)

Precompute all point prompts.

To enable running different inference tasks in parallel afterwards.

Arguments:
  • gt_paths: The file paths to the ground-truth segmentations.
  • prompt_save_dir: The directory where the prompt files will be saved.
  • prompt_settings: The settings for which the prompts will be computed.
def run_inference_with_prompts( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], gt_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], use_points: bool, use_boxes: bool, n_positives: int, n_negatives: int, dilation: int = 5, prompt_save_dir: Union[str, os.PathLike, NoneType] = None, batch_size: int = 512) -> None:
264def run_inference_with_prompts(
265    predictor: SamPredictor,
266    image_paths: List[Union[str, os.PathLike]],
267    gt_paths: List[Union[str, os.PathLike]],
268    embedding_dir: Union[str, os.PathLike],
269    prediction_dir: Union[str, os.PathLike],
270    use_points: bool,
271    use_boxes: bool,
272    n_positives: int,
273    n_negatives: int,
274    dilation: int = 5,
275    prompt_save_dir: Optional[Union[str, os.PathLike]] = None,
276    batch_size: int = 512,
277) -> None:
278    """Run segment anything inference for multiple images using prompts derived from groundtruth.
279
280    Args:
281        predictor: The SegmentAnything predictor.
282        image_paths: The image file paths.
283        gt_paths: The ground-truth segmentation file paths.
284        embedding_dir: The directory where the image embddings will be saved or are already saved.
285        use_points: Whether to use point prompts.
286        use_boxes: Whether to use box prompts
287        n_positives: The number of positive point prompts that will be sampled.
288        n_negativess: The number of negative point prompts that will be sampled.
289        dilation: The dilation factor for the radius around the ground-truth object
290            around which points will not be sampled.
291        prompt_save_dir: The directory where point prompts will be saved or are already saved.
292            This enables running multiple experiments in a reproducible manner.
293        batch_size: The batch size used for batched prediction.
294    """
295    if not (use_points or use_boxes):
296        raise ValueError("You need to use at least one of point or box prompts.")
297
298    if len(image_paths) != len(gt_paths):
299        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
300
301    (cached_point_prompts, save_point_prompts, point_prompt_save_path,
302     cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching(
303         prompt_save_dir, use_points, use_boxes, n_positives, n_negatives
304     )
305
306    os.makedirs(prediction_dir, exist_ok=True)
307    for image_path, gt_path in tqdm(
308        zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts"
309    ):
310        image_name = os.path.basename(image_path)
311        label_name = os.path.basename(gt_path)
312
313        # We skip the images that already have been segmented.
314        prediction_path = os.path.join(prediction_dir, image_name)
315        if os.path.exists(prediction_path):
316            continue
317
318        assert os.path.exists(image_path), image_path
319        assert os.path.exists(gt_path), gt_path
320
321        im = imageio.imread(image_path)
322        gt = imageio.imread(gt_path).astype("uint32")
323        gt = relabel_sequential(gt)[0]
324
325        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
326        this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts(
327            cached_point_prompts, save_point_prompts,
328            cached_box_prompts, save_box_prompts,
329            label_name
330        )
331        instances, this_prompts = _run_inference_with_prompts_for_image(
332            predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives,
333            dilation=dilation, use_points=use_points, use_boxes=use_boxes,
334            batch_size=batch_size, cached_prompts=this_prompts,
335            embedding_path=embedding_path,
336        )
337
338        if save_point_prompts:
339            cached_point_prompts[label_name] = this_prompts[:2]
340        if save_box_prompts:
341            cached_box_prompts[label_name] = this_prompts[-1]
342
343        # It's important to compress here, otherwise the predictions would take up a lot of space.
344        imageio.imwrite(prediction_path, instances, compression=5)
345
346    # Save the prompts if we run experiments with prompt caching and have computed them
347    # for the first time.
348    if save_point_prompts:
349        with open(point_prompt_save_path, "wb") as f:
350            pickle.dump(cached_point_prompts, f)
351    if save_box_prompts:
352        with open(box_prompt_save_path, "wb") as f:
353            pickle.dump(cached_box_prompts, f)

Run segment anything inference for multiple images using prompts derived from groundtruth.

Arguments:
  • predictor: The SegmentAnything predictor.
  • image_paths: The image file paths.
  • gt_paths: The ground-truth segmentation file paths.
  • embedding_dir: The directory where the image embddings will be saved or are already saved.
  • use_points: Whether to use point prompts.
  • use_boxes: Whether to use box prompts
  • n_positives: The number of positive point prompts that will be sampled.
  • n_negativess: The number of negative point prompts that will be sampled.
  • dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
  • prompt_save_dir: The directory where point prompts will be saved or are already saved. This enables running multiple experiments in a reproducible manner.
  • batch_size: The batch size used for batched prediction.
def run_inference_with_iterative_prompting( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], gt_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], start_with_box_prompt: bool = True, dilation: int = 5, batch_size: int = 32, n_iterations: int = 8, use_masks: bool = False) -> None:
471def run_inference_with_iterative_prompting(
472    predictor: SamPredictor,
473    image_paths: List[Union[str, os.PathLike]],
474    gt_paths: List[Union[str, os.PathLike]],
475    embedding_dir: Union[str, os.PathLike],
476    prediction_dir: Union[str, os.PathLike],
477    start_with_box_prompt: bool = True,
478    dilation: int = 5,
479    batch_size: int = 32,
480    n_iterations: int = 8,
481    use_masks: bool = False
482) -> None:
483    """Run Segment Anything inference for multiple images using prompts iteratively
484    derived from model outputs and ground-truth.
485
486    Args:
487        predictor: The Segment Anything predictor.
488        image_paths: The image file paths.
489        gt_paths: The ground-truth segmentation file paths.
490        embedding_dir: The directory where the image embeddings will be saved or are already saved.
491        prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
492        start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
493        dilation: The dilation factor for the radius around the ground-truth object
494            around which points will not be sampled.
495        batch_size: The batch size used for batched predictions.
496        n_iterations: The number of iterations for iterative prompting.
497        use_masks: Whether to make use of logits from previous prompt-based segmentation.
498    """
499    if len(image_paths) != len(gt_paths):
500        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
501
502    # create all prediction folders for all intermediate iterations
503    for i in range(n_iterations):
504        os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True)
505
506    if use_masks:
507        print("The iterative prompting will make use of logits masks from previous iterations.")
508
509    for image_path, gt_path in tqdm(
510        zip(image_paths, gt_paths), total=len(image_paths),
511        desc="Run inference with iterative prompting for all images",
512    ):
513        image_name = os.path.basename(image_path)
514
515        # We skip the images that already have been segmented
516        prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)]
517        if all(os.path.exists(prediction_path) for prediction_path in prediction_paths):
518            continue
519
520        assert os.path.exists(image_path), image_path
521        assert os.path.exists(gt_path), gt_path
522
523        image = imageio.imread(image_path)
524        gt = imageio.imread(gt_path).astype("uint32")
525        gt = relabel_sequential(gt)[0]
526
527        if embedding_dir is None:
528            embedding_path = None
529        else:
530            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
531
532        _run_inference_with_iterative_prompting_for_image(
533            predictor, image, gt, start_with_box_prompt=start_with_box_prompt,
534            dilation=dilation, batch_size=batch_size, embedding_path=embedding_path,
535            n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks
536        )

Run Segment Anything inference for multiple images using prompts iteratively derived from model outputs and ground-truth.

Arguments:
  • predictor: The Segment Anything predictor.
  • image_paths: The image file paths.
  • gt_paths: The ground-truth segmentation file paths.
  • embedding_dir: The directory where the image embeddings will be saved or are already saved.
  • prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
  • start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
  • dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
  • batch_size: The batch size used for batched predictions.
  • n_iterations: The number of iterations for iterative prompting.
  • use_masks: Whether to make use of logits from previous prompt-based segmentation.
def run_amg( checkpoint: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, peft_kwargs: Optional[Dict] = None) -> str:
544def run_amg(
545    checkpoint: Union[str, os.PathLike],
546    model_type: str,
547    experiment_folder: Union[str, os.PathLike],
548    val_image_paths: List[Union[str, os.PathLike]],
549    val_gt_paths: List[Union[str, os.PathLike]],
550    test_image_paths: List[Union[str, os.PathLike]],
551    iou_thresh_values: Optional[List[float]] = None,
552    stability_score_values: Optional[List[float]] = None,
553    peft_kwargs: Optional[Dict] = None,
554) -> str:
555    embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
556    os.makedirs(embedding_folder, exist_ok=True)
557
558    predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs)
559    amg = AutomaticMaskGenerator(predictor)
560    amg_prefix = "amg"
561
562    # where the predictions are saved
563    prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference")
564    os.makedirs(prediction_folder, exist_ok=True)
565
566    # where the grid-search results are saved
567    gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search")
568    os.makedirs(gs_result_folder, exist_ok=True)
569
570    grid_search_values = instance_segmentation.default_grid_search_values_amg(
571        iou_thresh_values=iou_thresh_values,
572        stability_score_values=stability_score_values,
573    )
574
575    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
576        amg, grid_search_values,
577        val_image_paths, val_gt_paths, test_image_paths,
578        embedding_folder, prediction_folder, gs_result_folder,
579    )
580    return prediction_folder
def run_instance_segmentation_with_decoder( checkpoint: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], peft_kwargs: Optional[Dict] = None) -> str:
588def run_instance_segmentation_with_decoder(
589    checkpoint: Union[str, os.PathLike],
590    model_type: str,
591    experiment_folder: Union[str, os.PathLike],
592    val_image_paths: List[Union[str, os.PathLike]],
593    val_gt_paths: List[Union[str, os.PathLike]],
594    test_image_paths: List[Union[str, os.PathLike]],
595    peft_kwargs: Optional[Dict] = None,
596) -> str:
597    embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
598    os.makedirs(embedding_folder, exist_ok=True)
599
600    predictor, decoder = get_predictor_and_decoder(
601        model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
602    )
603    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
604    seg_prefix = "instance_segmentation_with_decoder"
605
606    # where the predictions are saved
607    prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference")
608    os.makedirs(prediction_folder, exist_ok=True)
609
610    # where the grid-search results are saved
611    gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search")
612    os.makedirs(gs_result_folder, exist_ok=True)
613
614    grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder()
615
616    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
617        segmenter, grid_search_values,
618        val_image_paths, val_gt_paths, test_image_paths,
619        embedding_dir=embedding_folder, prediction_dir=prediction_folder,
620        result_dir=gs_result_folder,
621    )
622    return prediction_folder