micro_sam.prompt_generators
Classes for generating prompts from ground-truth segmentation masks. For training or evaluation of prompt-based segmentation.
1""" 2Classes for generating prompts from ground-truth segmentation masks. 3For training or evaluation of prompt-based segmentation. 4""" 5 6from typing import List, Optional, Tuple 7 8import numpy as np 9from kornia import morphology 10 11import torch 12 13 14class PromptGeneratorBase: 15 """PromptGeneratorBase is an interface to implement specific prompt generators. 16 """ 17 def __call__( 18 self, 19 segmentation: torch.Tensor, 20 prediction: Optional[torch.Tensor] = None, 21 bbox_coordinates: Optional[List[tuple]] = None, 22 center_coordinates: Optional[List[np.ndarray]] = None 23 ) -> Tuple[ 24 Optional[torch.Tensor], # the point coordinates 25 Optional[torch.Tensor], # the point labels 26 Optional[torch.Tensor], # the bounding boxes 27 Optional[torch.Tensor], # the mask prompts 28 ]: 29 """Return the point prompts given segmentation masks and optional other inputs. 30 31 Args: 32 segmentation: The object masks derived from instance segmentation groundtruth. 33 Expects a float tensor of shape NUM_OBJECTS x 1 x H x W. 34 The first axis corresponds to the binary object masks. 35 prediction: The predicted object masks corresponding to the segmentation. 36 Expects the same shape as the segmentation 37 bbox_coordinates: Precomputed bounding boxes for the segmentation. 38 Expects a list of length NUM_OBJECTS. 39 center_coordinates: Precomputed center coordinates for the segmentation. 40 Expects a list of length NUM_OBJECTS. 41 42 Returns: 43 The point prompt coordinates. Int tensor of shape NUM_OBJECTS x NUM_POINTS x 2. 44 The point coordinates are retuned in XY axis order. This means they are reversed compared 45 to the standard YX axis order used by numpy. 46 The point prompt labels. Int tensor of shape NUM_OBJECTS x NUM_POINTS. 47 The box prompts. Int tensor of shape NUM_OBJECTS x 4. 48 The box coordinates are retunred as MIN_X, MIN_Y, MAX_X, MAX_Y. 49 The mask prompts. Float tensor of shape NUM_OBJECTS x 1 x H' x W'. 50 With H' = W'= 256. 51 """ 52 raise NotImplementedError("PromptGeneratorBase is just a class template. \ 53 Use a child class that implements the specific generator instead") 54 55 56class PointAndBoxPromptGenerator(PromptGeneratorBase): 57 """Generate point and/or box prompts from an instance segmentation. 58 59 You can use this class to derive prompts from an instance segmentation, either for 60 evaluation purposes or for training Segment Anything on custom data. 61 In order to use this generator you need to precompute the bounding boxes and center 62 coordiantes of the instance segmentation, using e.g. `util.get_centers_and_bounding_boxes`. 63 64 Here's an example for how to use this class: 65 ```python 66 # Initialize generator for 1 positive and 4 negative point prompts. 67 prompt_generator = PointAndBoxPromptGenerator(1, 4, dilation_strength=8) 68 69 # Precompute the bounding boxes for the given segmentation 70 bounding_boxes, _ = util.get_centers_and_bounding_boxes(segmentation) 71 72 # generate point prompts for the objects with ids 1, 2 and 3 73 seg_ids = (1, 2, 3) 74 object_mask = np.stack([segmentation == seg_id for seg_id in seg_ids])[:, None] 75 this_bounding_boxes = [bounding_boxes[seg_id] for seg_id in seg_ids] 76 point_coords, point_labels, _, _ = prompt_generator(object_mask, this_bounding_boxes) 77 ``` 78 79 Args: 80 n_positive_points: The number of positive point prompts to generate per mask. 81 n_negative_points: The number of negative point prompts to generate per mask. 82 dilation_strength: The factor by which the mask is dilated before generating prompts. 83 get_point_prompts: Whether to generate point prompts. 84 get_box_prompts: Whether to generate box prompts. 85 """ 86 def __init__( 87 self, 88 n_positive_points: int, 89 n_negative_points: int, 90 dilation_strength: int, 91 get_point_prompts: bool = True, 92 get_box_prompts: bool = False 93 ) -> None: 94 self.n_positive_points = n_positive_points 95 self.n_negative_points = n_negative_points 96 self.dilation_strength = dilation_strength 97 self.get_box_prompts = get_box_prompts 98 self.get_point_prompts = get_point_prompts 99 100 if self.get_point_prompts is False and self.get_box_prompts is False: 101 raise ValueError("You need to request box prompts, point prompts or both.") 102 103 def _sample_positive_points(self, object_mask, center_coordinates, coord_list, label_list): 104 if center_coordinates is not None: 105 # getting the center coordinate as the first positive point (OPTIONAL) 106 coord_list.append(tuple(map(int, center_coordinates))) # to get int coords instead of float 107 108 # getting the additional positive points by randomly sampling points 109 # from this mask except the center coordinate 110 n_positive_remaining = self.n_positive_points - 1 111 112 else: 113 # need to sample "self.n_positive_points" number of points 114 n_positive_remaining = self.n_positive_points 115 116 if n_positive_remaining > 0: 117 object_coordinates = torch.where(object_mask) 118 n_coordinates = len(object_coordinates[0]) 119 120 # randomly sampling n_positive_remaining_points from these coordinates 121 indices = np.random.choice( 122 n_coordinates, size=n_positive_remaining, 123 # Allow replacing if we can't sample enough coordinates otherwise 124 replace=True if n_positive_remaining > n_coordinates else False, 125 ) 126 coord_list.extend([ 127 [object_coordinates[0][idx], object_coordinates[1][idx]] for idx in indices 128 ]) 129 130 label_list.extend([1] * self.n_positive_points) 131 assert len(coord_list) == len(label_list) == self.n_positive_points 132 return coord_list, label_list 133 134 def _sample_negative_points(self, object_mask, bbox_coordinates, coord_list, label_list): 135 if self.n_negative_points == 0: 136 return coord_list, label_list 137 138 # getting the negative points 139 # for this we do the opposite and we set the mask to the bounding box - the object mask 140 # we need to dilate the object mask before doing this: we use kornia.morphology.dilation for this 141 dilated_object = object_mask[None, None] 142 for _ in range(self.dilation_strength): 143 dilated_object = morphology.dilation(dilated_object, torch.ones(3, 3), engine="convolution") 144 dilated_object = dilated_object.squeeze() 145 146 background_mask = torch.zeros(object_mask.shape, device=object_mask.device) 147 _ds = self.dilation_strength 148 background_mask[max(bbox_coordinates[0] - _ds, 0): min(bbox_coordinates[2] + _ds, object_mask.shape[-2]), 149 max(bbox_coordinates[1] - _ds, 0): min(bbox_coordinates[3] + _ds, object_mask.shape[-1])] = 1 150 background_mask = torch.abs(background_mask - dilated_object) 151 152 # the valid background coordinates 153 background_coordinates = torch.where(background_mask) 154 n_coordinates = len(background_coordinates[0]) 155 156 # randomly sample the negative points from these coordinates 157 indices = np.random.choice( 158 n_coordinates, replace=False, 159 size=min(self.n_negative_points, n_coordinates) # handles the cases with insufficient bg pixels 160 ) 161 coord_list.extend([ 162 [background_coordinates[0][idx], background_coordinates[1][idx]] for idx in indices 163 ]) 164 label_list.extend([0] * len(indices)) 165 166 return coord_list, label_list 167 168 def _ensure_num_points(self, object_mask, coord_list, label_list): 169 num_points = self.n_positive_points + self.n_negative_points 170 171 # fill up to the necessary number of points if we did not sample enough of them 172 if len(coord_list) != num_points: 173 # to stay consistent, we add random points in the background of an object 174 # if there's no neg region around the object - usually happens with small rois 175 needed_points = num_points - len(coord_list) 176 more_neg_points = torch.where(object_mask == 0) 177 indices = np.random.choice(len(more_neg_points[0]), size=needed_points, replace=False) 178 179 coord_list.extend([ 180 (more_neg_points[0][idx], more_neg_points[1][idx]) for idx in indices 181 ]) 182 label_list.extend([0] * needed_points) 183 184 assert len(coord_list) == len(label_list) == num_points 185 return coord_list, label_list 186 187 # Can we batch this properly? 188 def _sample_points(self, segmentation, bbox_coordinates, center_coordinates): 189 all_coords, all_labels = [], [] 190 191 center_coordinates = [None] * len(segmentation) if center_coordinates is None else center_coordinates 192 for object_mask, bbox_coords, center_coords in zip(segmentation, bbox_coordinates, center_coordinates): 193 coord_list, label_list = [], [] 194 coord_list, label_list = self._sample_positive_points(object_mask[0], center_coords, coord_list, label_list) 195 coord_list, label_list = self._sample_negative_points(object_mask[0], bbox_coords, coord_list, label_list) 196 coord_list, label_list = self._ensure_num_points(object_mask[0], coord_list, label_list) 197 198 all_coords.append(coord_list) 199 all_labels.append(label_list) 200 201 return all_coords, all_labels 202 203 def __call__( 204 self, 205 segmentation: torch.Tensor, 206 bbox_coordinates: List[Tuple], 207 center_coordinates: Optional[List[np.ndarray]] = None, 208 **kwargs, 209 ) -> Tuple[ 210 Optional[torch.Tensor], 211 Optional[torch.Tensor], 212 Optional[torch.Tensor], 213 None 214 ]: 215 """Generate the prompts for one object in the segmentation. 216 217 Args: 218 The groundtruth segmentation. Expects a float tensor of shape NUM_OBJECTS x 1 x H x W. 219 bbox_coordinates: The precomputed bounding boxes of particular object in the segmentation. 220 center_coordinates: The precomputed center coordinates of particular object in the segmentation. 221 If passed, these coordinates will be used as the first positive point prompt. 222 If not passed a random point from within the object mask will be used. 223 224 Returns: 225 Coordinates of point prompts. Returns None, if get_point_prompts is false. 226 Point prompt labels. Returns None, if get_point_prompts is false. 227 Bounding box prompts. Returns None, if get_box_prompts is false. 228 """ 229 if self.get_point_prompts: 230 coord_list, label_list = self._sample_points(segmentation, bbox_coordinates, center_coordinates) 231 # change the axis convention of the point coordinates to match the expected coordinate order of SAM 232 coord_list = np.array(coord_list)[:, :, ::-1].copy() 233 coord_list = torch.from_numpy(coord_list) 234 label_list = torch.tensor(label_list) 235 else: 236 coord_list, label_list = None, None 237 238 if self.get_box_prompts: 239 # change the axis convention of the box coordinates to match the expected coordinate order of SAM 240 bbox_list = np.array(bbox_coordinates)[:, [1, 0, 3, 2]] 241 bbox_list = torch.from_numpy(bbox_list) 242 else: 243 bbox_list = None 244 245 return coord_list, label_list, bbox_list, None 246 247 248class IterativePromptGenerator(PromptGeneratorBase): 249 """Generate point prompts from an instance segmentation iteratively. 250 """ 251 def _get_positive_points(self, pos_region, overlap_region, is_3d): 252 positive_locations = [torch.where(pos_reg) for pos_reg in pos_region] 253 # we may have objects without a positive region (= missing true foreground) 254 # in this case we just sample a positive point where the model was already correct 255 positive_locations = [ 256 torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc 257 for pos_loc, ovlp_reg in zip(positive_locations, overlap_region) 258 ] 259 # we sample one positive location for each object in the batch 260 sampled_indices = [np.random.choice(len(pos_loc[0])) for pos_loc in positive_locations] 261 # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM) 262 if is_3d: 263 pos_coordinates = [ 264 [pos_loc[-1][idx], pos_loc[-2][idx], pos_loc[-3][idx]] 265 for pos_loc, idx in zip(positive_locations, sampled_indices) 266 ] 267 else: 268 pos_coordinates = [ 269 [pos_loc[-1][idx], pos_loc[-2][idx]] for pos_loc, idx in zip(positive_locations, sampled_indices) 270 ] 271 272 # make sure that we still have the correct batch size 273 assert len(pos_coordinates) == pos_region.shape[0] 274 pos_labels = [1] * len(pos_coordinates) 275 276 return pos_coordinates, pos_labels 277 278 def _get_negative_locations_in_obj_bbox(self, true_object, custom_df=3): 279 true_loc = torch.where(true_object) 280 bbox = torch.stack( 281 [torch.min(true_loc[1]), torch.min(true_loc[2]), torch.max(true_loc[1]) + 1, torch.max(true_loc[2]) + 1] 282 ) 283 284 # custom dilation factor to perform dilation by expanding the pixels of bbox 285 bbox_mask = torch.zeros_like(true_object).squeeze(0) 286 bbox_mask[ 287 max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, true_object.shape[-2]), 288 max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, true_object.shape[-1]) 289 ] = 1 290 bbox_mask = bbox_mask[None].to(true_object.device) 291 background_mask = torch.abs(bbox_mask - true_object) 292 return torch.where(background_mask) 293 294 def _get_negative_points(self, neg_region, true_object, is_3d): 295 # we have a valid negative region (i.e. a valid region where the model could not generate prediction) 296 negative_locations = [torch.where(neg_reg) for neg_reg in neg_region] 297 # we may have objects without a negative region (= no rectifications required) 298 # in this case we sample a negative point in outer periphery of the object inside the bounding box. 299 negative_locations = [ 300 self._get_negative_locations_in_obj_bbox(true_obj) if len(neg_loc[0]) == 0 else neg_loc 301 for neg_loc, true_obj in zip(negative_locations, true_object) 302 ] 303 # there is a chance that the object is small to not return a decent-sized bounding box 304 # hence we might not find points sometimes there as well. therefore, we sample points from true background. 305 negative_locations = [ 306 torch.where(true_obj == 0) if len(neg_loc[0]) == 0 else neg_loc 307 for neg_loc, true_obj in zip(negative_locations, true_object) 308 ] 309 # we sample one negative location for each object in the batch 310 sampled_indices = [np.random.choice(len(neg_loc[0])) for neg_loc in negative_locations] 311 # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM) 312 if is_3d: 313 neg_coordinates = [ 314 [neg_loc[-1][idx], neg_loc[-2][idx], neg_loc[-3][idx]] 315 for neg_loc, idx in zip(negative_locations, sampled_indices) 316 ] 317 else: 318 neg_coordinates = [ 319 [neg_loc[-1][idx], neg_loc[-2][idx]] for neg_loc, idx in zip(negative_locations, sampled_indices) 320 ] 321 322 # make sure that we still have the correct batch size 323 assert len(neg_coordinates) == neg_region.shape[0] 324 neg_labels = [0] * len(neg_coordinates) 325 326 return neg_coordinates, neg_labels 327 328 def __call__( 329 self, 330 segmentation: torch.Tensor, 331 prediction: torch.Tensor, 332 **kwargs, 333 ) -> Tuple[torch.Tensor, torch.Tensor, None, None]: 334 """Generate the prompts for each object iteratively in the segmentation. 335 336 Args: 337 segmentation: The groundtruth segmentation. 338 Expects a float tensor of shape (NUM_OBJECTS x 1 x H x W) or (NUM_OBJECTS x 1 x Z x H x W). 339 prediction: The predicted objects. Epects a float tensor of the same shape as the segmentation. 340 341 Returns: 342 The updated point prompt coordinates. 343 The updated point prompt labels. 344 """ 345 device = prediction.device 346 assert segmentation.shape == prediction.shape, \ 347 "The segmentation and prediction tensors should have the same shape." 348 349 if segmentation.ndim == 5: # masks in 3d must be tensors of shape NUM_OBJECTS x 1 x Z x H x W 350 is_3d = True 351 elif segmentation.ndim == 4: # masks in 2d must be tensors of shape NUM_OBJECTS x 1 x H x W 352 is_3d = False 353 else: 354 raise ValueError("The segmentation and prediction tensors should have either '4' or '5' dimensions.") 355 356 true_object = segmentation.to(device) 357 expected_diff = (prediction - true_object) 358 neg_region = (expected_diff == 1).to(torch.float32) 359 pos_region = (expected_diff == -1) 360 overlap_region = torch.logical_and(prediction == 1, true_object == 1).to(torch.float32) 361 362 pos_coordinates, pos_labels = self._get_positive_points(pos_region, overlap_region, is_3d) 363 neg_coordinates, neg_labels = self._get_negative_points(neg_region, true_object, is_3d) 364 assert len(pos_coordinates) == len(pos_labels) == len(neg_coordinates) == len(neg_labels) 365 366 pos_coordinates = torch.tensor(pos_coordinates)[:, None] 367 neg_coordinates = torch.tensor(neg_coordinates)[:, None] 368 pos_labels, neg_labels = torch.tensor(pos_labels)[:, None], torch.tensor(neg_labels)[:, None] 369 370 net_coords = torch.cat([pos_coordinates, neg_coordinates], dim=1) 371 net_labels = torch.cat([pos_labels, neg_labels], dim=1) 372 373 return net_coords, net_labels, None, None
class
PromptGeneratorBase:
15class PromptGeneratorBase: 16 """PromptGeneratorBase is an interface to implement specific prompt generators. 17 """ 18 def __call__( 19 self, 20 segmentation: torch.Tensor, 21 prediction: Optional[torch.Tensor] = None, 22 bbox_coordinates: Optional[List[tuple]] = None, 23 center_coordinates: Optional[List[np.ndarray]] = None 24 ) -> Tuple[ 25 Optional[torch.Tensor], # the point coordinates 26 Optional[torch.Tensor], # the point labels 27 Optional[torch.Tensor], # the bounding boxes 28 Optional[torch.Tensor], # the mask prompts 29 ]: 30 """Return the point prompts given segmentation masks and optional other inputs. 31 32 Args: 33 segmentation: The object masks derived from instance segmentation groundtruth. 34 Expects a float tensor of shape NUM_OBJECTS x 1 x H x W. 35 The first axis corresponds to the binary object masks. 36 prediction: The predicted object masks corresponding to the segmentation. 37 Expects the same shape as the segmentation 38 bbox_coordinates: Precomputed bounding boxes for the segmentation. 39 Expects a list of length NUM_OBJECTS. 40 center_coordinates: Precomputed center coordinates for the segmentation. 41 Expects a list of length NUM_OBJECTS. 42 43 Returns: 44 The point prompt coordinates. Int tensor of shape NUM_OBJECTS x NUM_POINTS x 2. 45 The point coordinates are retuned in XY axis order. This means they are reversed compared 46 to the standard YX axis order used by numpy. 47 The point prompt labels. Int tensor of shape NUM_OBJECTS x NUM_POINTS. 48 The box prompts. Int tensor of shape NUM_OBJECTS x 4. 49 The box coordinates are retunred as MIN_X, MIN_Y, MAX_X, MAX_Y. 50 The mask prompts. Float tensor of shape NUM_OBJECTS x 1 x H' x W'. 51 With H' = W'= 256. 52 """ 53 raise NotImplementedError("PromptGeneratorBase is just a class template. \ 54 Use a child class that implements the specific generator instead")
PromptGeneratorBase is an interface to implement specific prompt generators.
57class PointAndBoxPromptGenerator(PromptGeneratorBase): 58 """Generate point and/or box prompts from an instance segmentation. 59 60 You can use this class to derive prompts from an instance segmentation, either for 61 evaluation purposes or for training Segment Anything on custom data. 62 In order to use this generator you need to precompute the bounding boxes and center 63 coordiantes of the instance segmentation, using e.g. `util.get_centers_and_bounding_boxes`. 64 65 Here's an example for how to use this class: 66 ```python 67 # Initialize generator for 1 positive and 4 negative point prompts. 68 prompt_generator = PointAndBoxPromptGenerator(1, 4, dilation_strength=8) 69 70 # Precompute the bounding boxes for the given segmentation 71 bounding_boxes, _ = util.get_centers_and_bounding_boxes(segmentation) 72 73 # generate point prompts for the objects with ids 1, 2 and 3 74 seg_ids = (1, 2, 3) 75 object_mask = np.stack([segmentation == seg_id for seg_id in seg_ids])[:, None] 76 this_bounding_boxes = [bounding_boxes[seg_id] for seg_id in seg_ids] 77 point_coords, point_labels, _, _ = prompt_generator(object_mask, this_bounding_boxes) 78 ``` 79 80 Args: 81 n_positive_points: The number of positive point prompts to generate per mask. 82 n_negative_points: The number of negative point prompts to generate per mask. 83 dilation_strength: The factor by which the mask is dilated before generating prompts. 84 get_point_prompts: Whether to generate point prompts. 85 get_box_prompts: Whether to generate box prompts. 86 """ 87 def __init__( 88 self, 89 n_positive_points: int, 90 n_negative_points: int, 91 dilation_strength: int, 92 get_point_prompts: bool = True, 93 get_box_prompts: bool = False 94 ) -> None: 95 self.n_positive_points = n_positive_points 96 self.n_negative_points = n_negative_points 97 self.dilation_strength = dilation_strength 98 self.get_box_prompts = get_box_prompts 99 self.get_point_prompts = get_point_prompts 100 101 if self.get_point_prompts is False and self.get_box_prompts is False: 102 raise ValueError("You need to request box prompts, point prompts or both.") 103 104 def _sample_positive_points(self, object_mask, center_coordinates, coord_list, label_list): 105 if center_coordinates is not None: 106 # getting the center coordinate as the first positive point (OPTIONAL) 107 coord_list.append(tuple(map(int, center_coordinates))) # to get int coords instead of float 108 109 # getting the additional positive points by randomly sampling points 110 # from this mask except the center coordinate 111 n_positive_remaining = self.n_positive_points - 1 112 113 else: 114 # need to sample "self.n_positive_points" number of points 115 n_positive_remaining = self.n_positive_points 116 117 if n_positive_remaining > 0: 118 object_coordinates = torch.where(object_mask) 119 n_coordinates = len(object_coordinates[0]) 120 121 # randomly sampling n_positive_remaining_points from these coordinates 122 indices = np.random.choice( 123 n_coordinates, size=n_positive_remaining, 124 # Allow replacing if we can't sample enough coordinates otherwise 125 replace=True if n_positive_remaining > n_coordinates else False, 126 ) 127 coord_list.extend([ 128 [object_coordinates[0][idx], object_coordinates[1][idx]] for idx in indices 129 ]) 130 131 label_list.extend([1] * self.n_positive_points) 132 assert len(coord_list) == len(label_list) == self.n_positive_points 133 return coord_list, label_list 134 135 def _sample_negative_points(self, object_mask, bbox_coordinates, coord_list, label_list): 136 if self.n_negative_points == 0: 137 return coord_list, label_list 138 139 # getting the negative points 140 # for this we do the opposite and we set the mask to the bounding box - the object mask 141 # we need to dilate the object mask before doing this: we use kornia.morphology.dilation for this 142 dilated_object = object_mask[None, None] 143 for _ in range(self.dilation_strength): 144 dilated_object = morphology.dilation(dilated_object, torch.ones(3, 3), engine="convolution") 145 dilated_object = dilated_object.squeeze() 146 147 background_mask = torch.zeros(object_mask.shape, device=object_mask.device) 148 _ds = self.dilation_strength 149 background_mask[max(bbox_coordinates[0] - _ds, 0): min(bbox_coordinates[2] + _ds, object_mask.shape[-2]), 150 max(bbox_coordinates[1] - _ds, 0): min(bbox_coordinates[3] + _ds, object_mask.shape[-1])] = 1 151 background_mask = torch.abs(background_mask - dilated_object) 152 153 # the valid background coordinates 154 background_coordinates = torch.where(background_mask) 155 n_coordinates = len(background_coordinates[0]) 156 157 # randomly sample the negative points from these coordinates 158 indices = np.random.choice( 159 n_coordinates, replace=False, 160 size=min(self.n_negative_points, n_coordinates) # handles the cases with insufficient bg pixels 161 ) 162 coord_list.extend([ 163 [background_coordinates[0][idx], background_coordinates[1][idx]] for idx in indices 164 ]) 165 label_list.extend([0] * len(indices)) 166 167 return coord_list, label_list 168 169 def _ensure_num_points(self, object_mask, coord_list, label_list): 170 num_points = self.n_positive_points + self.n_negative_points 171 172 # fill up to the necessary number of points if we did not sample enough of them 173 if len(coord_list) != num_points: 174 # to stay consistent, we add random points in the background of an object 175 # if there's no neg region around the object - usually happens with small rois 176 needed_points = num_points - len(coord_list) 177 more_neg_points = torch.where(object_mask == 0) 178 indices = np.random.choice(len(more_neg_points[0]), size=needed_points, replace=False) 179 180 coord_list.extend([ 181 (more_neg_points[0][idx], more_neg_points[1][idx]) for idx in indices 182 ]) 183 label_list.extend([0] * needed_points) 184 185 assert len(coord_list) == len(label_list) == num_points 186 return coord_list, label_list 187 188 # Can we batch this properly? 189 def _sample_points(self, segmentation, bbox_coordinates, center_coordinates): 190 all_coords, all_labels = [], [] 191 192 center_coordinates = [None] * len(segmentation) if center_coordinates is None else center_coordinates 193 for object_mask, bbox_coords, center_coords in zip(segmentation, bbox_coordinates, center_coordinates): 194 coord_list, label_list = [], [] 195 coord_list, label_list = self._sample_positive_points(object_mask[0], center_coords, coord_list, label_list) 196 coord_list, label_list = self._sample_negative_points(object_mask[0], bbox_coords, coord_list, label_list) 197 coord_list, label_list = self._ensure_num_points(object_mask[0], coord_list, label_list) 198 199 all_coords.append(coord_list) 200 all_labels.append(label_list) 201 202 return all_coords, all_labels 203 204 def __call__( 205 self, 206 segmentation: torch.Tensor, 207 bbox_coordinates: List[Tuple], 208 center_coordinates: Optional[List[np.ndarray]] = None, 209 **kwargs, 210 ) -> Tuple[ 211 Optional[torch.Tensor], 212 Optional[torch.Tensor], 213 Optional[torch.Tensor], 214 None 215 ]: 216 """Generate the prompts for one object in the segmentation. 217 218 Args: 219 The groundtruth segmentation. Expects a float tensor of shape NUM_OBJECTS x 1 x H x W. 220 bbox_coordinates: The precomputed bounding boxes of particular object in the segmentation. 221 center_coordinates: The precomputed center coordinates of particular object in the segmentation. 222 If passed, these coordinates will be used as the first positive point prompt. 223 If not passed a random point from within the object mask will be used. 224 225 Returns: 226 Coordinates of point prompts. Returns None, if get_point_prompts is false. 227 Point prompt labels. Returns None, if get_point_prompts is false. 228 Bounding box prompts. Returns None, if get_box_prompts is false. 229 """ 230 if self.get_point_prompts: 231 coord_list, label_list = self._sample_points(segmentation, bbox_coordinates, center_coordinates) 232 # change the axis convention of the point coordinates to match the expected coordinate order of SAM 233 coord_list = np.array(coord_list)[:, :, ::-1].copy() 234 coord_list = torch.from_numpy(coord_list) 235 label_list = torch.tensor(label_list) 236 else: 237 coord_list, label_list = None, None 238 239 if self.get_box_prompts: 240 # change the axis convention of the box coordinates to match the expected coordinate order of SAM 241 bbox_list = np.array(bbox_coordinates)[:, [1, 0, 3, 2]] 242 bbox_list = torch.from_numpy(bbox_list) 243 else: 244 bbox_list = None 245 246 return coord_list, label_list, bbox_list, None
Generate point and/or box prompts from an instance segmentation.
You can use this class to derive prompts from an instance segmentation, either for
evaluation purposes or for training Segment Anything on custom data.
In order to use this generator you need to precompute the bounding boxes and center
coordiantes of the instance segmentation, using e.g. util.get_centers_and_bounding_boxes
.
Here's an example for how to use this class:
# Initialize generator for 1 positive and 4 negative point prompts.
prompt_generator = PointAndBoxPromptGenerator(1, 4, dilation_strength=8)
# Precompute the bounding boxes for the given segmentation
bounding_boxes, _ = util.get_centers_and_bounding_boxes(segmentation)
# generate point prompts for the objects with ids 1, 2 and 3
seg_ids = (1, 2, 3)
object_mask = np.stack([segmentation == seg_id for seg_id in seg_ids])[:, None]
this_bounding_boxes = [bounding_boxes[seg_id] for seg_id in seg_ids]
point_coords, point_labels, _, _ = prompt_generator(object_mask, this_bounding_boxes)
Arguments:
- n_positive_points: The number of positive point prompts to generate per mask.
- n_negative_points: The number of negative point prompts to generate per mask.
- dilation_strength: The factor by which the mask is dilated before generating prompts.
- get_point_prompts: Whether to generate point prompts.
- get_box_prompts: Whether to generate box prompts.
PointAndBoxPromptGenerator( n_positive_points: int, n_negative_points: int, dilation_strength: int, get_point_prompts: bool = True, get_box_prompts: bool = False)
87 def __init__( 88 self, 89 n_positive_points: int, 90 n_negative_points: int, 91 dilation_strength: int, 92 get_point_prompts: bool = True, 93 get_box_prompts: bool = False 94 ) -> None: 95 self.n_positive_points = n_positive_points 96 self.n_negative_points = n_negative_points 97 self.dilation_strength = dilation_strength 98 self.get_box_prompts = get_box_prompts 99 self.get_point_prompts = get_point_prompts 100 101 if self.get_point_prompts is False and self.get_box_prompts is False: 102 raise ValueError("You need to request box prompts, point prompts or both.")
249class IterativePromptGenerator(PromptGeneratorBase): 250 """Generate point prompts from an instance segmentation iteratively. 251 """ 252 def _get_positive_points(self, pos_region, overlap_region, is_3d): 253 positive_locations = [torch.where(pos_reg) for pos_reg in pos_region] 254 # we may have objects without a positive region (= missing true foreground) 255 # in this case we just sample a positive point where the model was already correct 256 positive_locations = [ 257 torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc 258 for pos_loc, ovlp_reg in zip(positive_locations, overlap_region) 259 ] 260 # we sample one positive location for each object in the batch 261 sampled_indices = [np.random.choice(len(pos_loc[0])) for pos_loc in positive_locations] 262 # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM) 263 if is_3d: 264 pos_coordinates = [ 265 [pos_loc[-1][idx], pos_loc[-2][idx], pos_loc[-3][idx]] 266 for pos_loc, idx in zip(positive_locations, sampled_indices) 267 ] 268 else: 269 pos_coordinates = [ 270 [pos_loc[-1][idx], pos_loc[-2][idx]] for pos_loc, idx in zip(positive_locations, sampled_indices) 271 ] 272 273 # make sure that we still have the correct batch size 274 assert len(pos_coordinates) == pos_region.shape[0] 275 pos_labels = [1] * len(pos_coordinates) 276 277 return pos_coordinates, pos_labels 278 279 def _get_negative_locations_in_obj_bbox(self, true_object, custom_df=3): 280 true_loc = torch.where(true_object) 281 bbox = torch.stack( 282 [torch.min(true_loc[1]), torch.min(true_loc[2]), torch.max(true_loc[1]) + 1, torch.max(true_loc[2]) + 1] 283 ) 284 285 # custom dilation factor to perform dilation by expanding the pixels of bbox 286 bbox_mask = torch.zeros_like(true_object).squeeze(0) 287 bbox_mask[ 288 max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, true_object.shape[-2]), 289 max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, true_object.shape[-1]) 290 ] = 1 291 bbox_mask = bbox_mask[None].to(true_object.device) 292 background_mask = torch.abs(bbox_mask - true_object) 293 return torch.where(background_mask) 294 295 def _get_negative_points(self, neg_region, true_object, is_3d): 296 # we have a valid negative region (i.e. a valid region where the model could not generate prediction) 297 negative_locations = [torch.where(neg_reg) for neg_reg in neg_region] 298 # we may have objects without a negative region (= no rectifications required) 299 # in this case we sample a negative point in outer periphery of the object inside the bounding box. 300 negative_locations = [ 301 self._get_negative_locations_in_obj_bbox(true_obj) if len(neg_loc[0]) == 0 else neg_loc 302 for neg_loc, true_obj in zip(negative_locations, true_object) 303 ] 304 # there is a chance that the object is small to not return a decent-sized bounding box 305 # hence we might not find points sometimes there as well. therefore, we sample points from true background. 306 negative_locations = [ 307 torch.where(true_obj == 0) if len(neg_loc[0]) == 0 else neg_loc 308 for neg_loc, true_obj in zip(negative_locations, true_object) 309 ] 310 # we sample one negative location for each object in the batch 311 sampled_indices = [np.random.choice(len(neg_loc[0])) for neg_loc in negative_locations] 312 # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM) 313 if is_3d: 314 neg_coordinates = [ 315 [neg_loc[-1][idx], neg_loc[-2][idx], neg_loc[-3][idx]] 316 for neg_loc, idx in zip(negative_locations, sampled_indices) 317 ] 318 else: 319 neg_coordinates = [ 320 [neg_loc[-1][idx], neg_loc[-2][idx]] for neg_loc, idx in zip(negative_locations, sampled_indices) 321 ] 322 323 # make sure that we still have the correct batch size 324 assert len(neg_coordinates) == neg_region.shape[0] 325 neg_labels = [0] * len(neg_coordinates) 326 327 return neg_coordinates, neg_labels 328 329 def __call__( 330 self, 331 segmentation: torch.Tensor, 332 prediction: torch.Tensor, 333 **kwargs, 334 ) -> Tuple[torch.Tensor, torch.Tensor, None, None]: 335 """Generate the prompts for each object iteratively in the segmentation. 336 337 Args: 338 segmentation: The groundtruth segmentation. 339 Expects a float tensor of shape (NUM_OBJECTS x 1 x H x W) or (NUM_OBJECTS x 1 x Z x H x W). 340 prediction: The predicted objects. Epects a float tensor of the same shape as the segmentation. 341 342 Returns: 343 The updated point prompt coordinates. 344 The updated point prompt labels. 345 """ 346 device = prediction.device 347 assert segmentation.shape == prediction.shape, \ 348 "The segmentation and prediction tensors should have the same shape." 349 350 if segmentation.ndim == 5: # masks in 3d must be tensors of shape NUM_OBJECTS x 1 x Z x H x W 351 is_3d = True 352 elif segmentation.ndim == 4: # masks in 2d must be tensors of shape NUM_OBJECTS x 1 x H x W 353 is_3d = False 354 else: 355 raise ValueError("The segmentation and prediction tensors should have either '4' or '5' dimensions.") 356 357 true_object = segmentation.to(device) 358 expected_diff = (prediction - true_object) 359 neg_region = (expected_diff == 1).to(torch.float32) 360 pos_region = (expected_diff == -1) 361 overlap_region = torch.logical_and(prediction == 1, true_object == 1).to(torch.float32) 362 363 pos_coordinates, pos_labels = self._get_positive_points(pos_region, overlap_region, is_3d) 364 neg_coordinates, neg_labels = self._get_negative_points(neg_region, true_object, is_3d) 365 assert len(pos_coordinates) == len(pos_labels) == len(neg_coordinates) == len(neg_labels) 366 367 pos_coordinates = torch.tensor(pos_coordinates)[:, None] 368 neg_coordinates = torch.tensor(neg_coordinates)[:, None] 369 pos_labels, neg_labels = torch.tensor(pos_labels)[:, None], torch.tensor(neg_labels)[:, None] 370 371 net_coords = torch.cat([pos_coordinates, neg_coordinates], dim=1) 372 net_labels = torch.cat([pos_labels, neg_labels], dim=1) 373 374 return net_coords, net_labels, None, None
Generate point prompts from an instance segmentation iteratively.