micro_sam.inference

  1import os
  2from typing import Optional, Union
  3
  4import numpy as np
  5
  6import torch
  7
  8from segment_anything import SamPredictor
  9import segment_anything.utils.amg as amg_utils
 10from segment_anything.utils.transforms import ResizeLongestSide
 11
 12from . import util
 13from .instance_segmentation import mask_data_to_segmentation
 14from ._vendored import batched_mask_to_box
 15
 16
 17@torch.no_grad()
 18def batched_inference(
 19    predictor: SamPredictor,
 20    image: np.ndarray,
 21    batch_size: int,
 22    boxes: Optional[np.ndarray] = None,
 23    points: Optional[np.ndarray] = None,
 24    point_labels: Optional[np.ndarray] = None,
 25    multimasking: bool = False,
 26    embedding_path: Optional[Union[str, os.PathLike]] = None,
 27    return_instance_segmentation: bool = True,
 28    segmentation_ids: Optional[list] = None,
 29    reduce_multimasking: bool = True,
 30    logits_masks: Optional[torch.Tensor] = None,
 31    verbose_embeddings: bool = True,
 32):
 33    """Run batched inference for input prompts.
 34
 35    Args:
 36        predictor: The segment anything predictor.
 37        image: The input image.
 38        batch_size: The batch size to use for inference.
 39        boxes: The box prompts. Array of shape N_PROMPTS x 4.
 40            The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
 41        points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2.
 42            The points are represented by their coordinates [X, Y], which are given in the last dimension.
 43        point_labels: The point prompt labels. Array of shape N_PROMPTS x 1.
 44            The labels are either 0 (negative prompt) or 1 (positive prompt).
 45        multimasking: Whether to predict with 3 or 1 mask.
 46        embedding_path: Cache path for the image embeddings.
 47        return_instance_segmentation: Whether to return a instance segmentation
 48            or the individual mask data.
 49        segmentation_ids: Fixed segmentation ids to assign to the masks
 50            derived from the prompts.
 51        reduce_multimasking: Whether to choose the most likely masks with
 52            highest ious from multimasking
 53        logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256.
 54            Whether to use the logits masks from previous segmentation.
 55        verbose_embeddings: Whether to show progress outputs of computing image embeddings.
 56
 57    Returns:
 58        The predicted segmentation masks.
 59    """
 60    if multimasking and (segmentation_ids is not None) and (not return_instance_segmentation):
 61        raise NotImplementedError
 62
 63    if (points is None) != (point_labels is None):
 64        raise ValueError(
 65            "If you have point prompts both `points` and `point_labels` have to be passed, "
 66            "but you passed only one of them."
 67        )
 68
 69    have_points = points is not None
 70    have_boxes = boxes is not None
 71    have_logits = logits_masks is not None
 72    if (not have_points) and (not have_boxes):
 73        raise ValueError("Point and/or box prompts have to be passed, you passed neither.")
 74
 75    if have_points and (len(point_labels) != len(points)):
 76        raise ValueError(
 77            "The number of point coordinates and labels does not match: "
 78            f"{len(point_labels)} != {len(points)}"
 79        )
 80
 81    if (have_points and have_boxes) and (len(points) != len(boxes)):
 82        raise ValueError(
 83            "The number of point and box prompts does not match: "
 84            f"{len(points)} != {len(boxes)}"
 85        )
 86
 87    if have_logits:
 88        if have_points and (len(logits_masks) != len(point_labels)):
 89            raise ValueError(
 90                "The number of point and logits does not match: "
 91                f"{len(points) != len(logits_masks)}"
 92            )
 93        elif have_boxes and (len(logits_masks) != len(boxes)):
 94            raise ValueError(
 95                "The number of boxes and logits does not match: "
 96                f"{len(boxes)} != {len(logits_masks)}"
 97            )
 98
 99    n_prompts = boxes.shape[0] if have_boxes else points.shape[0]
100
101    if (segmentation_ids is not None) and (len(segmentation_ids) != n_prompts):
102        raise ValueError(
103            "The number of segmentation ids and prompts does not match: "
104            f"{len(segmentation_ids)} != {n_prompts}"
105        )
106
107    # Compute the image embeddings.
108    image_embeddings = util.precompute_image_embeddings(
109        predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
110    )
111    util.set_precomputed(predictor, image_embeddings)
112
113    # Determine the number of batches.
114    n_batches = int(np.ceil(float(n_prompts) / batch_size))
115
116    # Preprocess the prompts.
117    device = predictor.device
118    transform_function = ResizeLongestSide(1024)
119    image_shape = predictor.original_size
120    if have_boxes:
121        boxes = transform_function.apply_boxes(boxes, image_shape)
122        boxes = torch.tensor(boxes, dtype=torch.float32).to(device)
123    if have_points:
124        points = transform_function.apply_coords(points, image_shape)
125        points = torch.tensor(points, dtype=torch.float32).to(device)
126        point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device)
127
128    masks = amg_utils.MaskData()
129    for batch_idx in range(n_batches):
130        batch_start = batch_idx * batch_size
131        batch_stop = min((batch_idx + 1) * batch_size, n_prompts)
132
133        batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None
134        batch_points = points[batch_start:batch_stop] if have_points else None
135        batch_labels = point_labels[batch_start:batch_stop] if have_points else None
136        batch_logits = logits_masks[batch_start:batch_stop] if have_logits else None
137
138        batch_masks, batch_ious, batch_logits = predictor.predict_torch(
139            point_coords=batch_points,
140            point_labels=batch_labels,
141            boxes=batch_boxes,
142            mask_input=batch_logits,
143            multimask_output=multimasking
144        )
145
146        # If we expect to reduce the masks from multimasking and use multi-masking,
147        # then we need to select the most likely mask (according to the predicted IOU) here.
148        if reduce_multimasking and multimasking:
149            _, max_index = batch_ious.max(axis=1)
150            batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
151            batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
152            batch_logits = torch.cat([batch_logits[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
153
154        batch_data = amg_utils.MaskData(masks=batch_masks.flatten(0, 1), iou_preds=batch_ious.flatten(0, 1))
155        batch_data["masks"] = (batch_data["masks"] > predictor.model.mask_threshold).type(torch.bool)
156        batch_data["boxes"] = batched_mask_to_box(batch_data["masks"])
157        batch_data["logits"] = batch_logits
158
159        masks.cat(batch_data)
160
161    # Mask data to records.
162    masks = [
163        {
164            "segmentation": masks["masks"][idx],
165            "area": masks["masks"][idx].sum(),
166            "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(),
167            "predicted_iou": masks["iou_preds"][idx].item(),
168            "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]),
169            "logits": masks["logits"][idx]
170        }
171        for idx in range(len(masks["masks"]))
172    ]
173
174    if return_instance_segmentation:
175        masks = mask_data_to_segmentation(masks, with_background=False, min_object_size=0)
176
177    return masks
@torch.no_grad()
def batched_inference( predictor: segment_anything.predictor.SamPredictor, image: numpy.ndarray, batch_size: int, boxes: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, point_labels: Optional[numpy.ndarray] = None, multimasking: bool = False, embedding_path: Union[str, os.PathLike, NoneType] = None, return_instance_segmentation: bool = True, segmentation_ids: Optional[list] = None, reduce_multimasking: bool = True, logits_masks: Optional[torch.Tensor] = None, verbose_embeddings: bool = True):
 18@torch.no_grad()
 19def batched_inference(
 20    predictor: SamPredictor,
 21    image: np.ndarray,
 22    batch_size: int,
 23    boxes: Optional[np.ndarray] = None,
 24    points: Optional[np.ndarray] = None,
 25    point_labels: Optional[np.ndarray] = None,
 26    multimasking: bool = False,
 27    embedding_path: Optional[Union[str, os.PathLike]] = None,
 28    return_instance_segmentation: bool = True,
 29    segmentation_ids: Optional[list] = None,
 30    reduce_multimasking: bool = True,
 31    logits_masks: Optional[torch.Tensor] = None,
 32    verbose_embeddings: bool = True,
 33):
 34    """Run batched inference for input prompts.
 35
 36    Args:
 37        predictor: The segment anything predictor.
 38        image: The input image.
 39        batch_size: The batch size to use for inference.
 40        boxes: The box prompts. Array of shape N_PROMPTS x 4.
 41            The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
 42        points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2.
 43            The points are represented by their coordinates [X, Y], which are given in the last dimension.
 44        point_labels: The point prompt labels. Array of shape N_PROMPTS x 1.
 45            The labels are either 0 (negative prompt) or 1 (positive prompt).
 46        multimasking: Whether to predict with 3 or 1 mask.
 47        embedding_path: Cache path for the image embeddings.
 48        return_instance_segmentation: Whether to return a instance segmentation
 49            or the individual mask data.
 50        segmentation_ids: Fixed segmentation ids to assign to the masks
 51            derived from the prompts.
 52        reduce_multimasking: Whether to choose the most likely masks with
 53            highest ious from multimasking
 54        logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256.
 55            Whether to use the logits masks from previous segmentation.
 56        verbose_embeddings: Whether to show progress outputs of computing image embeddings.
 57
 58    Returns:
 59        The predicted segmentation masks.
 60    """
 61    if multimasking and (segmentation_ids is not None) and (not return_instance_segmentation):
 62        raise NotImplementedError
 63
 64    if (points is None) != (point_labels is None):
 65        raise ValueError(
 66            "If you have point prompts both `points` and `point_labels` have to be passed, "
 67            "but you passed only one of them."
 68        )
 69
 70    have_points = points is not None
 71    have_boxes = boxes is not None
 72    have_logits = logits_masks is not None
 73    if (not have_points) and (not have_boxes):
 74        raise ValueError("Point and/or box prompts have to be passed, you passed neither.")
 75
 76    if have_points and (len(point_labels) != len(points)):
 77        raise ValueError(
 78            "The number of point coordinates and labels does not match: "
 79            f"{len(point_labels)} != {len(points)}"
 80        )
 81
 82    if (have_points and have_boxes) and (len(points) != len(boxes)):
 83        raise ValueError(
 84            "The number of point and box prompts does not match: "
 85            f"{len(points)} != {len(boxes)}"
 86        )
 87
 88    if have_logits:
 89        if have_points and (len(logits_masks) != len(point_labels)):
 90            raise ValueError(
 91                "The number of point and logits does not match: "
 92                f"{len(points) != len(logits_masks)}"
 93            )
 94        elif have_boxes and (len(logits_masks) != len(boxes)):
 95            raise ValueError(
 96                "The number of boxes and logits does not match: "
 97                f"{len(boxes)} != {len(logits_masks)}"
 98            )
 99
100    n_prompts = boxes.shape[0] if have_boxes else points.shape[0]
101
102    if (segmentation_ids is not None) and (len(segmentation_ids) != n_prompts):
103        raise ValueError(
104            "The number of segmentation ids and prompts does not match: "
105            f"{len(segmentation_ids)} != {n_prompts}"
106        )
107
108    # Compute the image embeddings.
109    image_embeddings = util.precompute_image_embeddings(
110        predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
111    )
112    util.set_precomputed(predictor, image_embeddings)
113
114    # Determine the number of batches.
115    n_batches = int(np.ceil(float(n_prompts) / batch_size))
116
117    # Preprocess the prompts.
118    device = predictor.device
119    transform_function = ResizeLongestSide(1024)
120    image_shape = predictor.original_size
121    if have_boxes:
122        boxes = transform_function.apply_boxes(boxes, image_shape)
123        boxes = torch.tensor(boxes, dtype=torch.float32).to(device)
124    if have_points:
125        points = transform_function.apply_coords(points, image_shape)
126        points = torch.tensor(points, dtype=torch.float32).to(device)
127        point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device)
128
129    masks = amg_utils.MaskData()
130    for batch_idx in range(n_batches):
131        batch_start = batch_idx * batch_size
132        batch_stop = min((batch_idx + 1) * batch_size, n_prompts)
133
134        batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None
135        batch_points = points[batch_start:batch_stop] if have_points else None
136        batch_labels = point_labels[batch_start:batch_stop] if have_points else None
137        batch_logits = logits_masks[batch_start:batch_stop] if have_logits else None
138
139        batch_masks, batch_ious, batch_logits = predictor.predict_torch(
140            point_coords=batch_points,
141            point_labels=batch_labels,
142            boxes=batch_boxes,
143            mask_input=batch_logits,
144            multimask_output=multimasking
145        )
146
147        # If we expect to reduce the masks from multimasking and use multi-masking,
148        # then we need to select the most likely mask (according to the predicted IOU) here.
149        if reduce_multimasking and multimasking:
150            _, max_index = batch_ious.max(axis=1)
151            batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
152            batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
153            batch_logits = torch.cat([batch_logits[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
154
155        batch_data = amg_utils.MaskData(masks=batch_masks.flatten(0, 1), iou_preds=batch_ious.flatten(0, 1))
156        batch_data["masks"] = (batch_data["masks"] > predictor.model.mask_threshold).type(torch.bool)
157        batch_data["boxes"] = batched_mask_to_box(batch_data["masks"])
158        batch_data["logits"] = batch_logits
159
160        masks.cat(batch_data)
161
162    # Mask data to records.
163    masks = [
164        {
165            "segmentation": masks["masks"][idx],
166            "area": masks["masks"][idx].sum(),
167            "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(),
168            "predicted_iou": masks["iou_preds"][idx].item(),
169            "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]),
170            "logits": masks["logits"][idx]
171        }
172        for idx in range(len(masks["masks"]))
173    ]
174
175    if return_instance_segmentation:
176        masks = mask_data_to_segmentation(masks, with_background=False, min_object_size=0)
177
178    return masks

Run batched inference for input prompts.

Arguments:
  • predictor: The segment anything predictor.
  • image: The input image.
  • batch_size: The batch size to use for inference.
  • boxes: The box prompts. Array of shape N_PROMPTS x 4. The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
  • points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. The points are represented by their coordinates [X, Y], which are given in the last dimension.
  • point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. The labels are either 0 (negative prompt) or 1 (positive prompt).
  • multimasking: Whether to predict with 3 or 1 mask.
  • embedding_path: Cache path for the image embeddings.
  • return_instance_segmentation: Whether to return a instance segmentation or the individual mask data.
  • segmentation_ids: Fixed segmentation ids to assign to the masks derived from the prompts.
  • reduce_multimasking: Whether to choose the most likely masks with highest ious from multimasking
  • logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. Whether to use the logits masks from previous segmentation.
  • verbose_embeddings: Whether to show progress outputs of computing image embeddings.
Returns:

The predicted segmentation masks.