
Inference with Segment Anything models and different prompt strategies.

  1"""Inference with Segment Anything models and different prompt strategies.
  4import os
  5import pickle
  6import numpy as np
  7from tqdm import tqdm
  8from copy import deepcopy
  9from typing import Any, Dict, List, Optional, Union, Tuple
 11import imageio.v3 as imageio
 12from skimage.segmentation import relabel_sequential
 14import torch
 16from segment_anything import SamPredictor
 18from .. import util as util
 19from ..inference import batched_inference
 20from ..instance_segmentation import (
 21    mask_data_to_segmentation, get_predictor_and_decoder,
 22    AutomaticMaskGenerator, InstanceSegmentationWithDecoder,
 23    TiledAutomaticMaskGenerator, TiledInstanceSegmentationWithDecoder,
 25from . import instance_segmentation
 26from ..prompt_generators import PointAndBoxPromptGenerator, IterativePromptGenerator
 29def _load_prompts(
 30    cached_point_prompts, save_point_prompts, cached_box_prompts, save_box_prompts, image_name
 33    def load_prompt_type(cached_prompts, save_prompts):
 34        # Check if we have saved prompts.
 35        if cached_prompts is None or save_prompts:  # we don't have cached prompts
 36            return cached_prompts, None
 38        # we have cached prompts, but they have not been loaded yet
 39        if isinstance(cached_prompts, str):
 40            with open(cached_prompts, "rb") as f:
 41                cached_prompts = pickle.load(f)
 43        prompts = cached_prompts[image_name]
 44        return cached_prompts, prompts
 46    cached_point_prompts, point_prompts = load_prompt_type(cached_point_prompts, save_point_prompts)
 47    cached_box_prompts, box_prompts = load_prompt_type(cached_box_prompts, save_box_prompts)
 49    # we don't have anything cached
 50    if point_prompts is None and box_prompts is None:
 51        return None, cached_point_prompts, cached_box_prompts
 53    if point_prompts is None:
 54        input_point, input_label = [], []
 55    else:
 56        input_point, input_label = point_prompts
 58    if box_prompts is None:
 59        input_box = []
 60    else:
 61        input_box = box_prompts
 63    prompts = (input_point, input_label, input_box)
 64    return prompts, cached_point_prompts, cached_box_prompts
 67def _get_batched_prompts(gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation):
 69    # Initialize the prompt generator.
 70    prompt_generator = PointAndBoxPromptGenerator(
 71        n_positive_points=n_positives, n_negative_points=n_negatives,
 72        dilation_strength=dilation, get_point_prompts=use_points,
 73        get_box_prompts=use_boxes
 74    )
 76    # Generate the prompts.
 77    center_coordinates, bbox_coordinates = util.get_centers_and_bounding_boxes(gt)
 78    center_coordinates = [center_coordinates[gt_id] for gt_id in gt_ids]
 79    bbox_coordinates = [bbox_coordinates[gt_id] for gt_id in gt_ids]
 80    masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids)
 82    points, point_labels, boxes, _ = prompt_generator(
 83        masks, bbox_coordinates, center_coordinates
 84    )
 86    def to_numpy(x):
 87        if x is None:
 88            return x
 89        return x.numpy()
 91    return to_numpy(points), to_numpy(point_labels), to_numpy(boxes)
 94def _run_inference_with_prompts_for_image(
 95    predictor,
 96    image,
 97    gt,
 98    use_points,
 99    use_boxes,
100    n_positives,
101    n_negatives,
102    dilation,
103    batch_size,
104    cached_prompts,
105    embedding_path,
107    gt_ids = np.unique(gt)[1:]
108    if cached_prompts is None:
109        points, point_labels, boxes = _get_batched_prompts(
110            gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation
111        )
112    else:
113        points, point_labels, boxes = cached_prompts
115    # Make a copy of the point prompts to return them at the end.
116    prompts = deepcopy((points, point_labels, boxes))
118    # Use multi-masking only if we have a single positive point without box
119    multimasking = False
120    if not use_boxes and (n_positives == 1 and n_negatives == 0):
121        multimasking = True
123    instance_labels = batched_inference(
124        predictor, image, batch_size,
125        boxes=boxes, points=points, point_labels=point_labels,
126        multimasking=multimasking, embedding_path=embedding_path,
127        return_instance_segmentation=True,
128    )
130    return instance_labels, prompts
133def precompute_all_embeddings(
134    predictor: SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike],
135) -> None:
136    """Precompute all image embeddings.
138    To enable running different inference tasks in parallel afterwards.
140    Args:
141        predictor: The SegmentAnything predictor.
142        image_paths: The image file paths.
143        embedding_dir: The directory where the embeddings will be saved.
144    """
145    for image_path in tqdm(image_paths, desc="Precompute embeddings"):
146        image_name = os.path.basename(image_path)
147        im = imageio.imread(image_path)
148        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
149        util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)
152def _precompute_prompts(gt_path, use_points, use_boxes, n_positives, n_negatives, dilation):
153    name = os.path.basename(gt_path)
155    gt = imageio.imread(gt_path).astype("uint32")
156    gt = relabel_sequential(gt)[0]
157    gt_ids = np.unique(gt)[1:]
159    input_point, input_label, input_box = _get_batched_prompts(
160        gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation
161    )
163    if use_boxes and not use_points:
164        return name, input_box
165    return name, (input_point, input_label)
168def precompute_all_prompts(
169    gt_paths: List[Union[str, os.PathLike]],
170    prompt_save_dir: Union[str, os.PathLike],
171    prompt_settings: List[Dict[str, Any]],
172) -> None:
173    """Precompute all point prompts.
175    To enable running different inference tasks in parallel afterwards.
177    Args:
178        gt_paths: The file paths to the ground-truth segmentations.
179        prompt_save_dir: The directory where the prompt files will be saved.
180        prompt_settings: The settings for which the prompts will be computed.
181    """
182    os.makedirs(prompt_save_dir, exist_ok=True)
184    for settings in tqdm(prompt_settings, desc="Precompute prompts"):
186        use_points, use_boxes = settings["use_points"], settings["use_boxes"]
187        n_positives, n_negatives = settings["n_positives"], settings["n_negatives"]
188        dilation = settings.get("dilation", 5)
190        # check if the prompts were already computed
191        if use_boxes and not use_points:
192            prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl")
193        else:
194            prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl")
195        if os.path.exists(prompt_save_path):
196            continue
198        results = []
199        for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"):
200            prompts = _precompute_prompts(
201                gt_path,
202                use_points=use_points,
203                use_boxes=use_boxes,
204                n_positives=n_positives,
205                n_negatives=n_negatives,
206                dilation=dilation,
207            )
208            results.append(prompts)
210        saved_prompts = {res[0]: res[1] for res in results}
211        with open(prompt_save_path, "wb") as f:
212            pickle.dump(saved_prompts, f)
215def _get_prompt_caching(prompt_save_dir, use_points, use_boxes, n_positives, n_negatives):
217    def get_prompt_type_caching(use_type, save_name):
218        if not use_type:
219            return None, False, None
221        prompt_save_path = os.path.join(prompt_save_dir, save_name)
222        if os.path.exists(prompt_save_path):
223            print("Using precomputed prompts from", prompt_save_path)
224            # We delay loading the prompts, so we only have to load them once they're needed the first time.
225            # This avoids loading the prompts (which are in a big pickle file) if all predictions are done already.
226            cached_prompts = prompt_save_path
227            save_prompts = False
228        else:
229            print("Saving prompts in", prompt_save_path)
230            cached_prompts = {}
231            save_prompts = True
232        return cached_prompts, save_prompts, prompt_save_path
234    # Check if prompt serialization is enabled.
235    # If it is then load the prompts if they are already cached and otherwise store them.
236    if prompt_save_dir is None:
237        print("Prompts are not cached.")
238        cached_point_prompts, cached_box_prompts = None, None
239        save_point_prompts, save_box_prompts = False, False
240        point_prompt_save_path, box_prompt_save_path = None, None
241    else:
242        cached_point_prompts, save_point_prompts, point_prompt_save_path = get_prompt_type_caching(
243            use_points, f"points-p{n_positives}-n{n_negatives}.pkl"
244        )
245        cached_box_prompts, save_box_prompts, box_prompt_save_path = get_prompt_type_caching(
246            use_boxes, "boxes.pkl"
247        )
249    return (cached_point_prompts, save_point_prompts, point_prompt_save_path,
250            cached_box_prompts, save_box_prompts, box_prompt_save_path)
253def run_inference_with_prompts(
254    predictor: SamPredictor,
255    image_paths: List[Union[str, os.PathLike]],
256    gt_paths: List[Union[str, os.PathLike]],
257    embedding_dir: Union[str, os.PathLike],
258    prediction_dir: Union[str, os.PathLike],
259    use_points: bool,
260    use_boxes: bool,
261    n_positives: int,
262    n_negatives: int,
263    dilation: int = 5,
264    prompt_save_dir: Optional[Union[str, os.PathLike]] = None,
265    batch_size: int = 512,
266) -> None:
267    """Run segment anything inference for multiple images using prompts derived from groundtruth.
269    Args:
270        predictor: The SegmentAnything predictor.
271        image_paths: The image file paths.
272        gt_paths: The ground-truth segmentation file paths.
273        embedding_dir: The directory where the image embddings will be saved or are already saved.
274        use_points: Whether to use point prompts.
275        use_boxes: Whether to use box prompts
276        n_positives: The number of positive point prompts that will be sampled.
277        n_negativess: The number of negative point prompts that will be sampled.
278        dilation: The dilation factor for the radius around the ground-truth object
279            around which points will not be sampled.
280        prompt_save_dir: The directory where point prompts will be saved or are already saved.
281            This enables running multiple experiments in a reproducible manner.
282        batch_size: The batch size used for batched prediction.
283    """
284    if not (use_points or use_boxes):
285        raise ValueError("You need to use at least one of point or box prompts.")
287    if len(image_paths) != len(gt_paths):
288        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
290    (cached_point_prompts, save_point_prompts, point_prompt_save_path,
291     cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching(
292         prompt_save_dir, use_points, use_boxes, n_positives, n_negatives
293     )
295    os.makedirs(prediction_dir, exist_ok=True)
296    for image_path, gt_path in tqdm(
297        zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts"
298    ):
299        image_name = os.path.basename(image_path)
300        label_name = os.path.basename(gt_path)
302        # We skip the images that already have been segmented.
303        prediction_path = os.path.join(prediction_dir, image_name)
304        if os.path.exists(prediction_path):
305            continue
307        assert os.path.exists(image_path), image_path
308        assert os.path.exists(gt_path), gt_path
310        im = imageio.imread(image_path)
311        gt = imageio.imread(gt_path).astype("uint32")
312        gt = relabel_sequential(gt)[0]
314        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
315        this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts(
316            cached_point_prompts, save_point_prompts,
317            cached_box_prompts, save_box_prompts,
318            label_name
319        )
320        instances, this_prompts = _run_inference_with_prompts_for_image(
321            predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives,
322            dilation=dilation, use_points=use_points, use_boxes=use_boxes,
323            batch_size=batch_size, cached_prompts=this_prompts,
324            embedding_path=embedding_path,
325        )
327        if save_point_prompts:
328            cached_point_prompts[label_name] = this_prompts[:2]
329        if save_box_prompts:
330            cached_box_prompts[label_name] = this_prompts[-1]
332        # It's important to compress here, otherwise the predictions would take up a lot of space.
333        imageio.imwrite(prediction_path, instances, compression=5)
335    # Save the prompts if we run experiments with prompt caching and have computed them
336    # for the first time.
337    if save_point_prompts:
338        with open(point_prompt_save_path, "wb") as f:
339            pickle.dump(cached_point_prompts, f)
340    if save_box_prompts:
341        with open(box_prompt_save_path, "wb") as f:
342            pickle.dump(cached_box_prompts, f)
345def _save_segmentation(masks, prediction_path):
346    # masks to segmentation
347    masks = masks.cpu().numpy().squeeze(1).astype("bool")
348    masks = [{"segmentation": mask, "area": mask.sum()} for mask in masks]
349    segmentation = mask_data_to_segmentation(masks, with_background=True)
350    imageio.imwrite(prediction_path, segmentation, compression=5)
353def _get_batched_iterative_prompts(sampled_binary_gt, masks, batch_size, prompt_generator):
354    n_samples = sampled_binary_gt.shape[0]
355    n_batches = int(np.ceil(float(n_samples) / batch_size))
356    next_coords, next_labels = [], []
357    for batch_idx in range(n_batches):
358        batch_start = batch_idx * batch_size
359        batch_stop = min((batch_idx + 1) * batch_size, n_samples)
361        batch_coords, batch_labels, _, _ = prompt_generator(
362            sampled_binary_gt[batch_start: batch_stop], masks[batch_start: batch_stop]
363        )
364        next_coords.append(batch_coords)
365        next_labels.append(batch_labels)
367    next_coords = torch.concatenate(next_coords)
368    next_labels = torch.concatenate(next_labels)
370    return next_coords, next_labels
374def _run_inference_with_iterative_prompting_for_image(
375    predictor,
376    image,
377    gt,
378    start_with_box_prompt,
379    dilation,
380    batch_size,
381    embedding_path,
382    n_iterations,
383    prediction_paths,
384    use_masks=False
385) -> None:
386    verbose_embeddings = False
388    prompt_generator = IterativePromptGenerator()
390    gt_ids = np.unique(gt)[1:]
392    # Use multi-masking only if we have a single positive point without box
393    if start_with_box_prompt:
394        use_boxes, use_points = True, False
395        n_positives = 0
396        multimasking = False
397    else:
398        use_boxes, use_points = False, True
399        n_positives = 1
400        multimasking = True
402    points, point_labels, boxes = _get_batched_prompts(
403        gt, gt_ids,
404        use_points=use_points,
405        use_boxes=use_boxes,
406        n_positives=n_positives,
407        n_negatives=0,
408        dilation=dilation
409    )
411    sampled_binary_gt = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids)
413    for iteration in range(n_iterations):
414        if iteration == 0:  # logits mask can not be used for the first iteration.
415            logits_masks = None
416        else:
417            if not use_masks:  # logits mask should not be used when not desired.
418                logits_masks = None
420        batched_outputs = batched_inference(
421            predictor=predictor,
422            image=image,
423            batch_size=batch_size,
424            boxes=boxes,
425            points=points,
426            point_labels=point_labels,
427            multimasking=multimasking,
428            embedding_path=embedding_path,
429            return_instance_segmentation=False,
430            logits_masks=logits_masks,
431            verbose_embeddings=verbose_embeddings,
432        )
434        # switching off multimasking after first iter, as next iters (with multiple prompts) don't expect multimasking
435        multimasking = False
437        masks = torch.stack([m["segmentation"][None] for m in batched_outputs]).to(torch.float32)
439        next_coords, next_labels = _get_batched_iterative_prompts(
440            sampled_binary_gt, masks, batch_size, prompt_generator
441        )
442        next_coords, next_labels = next_coords.detach().cpu().numpy(), next_labels.detach().cpu().numpy()
444        if points is not None:
445            points = np.concatenate([points, next_coords], axis=1)
446        else:
447            points = next_coords
449        if point_labels is not None:
450            point_labels = np.concatenate([point_labels, next_labels], axis=1)
451        else:
452            point_labels = next_labels
454        if use_masks:
455            logits_masks = torch.stack([m["logits"] for m in batched_outputs])
457        _save_segmentation(masks, prediction_paths[iteration])
460def run_inference_with_iterative_prompting(
461    predictor: SamPredictor,
462    image_paths: List[Union[str, os.PathLike]],
463    gt_paths: List[Union[str, os.PathLike]],
464    embedding_dir: Union[str, os.PathLike],
465    prediction_dir: Union[str, os.PathLike],
466    start_with_box_prompt: bool = True,
467    dilation: int = 5,
468    batch_size: int = 32,
469    n_iterations: int = 8,
470    use_masks: bool = False
471) -> None:
472    """Run Segment Anything inference for multiple images using prompts iteratively
473    derived from model outputs and ground-truth.
475    Args:
476        predictor: The Segment Anything predictor.
477        image_paths: The image file paths.
478        gt_paths: The ground-truth segmentation file paths.
479        embedding_dir: The directory where the image embeddings will be saved or are already saved.
480        prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
481        start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
482        dilation: The dilation factor for the radius around the ground-truth object
483            around which points will not be sampled.
484        batch_size: The batch size used for batched predictions.
485        n_iterations: The number of iterations for iterative prompting.
486        use_masks: Whether to make use of logits from previous prompt-based segmentation.
487    """
488    if len(image_paths) != len(gt_paths):
489        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
491    # create all prediction folders for all intermediate iterations
492    for i in range(n_iterations):
493        os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True)
495    if use_masks:
496        print("The iterative prompting will make use of logits masks from previous iterations.")
498    for image_path, gt_path in tqdm(
499        zip(image_paths, gt_paths), total=len(image_paths),
500        desc="Run inference with iterative prompting for all images",
501    ):
502        image_name = os.path.basename(image_path)
504        # We skip the images that already have been segmented
505        prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)]
506        if all(os.path.exists(prediction_path) for prediction_path in prediction_paths):
507            continue
509        assert os.path.exists(image_path), image_path
510        assert os.path.exists(gt_path), gt_path
512        image = imageio.imread(image_path)
513        gt = imageio.imread(gt_path).astype("uint32")
514        gt = relabel_sequential(gt)[0]
516        if embedding_dir is None:
517            embedding_path = None
518        else:
519            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
521        _run_inference_with_iterative_prompting_for_image(
522            predictor, image, gt, start_with_box_prompt=start_with_box_prompt,
523            dilation=dilation, batch_size=batch_size, embedding_path=embedding_path,
524            n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks
525        )
533def run_amg(
534    checkpoint: Union[str, os.PathLike],
535    model_type: str,
536    experiment_folder: Union[str, os.PathLike],
537    val_image_paths: List[Union[str, os.PathLike]],
538    val_gt_paths: List[Union[str, os.PathLike]],
539    test_image_paths: List[Union[str, os.PathLike]],
540    iou_thresh_values: Optional[List[float]] = None,
541    stability_score_values: Optional[List[float]] = None,
542    peft_kwargs: Optional[Dict] = None,
543    cache_embeddings: bool = False,
544    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
545) -> str:
546    """Run Segment Anything inference for multiple images using automatic mask generation (AMG).
548    Args:
549        checkpoint: The filepath to model checkpoints.
550        model_type: The segment anything model choice.
551        experimet_folder: The directory where the relevant files are saved.
552        val_image_paths: The list of filepaths of input images for grid-search.
553        val_gt_paths: The list of filepaths of corresponding labels for grid-search.
554        test_image_paths: The list of filepaths of input images for automatic instance segmentation.
555        iou_thresh_values: Optional choice of values for grid search of `iou_thresh` parameter.
556        stability_score_values: Optional choice of values for grid search of `stability_score` parameter.
557        peft_kwargs: Keyword arguments for th PEFT wrapper class.
558        cache_embeddings: Whether to cache embeddings in experiment folder.
559        tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
561    Returns:
562        Filepath where the predictions have been saved.
563    """
565    if cache_embeddings:
566        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
567        os.makedirs(embedding_folder, exist_ok=True)
568    else:
569        embedding_folder = None
571    predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs)
573    # Get the AMG class.
574    if tiling_window_params:
575        if not isinstance(tiling_window_params, dict):
576            raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.")
578        if "tile_shape" not in tiling_window_params:
579            raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.")
581        if "halo" not in tiling_window_params:
582            raise RuntimeError("'halo' parameter is missing from the provided parameters.")
584        amg_class = TiledAutomaticMaskGenerator
585    else:
586        amg_class = AutomaticMaskGenerator
588    amg = amg_class(predictor)
589    amg_prefix = "amg"
591    # where the predictions are saved
592    prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference")
593    os.makedirs(prediction_folder, exist_ok=True)
595    # where the grid-search results are saved
596    gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search")
597    os.makedirs(gs_result_folder, exist_ok=True)
599    grid_search_values = instance_segmentation.default_grid_search_values_amg(
600        iou_thresh_values=iou_thresh_values,
601        stability_score_values=stability_score_values,
602    )
604    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
605        segmenter=amg,
606        grid_search_values=grid_search_values,
607        val_image_paths=val_image_paths,
608        val_gt_paths=val_gt_paths,
609        test_image_paths=test_image_paths,
610        embedding_dir=embedding_folder,
611        prediction_dir=prediction_folder,
612        result_dir=gs_result_folder,
613        experiment_folder=experiment_folder,
614        tiling_window_params=tiling_window_params,
615    )
616    return prediction_folder
624def run_instance_segmentation_with_decoder(
625    checkpoint: Union[str, os.PathLike],
626    model_type: str,
627    experiment_folder: Union[str, os.PathLike],
628    val_image_paths: List[Union[str, os.PathLike]],
629    val_gt_paths: List[Union[str, os.PathLike]],
630    test_image_paths: List[Union[str, os.PathLike]],
631    peft_kwargs: Optional[Dict] = None,
632    cache_embeddings: bool = False,
633    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
634) -> str:
635    """Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS).
637    Args:
638        checkpoint: The filepath to model checkpoints.
639        model_type: The segment anything model choice.
640        experimet_folder: The directory where the relevant files are saved.
641        val_image_paths: The list of filepaths of input images for grid-search.
642        val_gt_paths: The list of filepaths of corresponding labels for grid-search.
643        test_image_paths: The list of filepaths of input images for automatic instance segmentation.
644        peft_kwargs: Keyword arguments for th PEFT wrapper class.
645        cache_embeddings: Whether to cache embeddings in experiment folder.
646        tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
648    Returns:
649        Filepath where the predictions have been saved.
650    """
652    if cache_embeddings:
653        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
654        os.makedirs(embedding_folder, exist_ok=True)
655    else:
656        embedding_folder = None
658    predictor, decoder = get_predictor_and_decoder(
659        model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
660    )
662    # Get the AIS class.
663    if tiling_window_params:
664        if not isinstance(tiling_window_params, dict):
665            raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.")
667        if "tile_shape" not in tiling_window_params:
668            raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.")
670        if "halo" not in tiling_window_params:
671            raise RuntimeError("'halo' parameter is missing from the provided parameters.")
673        ais_class = TiledInstanceSegmentationWithDecoder
674    else:
675        ais_class = InstanceSegmentationWithDecoder
677    segmenter = ais_class(predictor, decoder)
678    seg_prefix = "instance_segmentation_with_decoder"
680    # where the predictions are saved
681    prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference")
682    os.makedirs(prediction_folder, exist_ok=True)
684    # where the grid-search results are saved
685    gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search")
686    os.makedirs(gs_result_folder, exist_ok=True)
688    grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder()
690    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
691        segmenter=segmenter,
692        grid_search_values=grid_search_values,
693        val_image_paths=val_image_paths,
694        val_gt_paths=val_gt_paths,
695        test_image_paths=test_image_paths,
696        embedding_dir=embedding_folder,
697        prediction_dir=prediction_folder,
698        result_dir=gs_result_folder,
699        experiment_folder=experiment_folder,
700        tiling_window_params=tiling_window_params,
701    )
702    return prediction_folder
def precompute_all_embeddings( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike]) -> None:
134def precompute_all_embeddings(
135    predictor: SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike],
136) -> None:
137    """Precompute all image embeddings.
139    To enable running different inference tasks in parallel afterwards.
141    Args:
142        predictor: The SegmentAnything predictor.
143        image_paths: The image file paths.
144        embedding_dir: The directory where the embeddings will be saved.
145    """
146    for image_path in tqdm(image_paths, desc="Precompute embeddings"):
147        image_name = os.path.basename(image_path)
148        im = imageio.imread(image_path)
149        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
150        util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)

Precompute all image embeddings.

To enable running different inference tasks in parallel afterwards.

  • predictor: The SegmentAnything predictor.
  • image_paths: The image file paths.
  • embedding_dir: The directory where the embeddings will be saved.
def precompute_all_prompts( gt_paths: List[Union[str, os.PathLike]], prompt_save_dir: Union[str, os.PathLike], prompt_settings: List[Dict[str, Any]]) -> None:
169def precompute_all_prompts(
170    gt_paths: List[Union[str, os.PathLike]],
171    prompt_save_dir: Union[str, os.PathLike],
172    prompt_settings: List[Dict[str, Any]],
173) -> None:
174    """Precompute all point prompts.
176    To enable running different inference tasks in parallel afterwards.
178    Args:
179        gt_paths: The file paths to the ground-truth segmentations.
180        prompt_save_dir: The directory where the prompt files will be saved.
181        prompt_settings: The settings for which the prompts will be computed.
182    """
183    os.makedirs(prompt_save_dir, exist_ok=True)
185    for settings in tqdm(prompt_settings, desc="Precompute prompts"):
187        use_points, use_boxes = settings["use_points"], settings["use_boxes"]
188        n_positives, n_negatives = settings["n_positives"], settings["n_negatives"]
189        dilation = settings.get("dilation", 5)
191        # check if the prompts were already computed
192        if use_boxes and not use_points:
193            prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl")
194        else:
195            prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl")
196        if os.path.exists(prompt_save_path):
197            continue
199        results = []
200        for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"):
201            prompts = _precompute_prompts(
202                gt_path,
203                use_points=use_points,
204                use_boxes=use_boxes,
205                n_positives=n_positives,
206                n_negatives=n_negatives,
207                dilation=dilation,
208            )
209            results.append(prompts)
211        saved_prompts = {res[0]: res[1] for res in results}
212        with open(prompt_save_path, "wb") as f:
213            pickle.dump(saved_prompts, f)

Precompute all point prompts.

To enable running different inference tasks in parallel afterwards.

  • gt_paths: The file paths to the ground-truth segmentations.
  • prompt_save_dir: The directory where the prompt files will be saved.
  • prompt_settings: The settings for which the prompts will be computed.
def run_inference_with_prompts( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], gt_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], use_points: bool, use_boxes: bool, n_positives: int, n_negatives: int, dilation: int = 5, prompt_save_dir: Union[str, os.PathLike, NoneType] = None, batch_size: int = 512) -> None:
254def run_inference_with_prompts(
255    predictor: SamPredictor,
256    image_paths: List[Union[str, os.PathLike]],
257    gt_paths: List[Union[str, os.PathLike]],
258    embedding_dir: Union[str, os.PathLike],
259    prediction_dir: Union[str, os.PathLike],
260    use_points: bool,
261    use_boxes: bool,
262    n_positives: int,
263    n_negatives: int,
264    dilation: int = 5,
265    prompt_save_dir: Optional[Union[str, os.PathLike]] = None,
266    batch_size: int = 512,
267) -> None:
268    """Run segment anything inference for multiple images using prompts derived from groundtruth.
270    Args:
271        predictor: The SegmentAnything predictor.
272        image_paths: The image file paths.
273        gt_paths: The ground-truth segmentation file paths.
274        embedding_dir: The directory where the image embddings will be saved or are already saved.
275        use_points: Whether to use point prompts.
276        use_boxes: Whether to use box prompts
277        n_positives: The number of positive point prompts that will be sampled.
278        n_negativess: The number of negative point prompts that will be sampled.
279        dilation: The dilation factor for the radius around the ground-truth object
280            around which points will not be sampled.
281        prompt_save_dir: The directory where point prompts will be saved or are already saved.
282            This enables running multiple experiments in a reproducible manner.
283        batch_size: The batch size used for batched prediction.
284    """
285    if not (use_points or use_boxes):
286        raise ValueError("You need to use at least one of point or box prompts.")
288    if len(image_paths) != len(gt_paths):
289        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
291    (cached_point_prompts, save_point_prompts, point_prompt_save_path,
292     cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching(
293         prompt_save_dir, use_points, use_boxes, n_positives, n_negatives
294     )
296    os.makedirs(prediction_dir, exist_ok=True)
297    for image_path, gt_path in tqdm(
298        zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts"
299    ):
300        image_name = os.path.basename(image_path)
301        label_name = os.path.basename(gt_path)
303        # We skip the images that already have been segmented.
304        prediction_path = os.path.join(prediction_dir, image_name)
305        if os.path.exists(prediction_path):
306            continue
308        assert os.path.exists(image_path), image_path
309        assert os.path.exists(gt_path), gt_path
311        im = imageio.imread(image_path)
312        gt = imageio.imread(gt_path).astype("uint32")
313        gt = relabel_sequential(gt)[0]
315        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
316        this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts(
317            cached_point_prompts, save_point_prompts,
318            cached_box_prompts, save_box_prompts,
319            label_name
320        )
321        instances, this_prompts = _run_inference_with_prompts_for_image(
322            predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives,
323            dilation=dilation, use_points=use_points, use_boxes=use_boxes,
324            batch_size=batch_size, cached_prompts=this_prompts,
325            embedding_path=embedding_path,
326        )
328        if save_point_prompts:
329            cached_point_prompts[label_name] = this_prompts[:2]
330        if save_box_prompts:
331            cached_box_prompts[label_name] = this_prompts[-1]
333        # It's important to compress here, otherwise the predictions would take up a lot of space.
334        imageio.imwrite(prediction_path, instances, compression=5)
336    # Save the prompts if we run experiments with prompt caching and have computed them
337    # for the first time.
338    if save_point_prompts:
339        with open(point_prompt_save_path, "wb") as f:
340            pickle.dump(cached_point_prompts, f)
341    if save_box_prompts:
342        with open(box_prompt_save_path, "wb") as f:
343            pickle.dump(cached_box_prompts, f)

Run segment anything inference for multiple images using prompts derived from groundtruth.

  • predictor: The SegmentAnything predictor.
  • image_paths: The image file paths.
  • gt_paths: The ground-truth segmentation file paths.
  • embedding_dir: The directory where the image embddings will be saved or are already saved.
  • use_points: Whether to use point prompts.
  • use_boxes: Whether to use box prompts
  • n_positives: The number of positive point prompts that will be sampled.
  • n_negativess: The number of negative point prompts that will be sampled.
  • dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
  • prompt_save_dir: The directory where point prompts will be saved or are already saved. This enables running multiple experiments in a reproducible manner.
  • batch_size: The batch size used for batched prediction.
def run_inference_with_iterative_prompting( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], gt_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], start_with_box_prompt: bool = True, dilation: int = 5, batch_size: int = 32, n_iterations: int = 8, use_masks: bool = False) -> None:
461def run_inference_with_iterative_prompting(
462    predictor: SamPredictor,
463    image_paths: List[Union[str, os.PathLike]],
464    gt_paths: List[Union[str, os.PathLike]],
465    embedding_dir: Union[str, os.PathLike],
466    prediction_dir: Union[str, os.PathLike],
467    start_with_box_prompt: bool = True,
468    dilation: int = 5,
469    batch_size: int = 32,
470    n_iterations: int = 8,
471    use_masks: bool = False
472) -> None:
473    """Run Segment Anything inference for multiple images using prompts iteratively
474    derived from model outputs and ground-truth.
476    Args:
477        predictor: The Segment Anything predictor.
478        image_paths: The image file paths.
479        gt_paths: The ground-truth segmentation file paths.
480        embedding_dir: The directory where the image embeddings will be saved or are already saved.
481        prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
482        start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
483        dilation: The dilation factor for the radius around the ground-truth object
484            around which points will not be sampled.
485        batch_size: The batch size used for batched predictions.
486        n_iterations: The number of iterations for iterative prompting.
487        use_masks: Whether to make use of logits from previous prompt-based segmentation.
488    """
489    if len(image_paths) != len(gt_paths):
490        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
492    # create all prediction folders for all intermediate iterations
493    for i in range(n_iterations):
494        os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True)
496    if use_masks:
497        print("The iterative prompting will make use of logits masks from previous iterations.")
499    for image_path, gt_path in tqdm(
500        zip(image_paths, gt_paths), total=len(image_paths),
501        desc="Run inference with iterative prompting for all images",
502    ):
503        image_name = os.path.basename(image_path)
505        # We skip the images that already have been segmented
506        prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)]
507        if all(os.path.exists(prediction_path) for prediction_path in prediction_paths):
508            continue
510        assert os.path.exists(image_path), image_path
511        assert os.path.exists(gt_path), gt_path
513        image = imageio.imread(image_path)
514        gt = imageio.imread(gt_path).astype("uint32")
515        gt = relabel_sequential(gt)[0]
517        if embedding_dir is None:
518            embedding_path = None
519        else:
520            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
522        _run_inference_with_iterative_prompting_for_image(
523            predictor, image, gt, start_with_box_prompt=start_with_box_prompt,
524            dilation=dilation, batch_size=batch_size, embedding_path=embedding_path,
525            n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks
526        )

Run Segment Anything inference for multiple images using prompts iteratively derived from model outputs and ground-truth.

  • predictor: The Segment Anything predictor.
  • image_paths: The image file paths.
  • gt_paths: The ground-truth segmentation file paths.
  • embedding_dir: The directory where the image embeddings will be saved or are already saved.
  • prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
  • start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
  • dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
  • batch_size: The batch size used for batched predictions.
  • n_iterations: The number of iterations for iterative prompting.
  • use_masks: Whether to make use of logits from previous prompt-based segmentation.
def run_amg( checkpoint: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, peft_kwargs: Optional[Dict] = None, cache_embeddings: bool = False, tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None) -> str:
534def run_amg(
535    checkpoint: Union[str, os.PathLike],
536    model_type: str,
537    experiment_folder: Union[str, os.PathLike],
538    val_image_paths: List[Union[str, os.PathLike]],
539    val_gt_paths: List[Union[str, os.PathLike]],
540    test_image_paths: List[Union[str, os.PathLike]],
541    iou_thresh_values: Optional[List[float]] = None,
542    stability_score_values: Optional[List[float]] = None,
543    peft_kwargs: Optional[Dict] = None,
544    cache_embeddings: bool = False,
545    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
546) -> str:
547    """Run Segment Anything inference for multiple images using automatic mask generation (AMG).
549    Args:
550        checkpoint: The filepath to model checkpoints.
551        model_type: The segment anything model choice.
552        experimet_folder: The directory where the relevant files are saved.
553        val_image_paths: The list of filepaths of input images for grid-search.
554        val_gt_paths: The list of filepaths of corresponding labels for grid-search.
555        test_image_paths: The list of filepaths of input images for automatic instance segmentation.
556        iou_thresh_values: Optional choice of values for grid search of `iou_thresh` parameter.
557        stability_score_values: Optional choice of values for grid search of `stability_score` parameter.
558        peft_kwargs: Keyword arguments for th PEFT wrapper class.
559        cache_embeddings: Whether to cache embeddings in experiment folder.
560        tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
562    Returns:
563        Filepath where the predictions have been saved.
564    """
566    if cache_embeddings:
567        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
568        os.makedirs(embedding_folder, exist_ok=True)
569    else:
570        embedding_folder = None
572    predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs)
574    # Get the AMG class.
575    if tiling_window_params:
576        if not isinstance(tiling_window_params, dict):
577            raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.")
579        if "tile_shape" not in tiling_window_params:
580            raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.")
582        if "halo" not in tiling_window_params:
583            raise RuntimeError("'halo' parameter is missing from the provided parameters.")
585        amg_class = TiledAutomaticMaskGenerator
586    else:
587        amg_class = AutomaticMaskGenerator
589    amg = amg_class(predictor)
590    amg_prefix = "amg"
592    # where the predictions are saved
593    prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference")
594    os.makedirs(prediction_folder, exist_ok=True)
596    # where the grid-search results are saved
597    gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search")
598    os.makedirs(gs_result_folder, exist_ok=True)
600    grid_search_values = instance_segmentation.default_grid_search_values_amg(
601        iou_thresh_values=iou_thresh_values,
602        stability_score_values=stability_score_values,
603    )
605    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
606        segmenter=amg,
607        grid_search_values=grid_search_values,
608        val_image_paths=val_image_paths,
609        val_gt_paths=val_gt_paths,
610        test_image_paths=test_image_paths,
611        embedding_dir=embedding_folder,
612        prediction_dir=prediction_folder,
613        result_dir=gs_result_folder,
614        experiment_folder=experiment_folder,
615        tiling_window_params=tiling_window_params,
616    )
617    return prediction_folder

Run Segment Anything inference for multiple images using automatic mask generation (AMG).

  • checkpoint: The filepath to model checkpoints.
  • model_type: The segment anything model choice.
  • experimet_folder: The directory where the relevant files are saved.
  • val_image_paths: The list of filepaths of input images for grid-search.
  • val_gt_paths: The list of filepaths of corresponding labels for grid-search.
  • test_image_paths: The list of filepaths of input images for automatic instance segmentation.
  • iou_thresh_values: Optional choice of values for grid search of iou_thresh parameter.
  • stability_score_values: Optional choice of values for grid search of stability_score parameter.
  • peft_kwargs: Keyword arguments for th PEFT wrapper class.
  • cache_embeddings: Whether to cache embeddings in experiment folder.
  • tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.

Filepath where the predictions have been saved.

def run_instance_segmentation_with_decoder( checkpoint: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], peft_kwargs: Optional[Dict] = None, cache_embeddings: bool = False, tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None) -> str:
625def run_instance_segmentation_with_decoder(
626    checkpoint: Union[str, os.PathLike],
627    model_type: str,
628    experiment_folder: Union[str, os.PathLike],
629    val_image_paths: List[Union[str, os.PathLike]],
630    val_gt_paths: List[Union[str, os.PathLike]],
631    test_image_paths: List[Union[str, os.PathLike]],
632    peft_kwargs: Optional[Dict] = None,
633    cache_embeddings: bool = False,
634    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
635) -> str:
636    """Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS).
638    Args:
639        checkpoint: The filepath to model checkpoints.
640        model_type: The segment anything model choice.
641        experimet_folder: The directory where the relevant files are saved.
642        val_image_paths: The list of filepaths of input images for grid-search.
643        val_gt_paths: The list of filepaths of corresponding labels for grid-search.
644        test_image_paths: The list of filepaths of input images for automatic instance segmentation.
645        peft_kwargs: Keyword arguments for th PEFT wrapper class.
646        cache_embeddings: Whether to cache embeddings in experiment folder.
647        tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
649    Returns:
650        Filepath where the predictions have been saved.
651    """
653    if cache_embeddings:
654        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
655        os.makedirs(embedding_folder, exist_ok=True)
656    else:
657        embedding_folder = None
659    predictor, decoder = get_predictor_and_decoder(
660        model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
661    )
663    # Get the AIS class.
664    if tiling_window_params:
665        if not isinstance(tiling_window_params, dict):
666            raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.")
668        if "tile_shape" not in tiling_window_params:
669            raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.")
671        if "halo" not in tiling_window_params:
672            raise RuntimeError("'halo' parameter is missing from the provided parameters.")
674        ais_class = TiledInstanceSegmentationWithDecoder
675    else:
676        ais_class = InstanceSegmentationWithDecoder
678    segmenter = ais_class(predictor, decoder)
679    seg_prefix = "instance_segmentation_with_decoder"
681    # where the predictions are saved
682    prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference")
683    os.makedirs(prediction_folder, exist_ok=True)
685    # where the grid-search results are saved
686    gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search")
687    os.makedirs(gs_result_folder, exist_ok=True)
689    grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder()
691    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
692        segmenter=segmenter,
693        grid_search_values=grid_search_values,
694        val_image_paths=val_image_paths,
695        val_gt_paths=val_gt_paths,
696        test_image_paths=test_image_paths,
697        embedding_dir=embedding_folder,
698        prediction_dir=prediction_folder,
699        result_dir=gs_result_folder,
700        experiment_folder=experiment_folder,
701        tiling_window_params=tiling_window_params,
702    )
703    return prediction_folder

Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS).

  • checkpoint: The filepath to model checkpoints.
  • model_type: The segment anything model choice.
  • experimet_folder: The directory where the relevant files are saved.
  • val_image_paths: The list of filepaths of input images for grid-search.
  • val_gt_paths: The list of filepaths of corresponding labels for grid-search.
  • test_image_paths: The list of filepaths of input images for automatic instance segmentation.
  • peft_kwargs: Keyword arguments for th PEFT wrapper class.
  • cache_embeddings: Whether to cache embeddings in experiment folder.
  • tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.

Filepath where the predictions have been saved.