micro_sam.inference
1import os 2from typing import Optional, Union 3 4import numpy as np 5 6import torch 7 8from segment_anything import SamPredictor 9import segment_anything.utils.amg as amg_utils 10from segment_anything.utils.transforms import ResizeLongestSide 11 12from . import util 13from .instance_segmentation import mask_data_to_segmentation 14from ._vendored import batched_mask_to_box 15 16 17@torch.no_grad() 18def batched_inference( 19 predictor: SamPredictor, 20 image: np.ndarray, 21 batch_size: int, 22 boxes: Optional[np.ndarray] = None, 23 points: Optional[np.ndarray] = None, 24 point_labels: Optional[np.ndarray] = None, 25 multimasking: bool = False, 26 embedding_path: Optional[Union[str, os.PathLike]] = None, 27 return_instance_segmentation: bool = True, 28 segmentation_ids: Optional[list] = None, 29 reduce_multimasking: bool = True, 30 logits_masks: Optional[torch.Tensor] = None, 31 verbose_embeddings: bool = True, 32): 33 """Run batched inference for input prompts. 34 35 Args: 36 predictor: The segment anything predictor. 37 image: The input image. 38 batch_size: The batch size to use for inference. 39 boxes: The box prompts. Array of shape N_PROMPTS x 4. 40 The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y]. 41 points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. 42 The points are represented by their coordinates [X, Y], which are given 43 in the last dimension. 44 point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. 45 The labels are either 0 (negative prompt) or 1 (positive prompt). 46 multimasking: Whether to predict with 3 or 1 mask. 47 embedding_path: Cache path for the image embeddings. 48 return_instance_segmentation: Whether to return a instance segmentation 49 or the individual mask data. 50 segmentation_ids: Fixed segmentation ids to assign to the masks 51 derived from the prompts. 52 reduce_multimasking: Whether to choose the most likely masks with 53 highest ious from multimasking 54 logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. 55 Whether to use the logits masks from previous segmentation. 56 verbose_embeddings: Whether to show progress outputs of computing image embeddings. 57 58 Returns: 59 The predicted segmentation masks. 60 """ 61 if multimasking and (segmentation_ids is not None) and (not return_instance_segmentation): 62 raise NotImplementedError 63 64 if (points is None) != (point_labels is None): 65 raise ValueError( 66 "If you have point prompts both `points` and `point_labels` have to be passed, " 67 "but you passed only one of them." 68 ) 69 70 have_points = points is not None 71 have_boxes = boxes is not None 72 have_logits = logits_masks is not None 73 if (not have_points) and (not have_boxes): 74 raise ValueError("Point and/or box prompts have to be passed, you passed neither.") 75 76 if have_points and (len(point_labels) != len(points)): 77 raise ValueError( 78 "The number of point coordinates and labels does not match: " 79 f"{len(point_labels)} != {len(points)}" 80 ) 81 82 if (have_points and have_boxes) and (len(points) != len(boxes)): 83 raise ValueError( 84 "The number of point and box prompts does not match: " 85 f"{len(points)} != {len(boxes)}" 86 ) 87 88 if have_logits: 89 if have_points and (len(logits_masks) != len(point_labels)): 90 raise ValueError( 91 "The number of point and logits does not match: " 92 f"{len(points) != len(logits_masks)}" 93 ) 94 elif have_boxes and (len(logits_masks) != len(boxes)): 95 raise ValueError( 96 "The number of boxes and logits does not match: " 97 f"{len(boxes)} != {len(logits_masks)}" 98 ) 99 100 n_prompts = boxes.shape[0] if have_boxes else points.shape[0] 101 102 if (segmentation_ids is not None) and (len(segmentation_ids) != n_prompts): 103 raise ValueError( 104 "The number of segmentation ids and prompts does not match: " 105 f"{len(segmentation_ids)} != {n_prompts}" 106 ) 107 108 # Compute the image embeddings. 109 image_embeddings = util.precompute_image_embeddings( 110 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings 111 ) 112 util.set_precomputed(predictor, image_embeddings) 113 114 # Determine the number of batches. 115 n_batches = int(np.ceil(float(n_prompts) / batch_size)) 116 117 # Preprocess the prompts. 118 device = predictor.device 119 transform_function = ResizeLongestSide(1024) 120 image_shape = predictor.original_size 121 if have_boxes: 122 boxes = transform_function.apply_boxes(boxes, image_shape) 123 boxes = torch.tensor(boxes, dtype=torch.float32).to(device) 124 if have_points: 125 points = transform_function.apply_coords(points, image_shape) 126 points = torch.tensor(points, dtype=torch.float32).to(device) 127 point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device) 128 129 masks = amg_utils.MaskData() 130 for batch_idx in range(n_batches): 131 batch_start = batch_idx * batch_size 132 batch_stop = min((batch_idx + 1) * batch_size, n_prompts) 133 134 batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None 135 batch_points = points[batch_start:batch_stop] if have_points else None 136 batch_labels = point_labels[batch_start:batch_stop] if have_points else None 137 batch_logits = logits_masks[batch_start:batch_stop] if have_logits else None 138 139 batch_masks, batch_ious, batch_logits = predictor.predict_torch( 140 point_coords=batch_points, 141 point_labels=batch_labels, 142 boxes=batch_boxes, 143 mask_input=batch_logits, 144 multimask_output=multimasking 145 ) 146 147 # If we expect to reduce the masks from multimasking and use multi-masking, 148 # then we need to select the most likely mask (according to the predicted IOU) here. 149 if reduce_multimasking and multimasking: 150 _, max_index = batch_ious.max(axis=1) 151 batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 152 batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 153 batch_logits = torch.cat([batch_logits[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 154 155 batch_data = amg_utils.MaskData(masks=batch_masks.flatten(0, 1), iou_preds=batch_ious.flatten(0, 1)) 156 batch_data["masks"] = (batch_data["masks"] > predictor.model.mask_threshold).type(torch.bool) 157 batch_data["boxes"] = batched_mask_to_box(batch_data["masks"]) 158 batch_data["logits"] = batch_logits 159 160 masks.cat(batch_data) 161 162 # Mask data to records. 163 masks = [ 164 { 165 "segmentation": masks["masks"][idx], 166 "area": masks["masks"][idx].sum(), 167 "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(), 168 "predicted_iou": masks["iou_preds"][idx].item(), 169 "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]), 170 "logits": masks["logits"][idx] 171 } 172 for idx in range(len(masks["masks"])) 173 ] 174 175 if return_instance_segmentation: 176 masks = mask_data_to_segmentation(masks, with_background=False, min_object_size=0) 177 178 return masks
@torch.no_grad()
def
batched_inference( predictor: segment_anything.predictor.SamPredictor, image: numpy.ndarray, batch_size: int, boxes: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, point_labels: Optional[numpy.ndarray] = None, multimasking: bool = False, embedding_path: Union[str, os.PathLike, NoneType] = None, return_instance_segmentation: bool = True, segmentation_ids: Optional[list] = None, reduce_multimasking: bool = True, logits_masks: Optional[torch.Tensor] = None, verbose_embeddings: bool = True):
18@torch.no_grad() 19def batched_inference( 20 predictor: SamPredictor, 21 image: np.ndarray, 22 batch_size: int, 23 boxes: Optional[np.ndarray] = None, 24 points: Optional[np.ndarray] = None, 25 point_labels: Optional[np.ndarray] = None, 26 multimasking: bool = False, 27 embedding_path: Optional[Union[str, os.PathLike]] = None, 28 return_instance_segmentation: bool = True, 29 segmentation_ids: Optional[list] = None, 30 reduce_multimasking: bool = True, 31 logits_masks: Optional[torch.Tensor] = None, 32 verbose_embeddings: bool = True, 33): 34 """Run batched inference for input prompts. 35 36 Args: 37 predictor: The segment anything predictor. 38 image: The input image. 39 batch_size: The batch size to use for inference. 40 boxes: The box prompts. Array of shape N_PROMPTS x 4. 41 The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y]. 42 points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. 43 The points are represented by their coordinates [X, Y], which are given 44 in the last dimension. 45 point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. 46 The labels are either 0 (negative prompt) or 1 (positive prompt). 47 multimasking: Whether to predict with 3 or 1 mask. 48 embedding_path: Cache path for the image embeddings. 49 return_instance_segmentation: Whether to return a instance segmentation 50 or the individual mask data. 51 segmentation_ids: Fixed segmentation ids to assign to the masks 52 derived from the prompts. 53 reduce_multimasking: Whether to choose the most likely masks with 54 highest ious from multimasking 55 logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. 56 Whether to use the logits masks from previous segmentation. 57 verbose_embeddings: Whether to show progress outputs of computing image embeddings. 58 59 Returns: 60 The predicted segmentation masks. 61 """ 62 if multimasking and (segmentation_ids is not None) and (not return_instance_segmentation): 63 raise NotImplementedError 64 65 if (points is None) != (point_labels is None): 66 raise ValueError( 67 "If you have point prompts both `points` and `point_labels` have to be passed, " 68 "but you passed only one of them." 69 ) 70 71 have_points = points is not None 72 have_boxes = boxes is not None 73 have_logits = logits_masks is not None 74 if (not have_points) and (not have_boxes): 75 raise ValueError("Point and/or box prompts have to be passed, you passed neither.") 76 77 if have_points and (len(point_labels) != len(points)): 78 raise ValueError( 79 "The number of point coordinates and labels does not match: " 80 f"{len(point_labels)} != {len(points)}" 81 ) 82 83 if (have_points and have_boxes) and (len(points) != len(boxes)): 84 raise ValueError( 85 "The number of point and box prompts does not match: " 86 f"{len(points)} != {len(boxes)}" 87 ) 88 89 if have_logits: 90 if have_points and (len(logits_masks) != len(point_labels)): 91 raise ValueError( 92 "The number of point and logits does not match: " 93 f"{len(points) != len(logits_masks)}" 94 ) 95 elif have_boxes and (len(logits_masks) != len(boxes)): 96 raise ValueError( 97 "The number of boxes and logits does not match: " 98 f"{len(boxes)} != {len(logits_masks)}" 99 ) 100 101 n_prompts = boxes.shape[0] if have_boxes else points.shape[0] 102 103 if (segmentation_ids is not None) and (len(segmentation_ids) != n_prompts): 104 raise ValueError( 105 "The number of segmentation ids and prompts does not match: " 106 f"{len(segmentation_ids)} != {n_prompts}" 107 ) 108 109 # Compute the image embeddings. 110 image_embeddings = util.precompute_image_embeddings( 111 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings 112 ) 113 util.set_precomputed(predictor, image_embeddings) 114 115 # Determine the number of batches. 116 n_batches = int(np.ceil(float(n_prompts) / batch_size)) 117 118 # Preprocess the prompts. 119 device = predictor.device 120 transform_function = ResizeLongestSide(1024) 121 image_shape = predictor.original_size 122 if have_boxes: 123 boxes = transform_function.apply_boxes(boxes, image_shape) 124 boxes = torch.tensor(boxes, dtype=torch.float32).to(device) 125 if have_points: 126 points = transform_function.apply_coords(points, image_shape) 127 points = torch.tensor(points, dtype=torch.float32).to(device) 128 point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device) 129 130 masks = amg_utils.MaskData() 131 for batch_idx in range(n_batches): 132 batch_start = batch_idx * batch_size 133 batch_stop = min((batch_idx + 1) * batch_size, n_prompts) 134 135 batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None 136 batch_points = points[batch_start:batch_stop] if have_points else None 137 batch_labels = point_labels[batch_start:batch_stop] if have_points else None 138 batch_logits = logits_masks[batch_start:batch_stop] if have_logits else None 139 140 batch_masks, batch_ious, batch_logits = predictor.predict_torch( 141 point_coords=batch_points, 142 point_labels=batch_labels, 143 boxes=batch_boxes, 144 mask_input=batch_logits, 145 multimask_output=multimasking 146 ) 147 148 # If we expect to reduce the masks from multimasking and use multi-masking, 149 # then we need to select the most likely mask (according to the predicted IOU) here. 150 if reduce_multimasking and multimasking: 151 _, max_index = batch_ious.max(axis=1) 152 batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 153 batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 154 batch_logits = torch.cat([batch_logits[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 155 156 batch_data = amg_utils.MaskData(masks=batch_masks.flatten(0, 1), iou_preds=batch_ious.flatten(0, 1)) 157 batch_data["masks"] = (batch_data["masks"] > predictor.model.mask_threshold).type(torch.bool) 158 batch_data["boxes"] = batched_mask_to_box(batch_data["masks"]) 159 batch_data["logits"] = batch_logits 160 161 masks.cat(batch_data) 162 163 # Mask data to records. 164 masks = [ 165 { 166 "segmentation": masks["masks"][idx], 167 "area": masks["masks"][idx].sum(), 168 "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(), 169 "predicted_iou": masks["iou_preds"][idx].item(), 170 "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]), 171 "logits": masks["logits"][idx] 172 } 173 for idx in range(len(masks["masks"])) 174 ] 175 176 if return_instance_segmentation: 177 masks = mask_data_to_segmentation(masks, with_background=False, min_object_size=0) 178 179 return masks
Run batched inference for input prompts.
Arguments:
- predictor: The segment anything predictor.
- image: The input image.
- batch_size: The batch size to use for inference.
- boxes: The box prompts. Array of shape N_PROMPTS x 4. The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
- points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. The points are represented by their coordinates [X, Y], which are given in the last dimension.
- point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. The labels are either 0 (negative prompt) or 1 (positive prompt).
- multimasking: Whether to predict with 3 or 1 mask.
- embedding_path: Cache path for the image embeddings.
- return_instance_segmentation: Whether to return a instance segmentation or the individual mask data.
- segmentation_ids: Fixed segmentation ids to assign to the masks derived from the prompts.
- reduce_multimasking: Whether to choose the most likely masks with highest ious from multimasking
- logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. Whether to use the logits masks from previous segmentation.
- verbose_embeddings: Whether to show progress outputs of computing image embeddings.
Returns:
The predicted segmentation masks.