micro_sam.inference

  1import os
  2from typing import Optional, Union
  3
  4import numpy as np
  5
  6import torch
  7
  8from segment_anything import SamPredictor
  9import segment_anything.utils.amg as amg_utils
 10from segment_anything.utils.transforms import ResizeLongestSide
 11
 12from . import util
 13from .instance_segmentation import mask_data_to_segmentation
 14from ._vendored import batched_mask_to_box
 15
 16
 17@torch.no_grad()
 18def batched_inference(
 19    predictor: SamPredictor,
 20    image: np.ndarray,
 21    batch_size: int,
 22    boxes: Optional[np.ndarray] = None,
 23    points: Optional[np.ndarray] = None,
 24    point_labels: Optional[np.ndarray] = None,
 25    multimasking: bool = False,
 26    embedding_path: Optional[Union[str, os.PathLike]] = None,
 27    return_instance_segmentation: bool = True,
 28    segmentation_ids: Optional[list] = None,
 29    reduce_multimasking: bool = True,
 30    logits_masks: Optional[torch.Tensor] = None,
 31    verbose_embeddings: bool = True,
 32):
 33    """Run batched inference for input prompts.
 34
 35    Args:
 36        predictor: The segment anything predictor.
 37        image: The input image.
 38        batch_size: The batch size to use for inference.
 39        boxes: The box prompts. Array of shape N_PROMPTS x 4.
 40            The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
 41        points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2.
 42            The points are represented by their coordinates [X, Y], which are given
 43            in the last dimension.
 44        point_labels: The point prompt labels. Array of shape N_PROMPTS x 1.
 45            The labels are either 0 (negative prompt) or 1 (positive prompt).
 46        multimasking: Whether to predict with 3 or 1 mask.
 47        embedding_path: Cache path for the image embeddings.
 48        return_instance_segmentation: Whether to return a instance segmentation
 49            or the individual mask data.
 50        segmentation_ids: Fixed segmentation ids to assign to the masks
 51            derived from the prompts.
 52        reduce_multimasking: Whether to choose the most likely masks with
 53            highest ious from multimasking
 54        logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256.
 55            Whether to use the logits masks from previous segmentation.
 56        verbose_embeddings: Whether to show progress outputs of computing image embeddings.
 57
 58    Returns:
 59        The predicted segmentation masks.
 60    """
 61    if multimasking and (segmentation_ids is not None) and (not return_instance_segmentation):
 62        raise NotImplementedError
 63
 64    if (points is None) != (point_labels is None):
 65        raise ValueError(
 66            "If you have point prompts both `points` and `point_labels` have to be passed, "
 67            "but you passed only one of them."
 68        )
 69
 70    have_points = points is not None
 71    have_boxes = boxes is not None
 72    have_logits = logits_masks is not None
 73    if (not have_points) and (not have_boxes):
 74        raise ValueError("Point and/or box prompts have to be passed, you passed neither.")
 75
 76    if have_points and (len(point_labels) != len(points)):
 77        raise ValueError(
 78            "The number of point coordinates and labels does not match: "
 79            f"{len(point_labels)} != {len(points)}"
 80        )
 81
 82    if (have_points and have_boxes) and (len(points) != len(boxes)):
 83        raise ValueError(
 84            "The number of point and box prompts does not match: "
 85            f"{len(points)} != {len(boxes)}"
 86        )
 87
 88    if have_logits:
 89        if have_points and (len(logits_masks) != len(point_labels)):
 90            raise ValueError(
 91                "The number of point and logits does not match: "
 92                f"{len(points) != len(logits_masks)}"
 93            )
 94        elif have_boxes and (len(logits_masks) != len(boxes)):
 95            raise ValueError(
 96                "The number of boxes and logits does not match: "
 97                f"{len(boxes)} != {len(logits_masks)}"
 98            )
 99
100    n_prompts = boxes.shape[0] if have_boxes else points.shape[0]
101
102    if (segmentation_ids is not None) and (len(segmentation_ids) != n_prompts):
103        raise ValueError(
104            "The number of segmentation ids and prompts does not match: "
105            f"{len(segmentation_ids)} != {n_prompts}"
106        )
107
108    # Compute the image embeddings.
109    image_embeddings = util.precompute_image_embeddings(
110        predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
111    )
112    util.set_precomputed(predictor, image_embeddings)
113
114    # Determine the number of batches.
115    n_batches = int(np.ceil(float(n_prompts) / batch_size))
116
117    # Preprocess the prompts.
118    device = predictor.device
119    transform_function = ResizeLongestSide(1024)
120    image_shape = predictor.original_size
121    if have_boxes:
122        boxes = transform_function.apply_boxes(boxes, image_shape)
123        boxes = torch.tensor(boxes, dtype=torch.float32).to(device)
124    if have_points:
125        points = transform_function.apply_coords(points, image_shape)
126        points = torch.tensor(points, dtype=torch.float32).to(device)
127        point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device)
128
129    masks = amg_utils.MaskData()
130    for batch_idx in range(n_batches):
131        batch_start = batch_idx * batch_size
132        batch_stop = min((batch_idx + 1) * batch_size, n_prompts)
133
134        batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None
135        batch_points = points[batch_start:batch_stop] if have_points else None
136        batch_labels = point_labels[batch_start:batch_stop] if have_points else None
137        batch_logits = logits_masks[batch_start:batch_stop] if have_logits else None
138
139        batch_masks, batch_ious, batch_logits = predictor.predict_torch(
140            point_coords=batch_points,
141            point_labels=batch_labels,
142            boxes=batch_boxes,
143            mask_input=batch_logits,
144            multimask_output=multimasking
145        )
146
147        # If we expect to reduce the masks from multimasking and use multi-masking,
148        # then we need to select the most likely mask (according to the predicted IOU) here.
149        if reduce_multimasking and multimasking:
150            _, max_index = batch_ious.max(axis=1)
151            batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
152            batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
153            batch_logits = torch.cat([batch_logits[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
154
155        batch_data = amg_utils.MaskData(masks=batch_masks.flatten(0, 1), iou_preds=batch_ious.flatten(0, 1))
156        batch_data["masks"] = (batch_data["masks"] > predictor.model.mask_threshold).type(torch.bool)
157        batch_data["boxes"] = batched_mask_to_box(batch_data["masks"])
158        batch_data["logits"] = batch_logits
159
160        masks.cat(batch_data)
161
162    # Mask data to records.
163    masks = [
164        {
165            "segmentation": masks["masks"][idx],
166            "area": masks["masks"][idx].sum(),
167            "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(),
168            "predicted_iou": masks["iou_preds"][idx].item(),
169            "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]),
170            "logits": masks["logits"][idx]
171        }
172        for idx in range(len(masks["masks"]))
173    ]
174
175    if return_instance_segmentation:
176        masks = mask_data_to_segmentation(masks, with_background=False, min_object_size=0)
177
178    return masks
@torch.no_grad()
def batched_inference( predictor: segment_anything.predictor.SamPredictor, image: numpy.ndarray, batch_size: int, boxes: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, point_labels: Optional[numpy.ndarray] = None, multimasking: bool = False, embedding_path: Union[str, os.PathLike, NoneType] = None, return_instance_segmentation: bool = True, segmentation_ids: Optional[list] = None, reduce_multimasking: bool = True, logits_masks: Optional[torch.Tensor] = None, verbose_embeddings: bool = True):
 18@torch.no_grad()
 19def batched_inference(
 20    predictor: SamPredictor,
 21    image: np.ndarray,
 22    batch_size: int,
 23    boxes: Optional[np.ndarray] = None,
 24    points: Optional[np.ndarray] = None,
 25    point_labels: Optional[np.ndarray] = None,
 26    multimasking: bool = False,
 27    embedding_path: Optional[Union[str, os.PathLike]] = None,
 28    return_instance_segmentation: bool = True,
 29    segmentation_ids: Optional[list] = None,
 30    reduce_multimasking: bool = True,
 31    logits_masks: Optional[torch.Tensor] = None,
 32    verbose_embeddings: bool = True,
 33):
 34    """Run batched inference for input prompts.
 35
 36    Args:
 37        predictor: The segment anything predictor.
 38        image: The input image.
 39        batch_size: The batch size to use for inference.
 40        boxes: The box prompts. Array of shape N_PROMPTS x 4.
 41            The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
 42        points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2.
 43            The points are represented by their coordinates [X, Y], which are given
 44            in the last dimension.
 45        point_labels: The point prompt labels. Array of shape N_PROMPTS x 1.
 46            The labels are either 0 (negative prompt) or 1 (positive prompt).
 47        multimasking: Whether to predict with 3 or 1 mask.
 48        embedding_path: Cache path for the image embeddings.
 49        return_instance_segmentation: Whether to return a instance segmentation
 50            or the individual mask data.
 51        segmentation_ids: Fixed segmentation ids to assign to the masks
 52            derived from the prompts.
 53        reduce_multimasking: Whether to choose the most likely masks with
 54            highest ious from multimasking
 55        logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256.
 56            Whether to use the logits masks from previous segmentation.
 57        verbose_embeddings: Whether to show progress outputs of computing image embeddings.
 58
 59    Returns:
 60        The predicted segmentation masks.
 61    """
 62    if multimasking and (segmentation_ids is not None) and (not return_instance_segmentation):
 63        raise NotImplementedError
 64
 65    if (points is None) != (point_labels is None):
 66        raise ValueError(
 67            "If you have point prompts both `points` and `point_labels` have to be passed, "
 68            "but you passed only one of them."
 69        )
 70
 71    have_points = points is not None
 72    have_boxes = boxes is not None
 73    have_logits = logits_masks is not None
 74    if (not have_points) and (not have_boxes):
 75        raise ValueError("Point and/or box prompts have to be passed, you passed neither.")
 76
 77    if have_points and (len(point_labels) != len(points)):
 78        raise ValueError(
 79            "The number of point coordinates and labels does not match: "
 80            f"{len(point_labels)} != {len(points)}"
 81        )
 82
 83    if (have_points and have_boxes) and (len(points) != len(boxes)):
 84        raise ValueError(
 85            "The number of point and box prompts does not match: "
 86            f"{len(points)} != {len(boxes)}"
 87        )
 88
 89    if have_logits:
 90        if have_points and (len(logits_masks) != len(point_labels)):
 91            raise ValueError(
 92                "The number of point and logits does not match: "
 93                f"{len(points) != len(logits_masks)}"
 94            )
 95        elif have_boxes and (len(logits_masks) != len(boxes)):
 96            raise ValueError(
 97                "The number of boxes and logits does not match: "
 98                f"{len(boxes)} != {len(logits_masks)}"
 99            )
100
101    n_prompts = boxes.shape[0] if have_boxes else points.shape[0]
102
103    if (segmentation_ids is not None) and (len(segmentation_ids) != n_prompts):
104        raise ValueError(
105            "The number of segmentation ids and prompts does not match: "
106            f"{len(segmentation_ids)} != {n_prompts}"
107        )
108
109    # Compute the image embeddings.
110    image_embeddings = util.precompute_image_embeddings(
111        predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
112    )
113    util.set_precomputed(predictor, image_embeddings)
114
115    # Determine the number of batches.
116    n_batches = int(np.ceil(float(n_prompts) / batch_size))
117
118    # Preprocess the prompts.
119    device = predictor.device
120    transform_function = ResizeLongestSide(1024)
121    image_shape = predictor.original_size
122    if have_boxes:
123        boxes = transform_function.apply_boxes(boxes, image_shape)
124        boxes = torch.tensor(boxes, dtype=torch.float32).to(device)
125    if have_points:
126        points = transform_function.apply_coords(points, image_shape)
127        points = torch.tensor(points, dtype=torch.float32).to(device)
128        point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device)
129
130    masks = amg_utils.MaskData()
131    for batch_idx in range(n_batches):
132        batch_start = batch_idx * batch_size
133        batch_stop = min((batch_idx + 1) * batch_size, n_prompts)
134
135        batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None
136        batch_points = points[batch_start:batch_stop] if have_points else None
137        batch_labels = point_labels[batch_start:batch_stop] if have_points else None
138        batch_logits = logits_masks[batch_start:batch_stop] if have_logits else None
139
140        batch_masks, batch_ious, batch_logits = predictor.predict_torch(
141            point_coords=batch_points,
142            point_labels=batch_labels,
143            boxes=batch_boxes,
144            mask_input=batch_logits,
145            multimask_output=multimasking
146        )
147
148        # If we expect to reduce the masks from multimasking and use multi-masking,
149        # then we need to select the most likely mask (according to the predicted IOU) here.
150        if reduce_multimasking and multimasking:
151            _, max_index = batch_ious.max(axis=1)
152            batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
153            batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
154            batch_logits = torch.cat([batch_logits[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
155
156        batch_data = amg_utils.MaskData(masks=batch_masks.flatten(0, 1), iou_preds=batch_ious.flatten(0, 1))
157        batch_data["masks"] = (batch_data["masks"] > predictor.model.mask_threshold).type(torch.bool)
158        batch_data["boxes"] = batched_mask_to_box(batch_data["masks"])
159        batch_data["logits"] = batch_logits
160
161        masks.cat(batch_data)
162
163    # Mask data to records.
164    masks = [
165        {
166            "segmentation": masks["masks"][idx],
167            "area": masks["masks"][idx].sum(),
168            "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(),
169            "predicted_iou": masks["iou_preds"][idx].item(),
170            "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]),
171            "logits": masks["logits"][idx]
172        }
173        for idx in range(len(masks["masks"]))
174    ]
175
176    if return_instance_segmentation:
177        masks = mask_data_to_segmentation(masks, with_background=False, min_object_size=0)
178
179    return masks

Run batched inference for input prompts.

Arguments:
  • predictor: The segment anything predictor.
  • image: The input image.
  • batch_size: The batch size to use for inference.
  • boxes: The box prompts. Array of shape N_PROMPTS x 4. The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
  • points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. The points are represented by their coordinates [X, Y], which are given in the last dimension.
  • point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. The labels are either 0 (negative prompt) or 1 (positive prompt).
  • multimasking: Whether to predict with 3 or 1 mask.
  • embedding_path: Cache path for the image embeddings.
  • return_instance_segmentation: Whether to return a instance segmentation or the individual mask data.
  • segmentation_ids: Fixed segmentation ids to assign to the masks derived from the prompts.
  • reduce_multimasking: Whether to choose the most likely masks with highest ious from multimasking
  • logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. Whether to use the logits masks from previous segmentation.
  • verbose_embeddings: Whether to show progress outputs of computing image embeddings.
Returns:

The predicted segmentation masks.