micro_sam.evaluation.inference

Inference with Segment Anything models and different prompt strategies.

  1"""Inference with Segment Anything models and different prompt strategies.
  2"""
  3
  4import os
  5import pickle
  6import numpy as np
  7from tqdm import tqdm
  8from copy import deepcopy
  9from typing import Any, Dict, List, Optional, Union, Tuple
 10
 11import imageio.v3 as imageio
 12
 13from bioimage_cpp.segmentation import relabel_sequential
 14
 15import torch
 16
 17from segment_anything import SamPredictor
 18
 19from .. import util as util
 20from ..inference import batched_inference
 21from ..instance_segmentation import (
 22    get_predictor_and_decoder,
 23    AutomaticMaskGenerator, InstanceSegmentationWithDecoder,
 24    TiledAutomaticMaskGenerator, TiledInstanceSegmentationWithDecoder,
 25    AutomaticPromptGenerator,
 26)
 27from . import instance_segmentation
 28from ..prompt_generators import PointAndBoxPromptGenerator, IterativePromptGenerator
 29
 30
 31def _load_prompts(
 32    cached_point_prompts, save_point_prompts, cached_box_prompts, save_box_prompts, image_name
 33):
 34
 35    def load_prompt_type(cached_prompts, save_prompts):
 36        # Check if we have saved prompts.
 37        if cached_prompts is None or save_prompts:  # we don't have cached prompts
 38            return cached_prompts, None
 39
 40        # we have cached prompts, but they have not been loaded yet
 41        if isinstance(cached_prompts, str):
 42            with open(cached_prompts, "rb") as f:
 43                cached_prompts = pickle.load(f)
 44
 45        prompts = cached_prompts[image_name]
 46        return cached_prompts, prompts
 47
 48    cached_point_prompts, point_prompts = load_prompt_type(cached_point_prompts, save_point_prompts)
 49    cached_box_prompts, box_prompts = load_prompt_type(cached_box_prompts, save_box_prompts)
 50
 51    # we don't have anything cached
 52    if point_prompts is None and box_prompts is None:
 53        return None, cached_point_prompts, cached_box_prompts
 54
 55    if point_prompts is None:
 56        input_point, input_label = [], []
 57    else:
 58        input_point, input_label = point_prompts
 59
 60    if box_prompts is None:
 61        input_box = []
 62    else:
 63        input_box = box_prompts
 64
 65    prompts = (input_point, input_label, input_box)
 66    return prompts, cached_point_prompts, cached_box_prompts
 67
 68
 69def _get_batched_prompts(gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation):
 70
 71    # Initialize the prompt generator.
 72    prompt_generator = PointAndBoxPromptGenerator(
 73        n_positive_points=n_positives, n_negative_points=n_negatives,
 74        dilation_strength=dilation, get_point_prompts=use_points,
 75        get_box_prompts=use_boxes
 76    )
 77
 78    # Generate the prompts.
 79    center_coordinates, bbox_coordinates = util.get_centers_and_bounding_boxes(gt)
 80    center_coordinates = [center_coordinates[gt_id] for gt_id in gt_ids]
 81    bbox_coordinates = [bbox_coordinates[gt_id] for gt_id in gt_ids]
 82    masks = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids)
 83
 84    points, point_labels, boxes, _ = prompt_generator(
 85        masks, bbox_coordinates, center_coordinates
 86    )
 87
 88    def to_numpy(x):
 89        if x is None:
 90            return x
 91        return x.numpy()
 92
 93    return to_numpy(points), to_numpy(point_labels), to_numpy(boxes)
 94
 95
 96def _run_inference_with_prompts_for_image(
 97    predictor,
 98    image,
 99    gt,
100    use_points,
101    use_boxes,
102    n_positives,
103    n_negatives,
104    dilation,
105    batch_size,
106    cached_prompts,
107    embedding_path,
108):
109    gt_ids = np.unique(gt)[1:]
110    if cached_prompts is None:
111        points, point_labels, boxes = _get_batched_prompts(
112            gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation
113        )
114    else:
115        points, point_labels, boxes = cached_prompts
116
117    # Make a copy of the point prompts to return them at the end.
118    prompts = deepcopy((points, point_labels, boxes))
119
120    # Use multi-masking only if we have a single positive point without box
121    multimasking = False
122    if not use_boxes and (n_positives == 1 and n_negatives == 0):
123        multimasking = True
124
125    instance_labels = batched_inference(
126        predictor, image, batch_size,
127        boxes=boxes, points=points, point_labels=point_labels,
128        multimasking=multimasking, embedding_path=embedding_path,
129        return_instance_segmentation=True,
130    )
131
132    return instance_labels, prompts
133
134
135def precompute_all_embeddings(
136    predictor: SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike],
137) -> None:
138    """Precompute all image embeddings.
139
140    To enable running different inference tasks in parallel afterwards.
141
142    Args:
143        predictor: The Segment Anything predictor.
144        image_paths: The image file paths.
145        embedding_dir: The directory where the embeddings will be saved.
146    """
147    for image_path in tqdm(image_paths, desc="Precompute embeddings"):
148        image_name = os.path.basename(image_path)
149        im = imageio.imread(image_path)
150        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
151        util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)
152
153
154def _precompute_prompts(gt_path, use_points, use_boxes, n_positives, n_negatives, dilation):
155    name = os.path.basename(gt_path)
156
157    gt = imageio.imread(gt_path).astype("uint32")
158    gt = relabel_sequential(gt)[0]
159    gt_ids = np.unique(gt)[1:]
160
161    input_point, input_label, input_box = _get_batched_prompts(
162        gt, gt_ids, use_points, use_boxes, n_positives, n_negatives, dilation
163    )
164
165    if use_boxes and not use_points:
166        return name, input_box
167    return name, (input_point, input_label)
168
169
170def precompute_all_prompts(
171    gt_paths: List[Union[str, os.PathLike]],
172    prompt_save_dir: Union[str, os.PathLike],
173    prompt_settings: List[Dict[str, Any]],
174) -> None:
175    """Precompute all point prompts.
176
177    To enable running different inference tasks in parallel afterwards.
178
179    Args:
180        gt_paths: The file paths to the ground-truth segmentations.
181        prompt_save_dir: The directory where the prompt files will be saved.
182        prompt_settings: The settings for which the prompts will be computed.
183    """
184    os.makedirs(prompt_save_dir, exist_ok=True)
185
186    for settings in tqdm(prompt_settings, desc="Precompute prompts"):
187
188        use_points, use_boxes = settings["use_points"], settings["use_boxes"]
189        n_positives, n_negatives = settings["n_positives"], settings["n_negatives"]
190        dilation = settings.get("dilation", 5)
191
192        # check if the prompts were already computed
193        if use_boxes and not use_points:
194            prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl")
195        else:
196            prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl")
197        if os.path.exists(prompt_save_path):
198            continue
199
200        results = []
201        for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"):
202            prompts = _precompute_prompts(
203                gt_path,
204                use_points=use_points,
205                use_boxes=use_boxes,
206                n_positives=n_positives,
207                n_negatives=n_negatives,
208                dilation=dilation,
209            )
210            results.append(prompts)
211
212        saved_prompts = {res[0]: res[1] for res in results}
213        with open(prompt_save_path, "wb") as f:
214            pickle.dump(saved_prompts, f)
215
216
217def _get_prompt_caching(prompt_save_dir, use_points, use_boxes, n_positives, n_negatives):
218
219    def get_prompt_type_caching(use_type, save_name):
220        if not use_type:
221            return None, False, None
222
223        prompt_save_path = os.path.join(prompt_save_dir, save_name)
224        if os.path.exists(prompt_save_path):
225            print("Using precomputed prompts from", prompt_save_path)
226            # We delay loading the prompts, so we only have to load them once they're needed the first time.
227            # This avoids loading the prompts (which are in a big pickle file) if all predictions are done already.
228            cached_prompts = prompt_save_path
229            save_prompts = False
230        else:
231            print("Saving prompts in", prompt_save_path)
232            cached_prompts = {}
233            save_prompts = True
234        return cached_prompts, save_prompts, prompt_save_path
235
236    # Check if prompt serialization is enabled.
237    # If it is then load the prompts if they are already cached and otherwise store them.
238    if prompt_save_dir is None:
239        print("Prompts are not cached.")
240        cached_point_prompts, cached_box_prompts = None, None
241        save_point_prompts, save_box_prompts = False, False
242        point_prompt_save_path, box_prompt_save_path = None, None
243    else:
244        cached_point_prompts, save_point_prompts, point_prompt_save_path = get_prompt_type_caching(
245            use_points, f"points-p{n_positives}-n{n_negatives}.pkl"
246        )
247        cached_box_prompts, save_box_prompts, box_prompt_save_path = get_prompt_type_caching(
248            use_boxes, "boxes.pkl"
249        )
250
251    return (cached_point_prompts, save_point_prompts, point_prompt_save_path,
252            cached_box_prompts, save_box_prompts, box_prompt_save_path)
253
254
255def run_inference_with_prompts(
256    predictor: SamPredictor,
257    image_paths: List[Union[str, os.PathLike]],
258    gt_paths: List[Union[str, os.PathLike]],
259    embedding_dir: Union[str, os.PathLike],
260    prediction_dir: Union[str, os.PathLike],
261    use_points: bool,
262    use_boxes: bool,
263    n_positives: int,
264    n_negatives: int,
265    dilation: int = 5,
266    prompt_save_dir: Optional[Union[str, os.PathLike]] = None,
267    batch_size: int = 512,
268) -> None:
269    """Run segment anything inference for multiple images using prompts derived from groundtruth.
270
271    Args:
272        predictor: The Segment Anything predictor.
273        image_paths: The image file paths.
274        gt_paths: The ground-truth segmentation file paths.
275        embedding_dir: The directory where the image embddings will be saved or are already saved.
276        use_points: Whether to use point prompts.
277        use_boxes: Whether to use box prompts
278        n_positives: The number of positive point prompts that will be sampled.
279        n_negativess: The number of negative point prompts that will be sampled.
280        dilation: The dilation factor for the radius around the ground-truth object
281            around which points will not be sampled.
282        prompt_save_dir: The directory where point prompts will be saved or are already saved.
283            This enables running multiple experiments in a reproducible manner.
284        batch_size: The batch size used for batched prediction.
285    """
286    if not (use_points or use_boxes):
287        raise ValueError("You need to use at least one of point or box prompts.")
288
289    if len(image_paths) != len(gt_paths):
290        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
291
292    (cached_point_prompts, save_point_prompts, point_prompt_save_path,
293     cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching(
294         prompt_save_dir, use_points, use_boxes, n_positives, n_negatives
295     )
296
297    os.makedirs(prediction_dir, exist_ok=True)
298    for image_path, gt_path in tqdm(
299        zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts"
300    ):
301        image_name = os.path.basename(image_path)
302        label_name = os.path.basename(gt_path)
303
304        # We skip the images that already have been segmented.
305        prediction_path = os.path.join(prediction_dir, image_name)
306        if os.path.exists(prediction_path):
307            continue
308
309        assert os.path.exists(image_path), image_path
310        assert os.path.exists(gt_path), gt_path
311
312        im = imageio.imread(image_path)
313        gt = imageio.imread(gt_path).astype("uint32")
314        gt = relabel_sequential(gt)[0]
315
316        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
317        this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts(
318            cached_point_prompts, save_point_prompts,
319            cached_box_prompts, save_box_prompts,
320            label_name
321        )
322        instances, this_prompts = _run_inference_with_prompts_for_image(
323            predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives,
324            dilation=dilation, use_points=use_points, use_boxes=use_boxes,
325            batch_size=batch_size, cached_prompts=this_prompts,
326            embedding_path=embedding_path,
327        )
328
329        if save_point_prompts:
330            cached_point_prompts[label_name] = this_prompts[:2]
331        if save_box_prompts:
332            cached_box_prompts[label_name] = this_prompts[-1]
333
334        # It's important to compress here, otherwise the predictions would take up a lot of space.
335        imageio.imwrite(prediction_path, instances, compression=5)
336
337    # Save the prompts if we run experiments with prompt caching and have computed them
338    # for the first time.
339    if save_point_prompts:
340        with open(point_prompt_save_path, "wb") as f:
341            pickle.dump(cached_point_prompts, f)
342    if save_box_prompts:
343        with open(box_prompt_save_path, "wb") as f:
344            pickle.dump(cached_box_prompts, f)
345
346
347def _save_segmentation(masks, prediction_path):
348    # masks to segmentation
349    masks = masks.cpu().numpy().squeeze(1).astype("bool")
350    masks = [{"segmentation": mask, "area": mask.sum()} for mask in masks]
351    segmentation = util.mask_data_to_segmentation(masks)
352    imageio.imwrite(prediction_path, segmentation, compression=5)
353
354
355def _get_batched_iterative_prompts(sampled_binary_gt, masks, batch_size, prompt_generator):
356    n_samples = sampled_binary_gt.shape[0]
357    n_batches = int(np.ceil(float(n_samples) / batch_size))
358    next_coords, next_labels = [], []
359    for batch_idx in range(n_batches):
360        batch_start = batch_idx * batch_size
361        batch_stop = min((batch_idx + 1) * batch_size, n_samples)
362
363        batch_coords, batch_labels, _, _ = prompt_generator(
364            sampled_binary_gt[batch_start: batch_stop], masks[batch_start: batch_stop]
365        )
366        next_coords.append(batch_coords)
367        next_labels.append(batch_labels)
368
369    next_coords = torch.concatenate(next_coords)
370    next_labels = torch.concatenate(next_labels)
371
372    return next_coords, next_labels
373
374
375@torch.no_grad()
376def _run_inference_with_iterative_prompting_for_image(
377    predictor,
378    image,
379    gt,
380    start_with_box_prompt,
381    dilation,
382    batch_size,
383    embedding_path,
384    n_iterations,
385    prediction_paths,
386    use_masks=False
387) -> None:
388    verbose_embeddings = False
389
390    prompt_generator = IterativePromptGenerator()
391
392    gt_ids = np.unique(gt)[1:]
393
394    # Use multi-masking only if we have a single positive point without box
395    if start_with_box_prompt:
396        use_boxes, use_points = True, False
397        n_positives = 0
398        multimasking = False
399    else:
400        use_boxes, use_points = False, True
401        n_positives = 1
402        multimasking = True
403
404    points, point_labels, boxes = _get_batched_prompts(
405        gt, gt_ids,
406        use_points=use_points,
407        use_boxes=use_boxes,
408        n_positives=n_positives,
409        n_negatives=0,
410        dilation=dilation
411    )
412
413    sampled_binary_gt = util.segmentation_to_one_hot(gt.astype("int64"), gt_ids)
414
415    for iteration in range(n_iterations):
416        if iteration == 0:  # logits mask can not be used for the first iteration.
417            logits_masks = None
418        else:
419            if not use_masks:  # logits mask should not be used when not desired.
420                logits_masks = None
421
422        batched_outputs = batched_inference(
423            predictor=predictor,
424            image=image,
425            batch_size=batch_size,
426            boxes=boxes,
427            points=points,
428            point_labels=point_labels,
429            multimasking=multimasking,
430            embedding_path=embedding_path,
431            return_instance_segmentation=False,
432            logits_masks=logits_masks,
433            verbose_embeddings=verbose_embeddings,
434        )
435
436        # switching off multimasking after first iter, as next iters (with multiple prompts) don't expect multimasking
437        multimasking = False
438
439        masks = torch.stack([m["segmentation"][None] for m in batched_outputs]).to(torch.float32)
440
441        next_coords, next_labels = _get_batched_iterative_prompts(
442            sampled_binary_gt, masks, batch_size, prompt_generator
443        )
444        next_coords, next_labels = next_coords.detach().cpu().numpy(), next_labels.detach().cpu().numpy()
445
446        if points is not None:
447            points = np.concatenate([points, next_coords], axis=1)
448        else:
449            points = next_coords
450
451        if point_labels is not None:
452            point_labels = np.concatenate([point_labels, next_labels], axis=1)
453        else:
454            point_labels = next_labels
455
456        if use_masks:
457            logits_masks = torch.stack([m["logits"] for m in batched_outputs])
458
459        _save_segmentation(masks, prediction_paths[iteration])
460
461
462def run_inference_with_iterative_prompting(
463    predictor: SamPredictor,
464    image_paths: List[Union[str, os.PathLike]],
465    gt_paths: List[Union[str, os.PathLike]],
466    embedding_dir: Union[str, os.PathLike],
467    prediction_dir: Union[str, os.PathLike],
468    start_with_box_prompt: bool = True,
469    dilation: int = 5,
470    batch_size: int = 32,
471    n_iterations: int = 8,
472    use_masks: bool = False
473) -> None:
474    """Run Segment Anything inference for multiple images using prompts iteratively
475    derived from model outputs and ground-truth.
476
477    Args:
478        predictor: The Segment Anything predictor.
479        image_paths: The image file paths.
480        gt_paths: The ground-truth segmentation file paths.
481        embedding_dir: The directory where the image embeddings will be saved or are already saved.
482        prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
483        start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
484        dilation: The dilation factor for the radius around the ground-truth object
485            around which points will not be sampled.
486        batch_size: The batch size used for batched predictions.
487        n_iterations: The number of iterations for iterative prompting.
488        use_masks: Whether to make use of logits from previous prompt-based segmentation.
489    """
490    if len(image_paths) != len(gt_paths):
491        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
492
493    # create all prediction folders for all intermediate iterations
494    for i in range(n_iterations):
495        os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True)
496
497    if use_masks:
498        print("The iterative prompting will make use of logits masks from previous iterations.")
499
500    for image_path, gt_path in tqdm(
501        zip(image_paths, gt_paths), total=len(image_paths),
502        desc="Run inference with iterative prompting for all images",
503    ):
504        image_name = os.path.basename(image_path)
505
506        # We skip the images that already have been segmented
507        prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)]
508        if all(os.path.exists(prediction_path) for prediction_path in prediction_paths):
509            continue
510
511        assert os.path.exists(image_path), image_path
512        assert os.path.exists(gt_path), gt_path
513
514        image = imageio.imread(image_path)
515        gt = imageio.imread(gt_path).astype("uint32")
516        gt = relabel_sequential(gt)[0]
517
518        if embedding_dir is None:
519            embedding_path = None
520        else:
521            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
522
523        _run_inference_with_iterative_prompting_for_image(
524            predictor, image, gt, start_with_box_prompt=start_with_box_prompt,
525            dilation=dilation, batch_size=batch_size, embedding_path=embedding_path,
526            n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks
527        )
528
529
530#
531# AMG FUNCTION
532#
533
534
535def run_amg(
536    checkpoint: Union[str, os.PathLike],
537    model_type: str,
538    experiment_folder: Union[str, os.PathLike],
539    val_image_paths: List[Union[str, os.PathLike]],
540    val_gt_paths: List[Union[str, os.PathLike]],
541    test_image_paths: List[Union[str, os.PathLike]],
542    iou_thresh_values: Optional[List[float]] = None,
543    stability_score_values: Optional[List[float]] = None,
544    peft_kwargs: Optional[Dict] = None,
545    cache_embeddings: bool = False,
546    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
547) -> str:
548    """Run Segment Anything inference for multiple images using automatic mask generation (AMG).
549
550    Args:
551        checkpoint: The filepath to model checkpoints.
552        model_type: The segment anything model choice.
553        experimet_folder: The directory where the relevant files are saved.
554        val_image_paths: The list of filepaths of input images for grid-search.
555        val_gt_paths: The list of filepaths of corresponding labels for grid-search.
556        test_image_paths: The list of filepaths of input images for automatic instance segmentation.
557        iou_thresh_values: Optional choice of values for grid search of `iou_thresh` parameter.
558        stability_score_values: Optional choice of values for grid search of `stability_score` parameter.
559        peft_kwargs: Keyword arguments for th PEFT wrapper class.
560        cache_embeddings: Whether to cache embeddings in experiment folder.
561        tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
562
563    Returns:
564        Filepath where the predictions have been saved.
565    """
566
567    if cache_embeddings:
568        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
569        os.makedirs(embedding_folder, exist_ok=True)
570    else:
571        embedding_folder = None
572
573    predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs)
574
575    # Get the AMG class.
576    if tiling_window_params:
577        if not isinstance(tiling_window_params, dict):
578            raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.")
579
580        if "tile_shape" not in tiling_window_params:
581            raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.")
582
583        if "halo" not in tiling_window_params:
584            raise RuntimeError("'halo' parameter is missing from the provided parameters.")
585
586        amg_class = TiledAutomaticMaskGenerator
587    else:
588        amg_class = AutomaticMaskGenerator
589
590    amg = amg_class(predictor)
591    amg_prefix = "amg"
592
593    # where the predictions are saved
594    prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference")
595    os.makedirs(prediction_folder, exist_ok=True)
596
597    # where the grid-search results are saved
598    gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search")
599    os.makedirs(gs_result_folder, exist_ok=True)
600
601    grid_search_values = instance_segmentation.default_grid_search_values_amg(
602        iou_thresh_values=iou_thresh_values,
603        stability_score_values=stability_score_values,
604    )
605
606    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
607        segmenter=amg,
608        grid_search_values=grid_search_values,
609        val_image_paths=val_image_paths,
610        val_gt_paths=val_gt_paths,
611        test_image_paths=test_image_paths,
612        embedding_dir=embedding_folder,
613        prediction_dir=prediction_folder,
614        result_dir=gs_result_folder,
615        experiment_folder=experiment_folder,
616        tiling_window_params=tiling_window_params,
617    )
618    return prediction_folder
619
620
621def run_apg(
622    checkpoint: Optional[Union[str, os.PathLike]],
623    model_type: str,
624    experiment_folder: Union[str, os.PathLike],
625    val_image_paths: List[Union[str, os.PathLike]],
626    val_gt_paths: List[Union[str, os.PathLike]],
627    test_image_paths: List[Union[str, os.PathLike]],
628    peft_kwargs: Optional[Dict] = None,
629    cache_embeddings: bool = False,
630    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
631) -> str:
632    """Run Segment Anything inference for multiple images using automatic prompt generation (APG).
633
634    Args:
635        ...
636
637    Returns:
638        Filepath where the predictions have been saved.
639    """
640    if cache_embeddings:
641        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
642        os.makedirs(embedding_folder, exist_ok=True)
643    else:
644        embedding_folder = None
645
646    predictor, decoder = get_predictor_and_decoder(
647        model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
648    )
649
650    # Get the APG class.
651    if tiling_window_params:
652        raise NotImplementedError
653    else:
654        apg_class = AutomaticPromptGenerator
655
656    segmenter = apg_class(predictor, decoder)
657    seg_prefix = "apg"
658
659    # where the predictions are saved
660    prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference")
661    os.makedirs(prediction_folder, exist_ok=True)
662
663    # where the grid-search results are saved
664    gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search")
665    os.makedirs(gs_result_folder, exist_ok=True)
666
667    grid_search_values = instance_segmentation.default_grid_search_values_apg()
668
669    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
670        segmenter=segmenter,
671        grid_search_values=grid_search_values,
672        val_image_paths=val_image_paths,
673        val_gt_paths=val_gt_paths,
674        test_image_paths=test_image_paths,
675        embedding_dir=embedding_folder,
676        prediction_dir=prediction_folder,
677        result_dir=gs_result_folder,
678        experiment_folder=experiment_folder,
679        tiling_window_params=tiling_window_params,
680    )
681    return prediction_folder
682
683
684#
685# INSTANCE SEGMENTATION FUNCTION
686#
687
688
689def run_instance_segmentation_with_decoder(
690    checkpoint: Union[str, os.PathLike],
691    model_type: str,
692    experiment_folder: Union[str, os.PathLike],
693    val_image_paths: List[Union[str, os.PathLike]],
694    val_gt_paths: List[Union[str, os.PathLike]],
695    test_image_paths: List[Union[str, os.PathLike]],
696    peft_kwargs: Optional[Dict] = None,
697    cache_embeddings: bool = False,
698    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
699) -> str:
700    """Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS).
701
702    Args:
703        checkpoint: The filepath to model checkpoints.
704        model_type: The segment anything model choice.
705        experimet_folder: The directory where the relevant files are saved.
706        val_image_paths: The list of filepaths of input images for grid-search.
707        val_gt_paths: The list of filepaths of corresponding labels for grid-search.
708        test_image_paths: The list of filepaths of input images for automatic instance segmentation.
709        peft_kwargs: Keyword arguments for th PEFT wrapper class.
710        cache_embeddings: Whether to cache embeddings in experiment folder.
711        tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
712
713    Returns:
714        Filepath where the predictions have been saved.
715    """
716
717    if cache_embeddings:
718        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
719        os.makedirs(embedding_folder, exist_ok=True)
720    else:
721        embedding_folder = None
722
723    predictor, decoder = get_predictor_and_decoder(
724        model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
725    )
726
727    # Get the AIS class.
728    if tiling_window_params:
729        if not isinstance(tiling_window_params, dict):
730            raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.")
731
732        if "tile_shape" not in tiling_window_params:
733            raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.")
734
735        if "halo" not in tiling_window_params:
736            raise RuntimeError("'halo' parameter is missing from the provided parameters.")
737
738        ais_class = TiledInstanceSegmentationWithDecoder
739    else:
740        ais_class = InstanceSegmentationWithDecoder
741
742    segmenter = ais_class(predictor, decoder)
743    seg_prefix = "instance_segmentation_with_decoder"
744
745    # where the predictions are saved
746    prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference")
747    os.makedirs(prediction_folder, exist_ok=True)
748
749    # where the grid-search results are saved
750    gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search")
751    os.makedirs(gs_result_folder, exist_ok=True)
752
753    grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder()
754
755    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
756        segmenter=segmenter,
757        grid_search_values=grid_search_values,
758        val_image_paths=val_image_paths,
759        val_gt_paths=val_gt_paths,
760        test_image_paths=test_image_paths,
761        embedding_dir=embedding_folder,
762        prediction_dir=prediction_folder,
763        result_dir=gs_result_folder,
764        experiment_folder=experiment_folder,
765        tiling_window_params=tiling_window_params,
766    )
767    return prediction_folder
def precompute_all_embeddings( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike]) -> None:
136def precompute_all_embeddings(
137    predictor: SamPredictor, image_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike],
138) -> None:
139    """Precompute all image embeddings.
140
141    To enable running different inference tasks in parallel afterwards.
142
143    Args:
144        predictor: The Segment Anything predictor.
145        image_paths: The image file paths.
146        embedding_dir: The directory where the embeddings will be saved.
147    """
148    for image_path in tqdm(image_paths, desc="Precompute embeddings"):
149        image_name = os.path.basename(image_path)
150        im = imageio.imread(image_path)
151        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
152        util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)

Precompute all image embeddings.

To enable running different inference tasks in parallel afterwards.

Arguments:
  • predictor: The Segment Anything predictor.
  • image_paths: The image file paths.
  • embedding_dir: The directory where the embeddings will be saved.
def precompute_all_prompts( gt_paths: List[Union[str, os.PathLike]], prompt_save_dir: Union[str, os.PathLike], prompt_settings: List[Dict[str, Any]]) -> None:
171def precompute_all_prompts(
172    gt_paths: List[Union[str, os.PathLike]],
173    prompt_save_dir: Union[str, os.PathLike],
174    prompt_settings: List[Dict[str, Any]],
175) -> None:
176    """Precompute all point prompts.
177
178    To enable running different inference tasks in parallel afterwards.
179
180    Args:
181        gt_paths: The file paths to the ground-truth segmentations.
182        prompt_save_dir: The directory where the prompt files will be saved.
183        prompt_settings: The settings for which the prompts will be computed.
184    """
185    os.makedirs(prompt_save_dir, exist_ok=True)
186
187    for settings in tqdm(prompt_settings, desc="Precompute prompts"):
188
189        use_points, use_boxes = settings["use_points"], settings["use_boxes"]
190        n_positives, n_negatives = settings["n_positives"], settings["n_negatives"]
191        dilation = settings.get("dilation", 5)
192
193        # check if the prompts were already computed
194        if use_boxes and not use_points:
195            prompt_save_path = os.path.join(prompt_save_dir, "boxes.pkl")
196        else:
197            prompt_save_path = os.path.join(prompt_save_dir, f"points-p{n_positives}-n{n_negatives}.pkl")
198        if os.path.exists(prompt_save_path):
199            continue
200
201        results = []
202        for gt_path in tqdm(gt_paths, desc=f"Precompute prompts for p{n_positives}-n{n_negatives}"):
203            prompts = _precompute_prompts(
204                gt_path,
205                use_points=use_points,
206                use_boxes=use_boxes,
207                n_positives=n_positives,
208                n_negatives=n_negatives,
209                dilation=dilation,
210            )
211            results.append(prompts)
212
213        saved_prompts = {res[0]: res[1] for res in results}
214        with open(prompt_save_path, "wb") as f:
215            pickle.dump(saved_prompts, f)

Precompute all point prompts.

To enable running different inference tasks in parallel afterwards.

Arguments:
  • gt_paths: The file paths to the ground-truth segmentations.
  • prompt_save_dir: The directory where the prompt files will be saved.
  • prompt_settings: The settings for which the prompts will be computed.
def run_inference_with_prompts( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], gt_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], use_points: bool, use_boxes: bool, n_positives: int, n_negatives: int, dilation: int = 5, prompt_save_dir: Union[str, os.PathLike, NoneType] = None, batch_size: int = 512) -> None:
256def run_inference_with_prompts(
257    predictor: SamPredictor,
258    image_paths: List[Union[str, os.PathLike]],
259    gt_paths: List[Union[str, os.PathLike]],
260    embedding_dir: Union[str, os.PathLike],
261    prediction_dir: Union[str, os.PathLike],
262    use_points: bool,
263    use_boxes: bool,
264    n_positives: int,
265    n_negatives: int,
266    dilation: int = 5,
267    prompt_save_dir: Optional[Union[str, os.PathLike]] = None,
268    batch_size: int = 512,
269) -> None:
270    """Run segment anything inference for multiple images using prompts derived from groundtruth.
271
272    Args:
273        predictor: The Segment Anything predictor.
274        image_paths: The image file paths.
275        gt_paths: The ground-truth segmentation file paths.
276        embedding_dir: The directory where the image embddings will be saved or are already saved.
277        use_points: Whether to use point prompts.
278        use_boxes: Whether to use box prompts
279        n_positives: The number of positive point prompts that will be sampled.
280        n_negativess: The number of negative point prompts that will be sampled.
281        dilation: The dilation factor for the radius around the ground-truth object
282            around which points will not be sampled.
283        prompt_save_dir: The directory where point prompts will be saved or are already saved.
284            This enables running multiple experiments in a reproducible manner.
285        batch_size: The batch size used for batched prediction.
286    """
287    if not (use_points or use_boxes):
288        raise ValueError("You need to use at least one of point or box prompts.")
289
290    if len(image_paths) != len(gt_paths):
291        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
292
293    (cached_point_prompts, save_point_prompts, point_prompt_save_path,
294     cached_box_prompts, save_box_prompts, box_prompt_save_path) = _get_prompt_caching(
295         prompt_save_dir, use_points, use_boxes, n_positives, n_negatives
296     )
297
298    os.makedirs(prediction_dir, exist_ok=True)
299    for image_path, gt_path in tqdm(
300        zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with prompts"
301    ):
302        image_name = os.path.basename(image_path)
303        label_name = os.path.basename(gt_path)
304
305        # We skip the images that already have been segmented.
306        prediction_path = os.path.join(prediction_dir, image_name)
307        if os.path.exists(prediction_path):
308            continue
309
310        assert os.path.exists(image_path), image_path
311        assert os.path.exists(gt_path), gt_path
312
313        im = imageio.imread(image_path)
314        gt = imageio.imread(gt_path).astype("uint32")
315        gt = relabel_sequential(gt)[0]
316
317        embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
318        this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts(
319            cached_point_prompts, save_point_prompts,
320            cached_box_prompts, save_box_prompts,
321            label_name
322        )
323        instances, this_prompts = _run_inference_with_prompts_for_image(
324            predictor, im, gt, n_positives=n_positives, n_negatives=n_negatives,
325            dilation=dilation, use_points=use_points, use_boxes=use_boxes,
326            batch_size=batch_size, cached_prompts=this_prompts,
327            embedding_path=embedding_path,
328        )
329
330        if save_point_prompts:
331            cached_point_prompts[label_name] = this_prompts[:2]
332        if save_box_prompts:
333            cached_box_prompts[label_name] = this_prompts[-1]
334
335        # It's important to compress here, otherwise the predictions would take up a lot of space.
336        imageio.imwrite(prediction_path, instances, compression=5)
337
338    # Save the prompts if we run experiments with prompt caching and have computed them
339    # for the first time.
340    if save_point_prompts:
341        with open(point_prompt_save_path, "wb") as f:
342            pickle.dump(cached_point_prompts, f)
343    if save_box_prompts:
344        with open(box_prompt_save_path, "wb") as f:
345            pickle.dump(cached_box_prompts, f)

Run segment anything inference for multiple images using prompts derived from groundtruth.

Arguments:
  • predictor: The Segment Anything predictor.
  • image_paths: The image file paths.
  • gt_paths: The ground-truth segmentation file paths.
  • embedding_dir: The directory where the image embddings will be saved or are already saved.
  • use_points: Whether to use point prompts.
  • use_boxes: Whether to use box prompts
  • n_positives: The number of positive point prompts that will be sampled.
  • n_negativess: The number of negative point prompts that will be sampled.
  • dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
  • prompt_save_dir: The directory where point prompts will be saved or are already saved. This enables running multiple experiments in a reproducible manner.
  • batch_size: The batch size used for batched prediction.
def run_inference_with_iterative_prompting( predictor: segment_anything.predictor.SamPredictor, image_paths: List[Union[str, os.PathLike]], gt_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], start_with_box_prompt: bool = True, dilation: int = 5, batch_size: int = 32, n_iterations: int = 8, use_masks: bool = False) -> None:
463def run_inference_with_iterative_prompting(
464    predictor: SamPredictor,
465    image_paths: List[Union[str, os.PathLike]],
466    gt_paths: List[Union[str, os.PathLike]],
467    embedding_dir: Union[str, os.PathLike],
468    prediction_dir: Union[str, os.PathLike],
469    start_with_box_prompt: bool = True,
470    dilation: int = 5,
471    batch_size: int = 32,
472    n_iterations: int = 8,
473    use_masks: bool = False
474) -> None:
475    """Run Segment Anything inference for multiple images using prompts iteratively
476    derived from model outputs and ground-truth.
477
478    Args:
479        predictor: The Segment Anything predictor.
480        image_paths: The image file paths.
481        gt_paths: The ground-truth segmentation file paths.
482        embedding_dir: The directory where the image embeddings will be saved or are already saved.
483        prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
484        start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
485        dilation: The dilation factor for the radius around the ground-truth object
486            around which points will not be sampled.
487        batch_size: The batch size used for batched predictions.
488        n_iterations: The number of iterations for iterative prompting.
489        use_masks: Whether to make use of logits from previous prompt-based segmentation.
490    """
491    if len(image_paths) != len(gt_paths):
492        raise ValueError(f"Expect same number of images and gt images, got {len(image_paths)}, {len(gt_paths)}")
493
494    # create all prediction folders for all intermediate iterations
495    for i in range(n_iterations):
496        os.makedirs(os.path.join(prediction_dir, f"iteration{i:02}"), exist_ok=True)
497
498    if use_masks:
499        print("The iterative prompting will make use of logits masks from previous iterations.")
500
501    for image_path, gt_path in tqdm(
502        zip(image_paths, gt_paths), total=len(image_paths),
503        desc="Run inference with iterative prompting for all images",
504    ):
505        image_name = os.path.basename(image_path)
506
507        # We skip the images that already have been segmented
508        prediction_paths = [os.path.join(prediction_dir, f"iteration{i:02}", image_name) for i in range(n_iterations)]
509        if all(os.path.exists(prediction_path) for prediction_path in prediction_paths):
510            continue
511
512        assert os.path.exists(image_path), image_path
513        assert os.path.exists(gt_path), gt_path
514
515        image = imageio.imread(image_path)
516        gt = imageio.imread(gt_path).astype("uint32")
517        gt = relabel_sequential(gt)[0]
518
519        if embedding_dir is None:
520            embedding_path = None
521        else:
522            embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
523
524        _run_inference_with_iterative_prompting_for_image(
525            predictor, image, gt, start_with_box_prompt=start_with_box_prompt,
526            dilation=dilation, batch_size=batch_size, embedding_path=embedding_path,
527            n_iterations=n_iterations, prediction_paths=prediction_paths, use_masks=use_masks
528        )

Run Segment Anything inference for multiple images using prompts iteratively derived from model outputs and ground-truth.

Arguments:
  • predictor: The Segment Anything predictor.
  • image_paths: The image file paths.
  • gt_paths: The ground-truth segmentation file paths.
  • embedding_dir: The directory where the image embeddings will be saved or are already saved.
  • prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration.
  • start_with_box_prompt: Whether to use the first prompt as bounding box or a single point
  • dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled.
  • batch_size: The batch size used for batched predictions.
  • n_iterations: The number of iterations for iterative prompting.
  • use_masks: Whether to make use of logits from previous prompt-based segmentation.
def run_amg( checkpoint: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, peft_kwargs: Optional[Dict] = None, cache_embeddings: bool = False, tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None) -> str:
536def run_amg(
537    checkpoint: Union[str, os.PathLike],
538    model_type: str,
539    experiment_folder: Union[str, os.PathLike],
540    val_image_paths: List[Union[str, os.PathLike]],
541    val_gt_paths: List[Union[str, os.PathLike]],
542    test_image_paths: List[Union[str, os.PathLike]],
543    iou_thresh_values: Optional[List[float]] = None,
544    stability_score_values: Optional[List[float]] = None,
545    peft_kwargs: Optional[Dict] = None,
546    cache_embeddings: bool = False,
547    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
548) -> str:
549    """Run Segment Anything inference for multiple images using automatic mask generation (AMG).
550
551    Args:
552        checkpoint: The filepath to model checkpoints.
553        model_type: The segment anything model choice.
554        experimet_folder: The directory where the relevant files are saved.
555        val_image_paths: The list of filepaths of input images for grid-search.
556        val_gt_paths: The list of filepaths of corresponding labels for grid-search.
557        test_image_paths: The list of filepaths of input images for automatic instance segmentation.
558        iou_thresh_values: Optional choice of values for grid search of `iou_thresh` parameter.
559        stability_score_values: Optional choice of values for grid search of `stability_score` parameter.
560        peft_kwargs: Keyword arguments for th PEFT wrapper class.
561        cache_embeddings: Whether to cache embeddings in experiment folder.
562        tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
563
564    Returns:
565        Filepath where the predictions have been saved.
566    """
567
568    if cache_embeddings:
569        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
570        os.makedirs(embedding_folder, exist_ok=True)
571    else:
572        embedding_folder = None
573
574    predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs)
575
576    # Get the AMG class.
577    if tiling_window_params:
578        if not isinstance(tiling_window_params, dict):
579            raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.")
580
581        if "tile_shape" not in tiling_window_params:
582            raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.")
583
584        if "halo" not in tiling_window_params:
585            raise RuntimeError("'halo' parameter is missing from the provided parameters.")
586
587        amg_class = TiledAutomaticMaskGenerator
588    else:
589        amg_class = AutomaticMaskGenerator
590
591    amg = amg_class(predictor)
592    amg_prefix = "amg"
593
594    # where the predictions are saved
595    prediction_folder = os.path.join(experiment_folder, amg_prefix, "inference")
596    os.makedirs(prediction_folder, exist_ok=True)
597
598    # where the grid-search results are saved
599    gs_result_folder = os.path.join(experiment_folder, amg_prefix, "grid_search")
600    os.makedirs(gs_result_folder, exist_ok=True)
601
602    grid_search_values = instance_segmentation.default_grid_search_values_amg(
603        iou_thresh_values=iou_thresh_values,
604        stability_score_values=stability_score_values,
605    )
606
607    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
608        segmenter=amg,
609        grid_search_values=grid_search_values,
610        val_image_paths=val_image_paths,
611        val_gt_paths=val_gt_paths,
612        test_image_paths=test_image_paths,
613        embedding_dir=embedding_folder,
614        prediction_dir=prediction_folder,
615        result_dir=gs_result_folder,
616        experiment_folder=experiment_folder,
617        tiling_window_params=tiling_window_params,
618    )
619    return prediction_folder

Run Segment Anything inference for multiple images using automatic mask generation (AMG).

Arguments:
  • checkpoint: The filepath to model checkpoints.
  • model_type: The segment anything model choice.
  • experimet_folder: The directory where the relevant files are saved.
  • val_image_paths: The list of filepaths of input images for grid-search.
  • val_gt_paths: The list of filepaths of corresponding labels for grid-search.
  • test_image_paths: The list of filepaths of input images for automatic instance segmentation.
  • iou_thresh_values: Optional choice of values for grid search of iou_thresh parameter.
  • stability_score_values: Optional choice of values for grid search of stability_score parameter.
  • peft_kwargs: Keyword arguments for th PEFT wrapper class.
  • cache_embeddings: Whether to cache embeddings in experiment folder.
  • tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
Returns:

Filepath where the predictions have been saved.

def run_apg( checkpoint: Union[str, os.PathLike, NoneType], model_type: str, experiment_folder: Union[str, os.PathLike], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], peft_kwargs: Optional[Dict] = None, cache_embeddings: bool = False, tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None) -> str:
622def run_apg(
623    checkpoint: Optional[Union[str, os.PathLike]],
624    model_type: str,
625    experiment_folder: Union[str, os.PathLike],
626    val_image_paths: List[Union[str, os.PathLike]],
627    val_gt_paths: List[Union[str, os.PathLike]],
628    test_image_paths: List[Union[str, os.PathLike]],
629    peft_kwargs: Optional[Dict] = None,
630    cache_embeddings: bool = False,
631    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
632) -> str:
633    """Run Segment Anything inference for multiple images using automatic prompt generation (APG).
634
635    Args:
636        ...
637
638    Returns:
639        Filepath where the predictions have been saved.
640    """
641    if cache_embeddings:
642        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
643        os.makedirs(embedding_folder, exist_ok=True)
644    else:
645        embedding_folder = None
646
647    predictor, decoder = get_predictor_and_decoder(
648        model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
649    )
650
651    # Get the APG class.
652    if tiling_window_params:
653        raise NotImplementedError
654    else:
655        apg_class = AutomaticPromptGenerator
656
657    segmenter = apg_class(predictor, decoder)
658    seg_prefix = "apg"
659
660    # where the predictions are saved
661    prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference")
662    os.makedirs(prediction_folder, exist_ok=True)
663
664    # where the grid-search results are saved
665    gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search")
666    os.makedirs(gs_result_folder, exist_ok=True)
667
668    grid_search_values = instance_segmentation.default_grid_search_values_apg()
669
670    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
671        segmenter=segmenter,
672        grid_search_values=grid_search_values,
673        val_image_paths=val_image_paths,
674        val_gt_paths=val_gt_paths,
675        test_image_paths=test_image_paths,
676        embedding_dir=embedding_folder,
677        prediction_dir=prediction_folder,
678        result_dir=gs_result_folder,
679        experiment_folder=experiment_folder,
680        tiling_window_params=tiling_window_params,
681    )
682    return prediction_folder

Run Segment Anything inference for multiple images using automatic prompt generation (APG).

Arguments:
  • ...
Returns:

Filepath where the predictions have been saved.

def run_instance_segmentation_with_decoder( checkpoint: Union[str, os.PathLike], model_type: str, experiment_folder: Union[str, os.PathLike], val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], peft_kwargs: Optional[Dict] = None, cache_embeddings: bool = False, tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None) -> str:
690def run_instance_segmentation_with_decoder(
691    checkpoint: Union[str, os.PathLike],
692    model_type: str,
693    experiment_folder: Union[str, os.PathLike],
694    val_image_paths: List[Union[str, os.PathLike]],
695    val_gt_paths: List[Union[str, os.PathLike]],
696    test_image_paths: List[Union[str, os.PathLike]],
697    peft_kwargs: Optional[Dict] = None,
698    cache_embeddings: bool = False,
699    tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
700) -> str:
701    """Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS).
702
703    Args:
704        checkpoint: The filepath to model checkpoints.
705        model_type: The segment anything model choice.
706        experimet_folder: The directory where the relevant files are saved.
707        val_image_paths: The list of filepaths of input images for grid-search.
708        val_gt_paths: The list of filepaths of corresponding labels for grid-search.
709        test_image_paths: The list of filepaths of input images for automatic instance segmentation.
710        peft_kwargs: Keyword arguments for th PEFT wrapper class.
711        cache_embeddings: Whether to cache embeddings in experiment folder.
712        tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
713
714    Returns:
715        Filepath where the predictions have been saved.
716    """
717
718    if cache_embeddings:
719        embedding_folder = os.path.join(experiment_folder, "embeddings")  # where the precomputed embeddings are saved
720        os.makedirs(embedding_folder, exist_ok=True)
721    else:
722        embedding_folder = None
723
724    predictor, decoder = get_predictor_and_decoder(
725        model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
726    )
727
728    # Get the AIS class.
729    if tiling_window_params:
730        if not isinstance(tiling_window_params, dict):
731            raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.")
732
733        if "tile_shape" not in tiling_window_params:
734            raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.")
735
736        if "halo" not in tiling_window_params:
737            raise RuntimeError("'halo' parameter is missing from the provided parameters.")
738
739        ais_class = TiledInstanceSegmentationWithDecoder
740    else:
741        ais_class = InstanceSegmentationWithDecoder
742
743    segmenter = ais_class(predictor, decoder)
744    seg_prefix = "instance_segmentation_with_decoder"
745
746    # where the predictions are saved
747    prediction_folder = os.path.join(experiment_folder, seg_prefix, "inference")
748    os.makedirs(prediction_folder, exist_ok=True)
749
750    # where the grid-search results are saved
751    gs_result_folder = os.path.join(experiment_folder, seg_prefix, "grid_search")
752    os.makedirs(gs_result_folder, exist_ok=True)
753
754    grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder()
755
756    instance_segmentation.run_instance_segmentation_grid_search_and_inference(
757        segmenter=segmenter,
758        grid_search_values=grid_search_values,
759        val_image_paths=val_image_paths,
760        val_gt_paths=val_gt_paths,
761        test_image_paths=test_image_paths,
762        embedding_dir=embedding_folder,
763        prediction_dir=prediction_folder,
764        result_dir=gs_result_folder,
765        experiment_folder=experiment_folder,
766        tiling_window_params=tiling_window_params,
767    )
768    return prediction_folder

Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS).

Arguments:
  • checkpoint: The filepath to model checkpoints.
  • model_type: The segment anything model choice.
  • experimet_folder: The directory where the relevant files are saved.
  • val_image_paths: The list of filepaths of input images for grid-search.
  • val_gt_paths: The list of filepaths of corresponding labels for grid-search.
  • test_image_paths: The list of filepaths of input images for automatic instance segmentation.
  • peft_kwargs: Keyword arguments for th PEFT wrapper class.
  • cache_embeddings: Whether to cache embeddings in experiment folder.
  • tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
Returns:

Filepath where the predictions have been saved.