micro_sam.inference
1import os 2from typing import Optional, Union 3 4import numpy as np 5 6import torch 7 8from segment_anything import SamPredictor 9import segment_anything.utils.amg as amg_utils 10from segment_anything.utils.transforms import ResizeLongestSide 11 12from . import util 13from .instance_segmentation import mask_data_to_segmentation 14from ._vendored import batched_mask_to_box 15 16 17@torch.no_grad() 18def batched_inference( 19 predictor: SamPredictor, 20 image: np.ndarray, 21 batch_size: int, 22 boxes: Optional[np.ndarray] = None, 23 points: Optional[np.ndarray] = None, 24 point_labels: Optional[np.ndarray] = None, 25 multimasking: bool = False, 26 embedding_path: Optional[Union[str, os.PathLike]] = None, 27 return_instance_segmentation: bool = True, 28 segmentation_ids: Optional[list] = None, 29 reduce_multimasking: bool = True, 30 logits_masks: Optional[torch.Tensor] = None, 31 verbose_embeddings: bool = True, 32): 33 """Run batched inference for input prompts. 34 35 Args: 36 predictor: The segment anything predictor. 37 image: The input image. 38 batch_size: The batch size to use for inference. 39 boxes: The box prompts. Array of shape N_PROMPTS x 4. 40 The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y]. 41 points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. 42 The points are represented by their coordinates [X, Y], which are given in the last dimension. 43 point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. 44 The labels are either 0 (negative prompt) or 1 (positive prompt). 45 multimasking: Whether to predict with 3 or 1 mask. 46 embedding_path: Cache path for the image embeddings. 47 return_instance_segmentation: Whether to return a instance segmentation 48 or the individual mask data. 49 segmentation_ids: Fixed segmentation ids to assign to the masks 50 derived from the prompts. 51 reduce_multimasking: Whether to choose the most likely masks with 52 highest ious from multimasking 53 logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. 54 Whether to use the logits masks from previous segmentation. 55 verbose_embeddings: Whether to show progress outputs of computing image embeddings. 56 57 Returns: 58 The predicted segmentation masks. 59 """ 60 if multimasking and (segmentation_ids is not None) and (not return_instance_segmentation): 61 raise NotImplementedError 62 63 if (points is None) != (point_labels is None): 64 raise ValueError( 65 "If you have point prompts both `points` and `point_labels` have to be passed, " 66 "but you passed only one of them." 67 ) 68 69 have_points = points is not None 70 have_boxes = boxes is not None 71 have_logits = logits_masks is not None 72 if (not have_points) and (not have_boxes): 73 raise ValueError("Point and/or box prompts have to be passed, you passed neither.") 74 75 if have_points and (len(point_labels) != len(points)): 76 raise ValueError( 77 "The number of point coordinates and labels does not match: " 78 f"{len(point_labels)} != {len(points)}" 79 ) 80 81 if (have_points and have_boxes) and (len(points) != len(boxes)): 82 raise ValueError( 83 "The number of point and box prompts does not match: " 84 f"{len(points)} != {len(boxes)}" 85 ) 86 87 if have_logits: 88 if have_points and (len(logits_masks) != len(point_labels)): 89 raise ValueError( 90 "The number of point and logits does not match: " 91 f"{len(points) != len(logits_masks)}" 92 ) 93 elif have_boxes and (len(logits_masks) != len(boxes)): 94 raise ValueError( 95 "The number of boxes and logits does not match: " 96 f"{len(boxes)} != {len(logits_masks)}" 97 ) 98 99 n_prompts = boxes.shape[0] if have_boxes else points.shape[0] 100 101 if (segmentation_ids is not None) and (len(segmentation_ids) != n_prompts): 102 raise ValueError( 103 "The number of segmentation ids and prompts does not match: " 104 f"{len(segmentation_ids)} != {n_prompts}" 105 ) 106 107 # Compute the image embeddings. 108 image_embeddings = util.precompute_image_embeddings( 109 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings 110 ) 111 util.set_precomputed(predictor, image_embeddings) 112 113 # Determine the number of batches. 114 n_batches = int(np.ceil(float(n_prompts) / batch_size)) 115 116 # Preprocess the prompts. 117 device = predictor.device 118 transform_function = ResizeLongestSide(1024) 119 image_shape = predictor.original_size 120 if have_boxes: 121 boxes = transform_function.apply_boxes(boxes, image_shape) 122 boxes = torch.tensor(boxes, dtype=torch.float32).to(device) 123 if have_points: 124 points = transform_function.apply_coords(points, image_shape) 125 points = torch.tensor(points, dtype=torch.float32).to(device) 126 point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device) 127 128 masks = amg_utils.MaskData() 129 for batch_idx in range(n_batches): 130 batch_start = batch_idx * batch_size 131 batch_stop = min((batch_idx + 1) * batch_size, n_prompts) 132 133 batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None 134 batch_points = points[batch_start:batch_stop] if have_points else None 135 batch_labels = point_labels[batch_start:batch_stop] if have_points else None 136 batch_logits = logits_masks[batch_start:batch_stop] if have_logits else None 137 138 batch_masks, batch_ious, batch_logits = predictor.predict_torch( 139 point_coords=batch_points, 140 point_labels=batch_labels, 141 boxes=batch_boxes, 142 mask_input=batch_logits, 143 multimask_output=multimasking 144 ) 145 146 # If we expect to reduce the masks from multimasking and use multi-masking, 147 # then we need to select the most likely mask (according to the predicted IOU) here. 148 if reduce_multimasking and multimasking: 149 _, max_index = batch_ious.max(axis=1) 150 batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 151 batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 152 batch_logits = torch.cat([batch_logits[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 153 154 batch_data = amg_utils.MaskData(masks=batch_masks.flatten(0, 1), iou_preds=batch_ious.flatten(0, 1)) 155 batch_data["masks"] = (batch_data["masks"] > predictor.model.mask_threshold).type(torch.bool) 156 batch_data["boxes"] = batched_mask_to_box(batch_data["masks"]) 157 batch_data["logits"] = batch_logits 158 159 masks.cat(batch_data) 160 161 # Mask data to records. 162 masks = [ 163 { 164 "segmentation": masks["masks"][idx], 165 "area": masks["masks"][idx].sum(), 166 "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(), 167 "predicted_iou": masks["iou_preds"][idx].item(), 168 "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]), 169 "logits": masks["logits"][idx] 170 } 171 for idx in range(len(masks["masks"])) 172 ] 173 174 if return_instance_segmentation: 175 masks = mask_data_to_segmentation(masks, with_background=False, min_object_size=0) 176 177 return masks
@torch.no_grad()
def
batched_inference( predictor: segment_anything.predictor.SamPredictor, image: numpy.ndarray, batch_size: int, boxes: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, point_labels: Optional[numpy.ndarray] = None, multimasking: bool = False, embedding_path: Union[str, os.PathLike, NoneType] = None, return_instance_segmentation: bool = True, segmentation_ids: Optional[list] = None, reduce_multimasking: bool = True, logits_masks: Optional[torch.Tensor] = None, verbose_embeddings: bool = True):
18@torch.no_grad() 19def batched_inference( 20 predictor: SamPredictor, 21 image: np.ndarray, 22 batch_size: int, 23 boxes: Optional[np.ndarray] = None, 24 points: Optional[np.ndarray] = None, 25 point_labels: Optional[np.ndarray] = None, 26 multimasking: bool = False, 27 embedding_path: Optional[Union[str, os.PathLike]] = None, 28 return_instance_segmentation: bool = True, 29 segmentation_ids: Optional[list] = None, 30 reduce_multimasking: bool = True, 31 logits_masks: Optional[torch.Tensor] = None, 32 verbose_embeddings: bool = True, 33): 34 """Run batched inference for input prompts. 35 36 Args: 37 predictor: The segment anything predictor. 38 image: The input image. 39 batch_size: The batch size to use for inference. 40 boxes: The box prompts. Array of shape N_PROMPTS x 4. 41 The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y]. 42 points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. 43 The points are represented by their coordinates [X, Y], which are given in the last dimension. 44 point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. 45 The labels are either 0 (negative prompt) or 1 (positive prompt). 46 multimasking: Whether to predict with 3 or 1 mask. 47 embedding_path: Cache path for the image embeddings. 48 return_instance_segmentation: Whether to return a instance segmentation 49 or the individual mask data. 50 segmentation_ids: Fixed segmentation ids to assign to the masks 51 derived from the prompts. 52 reduce_multimasking: Whether to choose the most likely masks with 53 highest ious from multimasking 54 logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. 55 Whether to use the logits masks from previous segmentation. 56 verbose_embeddings: Whether to show progress outputs of computing image embeddings. 57 58 Returns: 59 The predicted segmentation masks. 60 """ 61 if multimasking and (segmentation_ids is not None) and (not return_instance_segmentation): 62 raise NotImplementedError 63 64 if (points is None) != (point_labels is None): 65 raise ValueError( 66 "If you have point prompts both `points` and `point_labels` have to be passed, " 67 "but you passed only one of them." 68 ) 69 70 have_points = points is not None 71 have_boxes = boxes is not None 72 have_logits = logits_masks is not None 73 if (not have_points) and (not have_boxes): 74 raise ValueError("Point and/or box prompts have to be passed, you passed neither.") 75 76 if have_points and (len(point_labels) != len(points)): 77 raise ValueError( 78 "The number of point coordinates and labels does not match: " 79 f"{len(point_labels)} != {len(points)}" 80 ) 81 82 if (have_points and have_boxes) and (len(points) != len(boxes)): 83 raise ValueError( 84 "The number of point and box prompts does not match: " 85 f"{len(points)} != {len(boxes)}" 86 ) 87 88 if have_logits: 89 if have_points and (len(logits_masks) != len(point_labels)): 90 raise ValueError( 91 "The number of point and logits does not match: " 92 f"{len(points) != len(logits_masks)}" 93 ) 94 elif have_boxes and (len(logits_masks) != len(boxes)): 95 raise ValueError( 96 "The number of boxes and logits does not match: " 97 f"{len(boxes)} != {len(logits_masks)}" 98 ) 99 100 n_prompts = boxes.shape[0] if have_boxes else points.shape[0] 101 102 if (segmentation_ids is not None) and (len(segmentation_ids) != n_prompts): 103 raise ValueError( 104 "The number of segmentation ids and prompts does not match: " 105 f"{len(segmentation_ids)} != {n_prompts}" 106 ) 107 108 # Compute the image embeddings. 109 image_embeddings = util.precompute_image_embeddings( 110 predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings 111 ) 112 util.set_precomputed(predictor, image_embeddings) 113 114 # Determine the number of batches. 115 n_batches = int(np.ceil(float(n_prompts) / batch_size)) 116 117 # Preprocess the prompts. 118 device = predictor.device 119 transform_function = ResizeLongestSide(1024) 120 image_shape = predictor.original_size 121 if have_boxes: 122 boxes = transform_function.apply_boxes(boxes, image_shape) 123 boxes = torch.tensor(boxes, dtype=torch.float32).to(device) 124 if have_points: 125 points = transform_function.apply_coords(points, image_shape) 126 points = torch.tensor(points, dtype=torch.float32).to(device) 127 point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device) 128 129 masks = amg_utils.MaskData() 130 for batch_idx in range(n_batches): 131 batch_start = batch_idx * batch_size 132 batch_stop = min((batch_idx + 1) * batch_size, n_prompts) 133 134 batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None 135 batch_points = points[batch_start:batch_stop] if have_points else None 136 batch_labels = point_labels[batch_start:batch_stop] if have_points else None 137 batch_logits = logits_masks[batch_start:batch_stop] if have_logits else None 138 139 batch_masks, batch_ious, batch_logits = predictor.predict_torch( 140 point_coords=batch_points, 141 point_labels=batch_labels, 142 boxes=batch_boxes, 143 mask_input=batch_logits, 144 multimask_output=multimasking 145 ) 146 147 # If we expect to reduce the masks from multimasking and use multi-masking, 148 # then we need to select the most likely mask (according to the predicted IOU) here. 149 if reduce_multimasking and multimasking: 150 _, max_index = batch_ious.max(axis=1) 151 batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 152 batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 153 batch_logits = torch.cat([batch_logits[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 154 155 batch_data = amg_utils.MaskData(masks=batch_masks.flatten(0, 1), iou_preds=batch_ious.flatten(0, 1)) 156 batch_data["masks"] = (batch_data["masks"] > predictor.model.mask_threshold).type(torch.bool) 157 batch_data["boxes"] = batched_mask_to_box(batch_data["masks"]) 158 batch_data["logits"] = batch_logits 159 160 masks.cat(batch_data) 161 162 # Mask data to records. 163 masks = [ 164 { 165 "segmentation": masks["masks"][idx], 166 "area": masks["masks"][idx].sum(), 167 "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(), 168 "predicted_iou": masks["iou_preds"][idx].item(), 169 "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]), 170 "logits": masks["logits"][idx] 171 } 172 for idx in range(len(masks["masks"])) 173 ] 174 175 if return_instance_segmentation: 176 masks = mask_data_to_segmentation(masks, with_background=False, min_object_size=0) 177 178 return masks
Run batched inference for input prompts.
Arguments:
- predictor: The segment anything predictor.
- image: The input image.
- batch_size: The batch size to use for inference.
- boxes: The box prompts. Array of shape N_PROMPTS x 4. The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
- points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. The points are represented by their coordinates [X, Y], which are given in the last dimension.
- point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. The labels are either 0 (negative prompt) or 1 (positive prompt).
- multimasking: Whether to predict with 3 or 1 mask.
- embedding_path: Cache path for the image embeddings.
- return_instance_segmentation: Whether to return a instance segmentation or the individual mask data.
- segmentation_ids: Fixed segmentation ids to assign to the masks derived from the prompts.
- reduce_multimasking: Whether to choose the most likely masks with highest ious from multimasking
- logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. Whether to use the logits masks from previous segmentation.
- verbose_embeddings: Whether to show progress outputs of computing image embeddings.
Returns:
The predicted segmentation masks.