micro_sam.prompt_generators
Classes for generating prompts from ground-truth segmentation masks. For training or evaluation of prompt-based segmentation.
1""" 2Classes for generating prompts from ground-truth segmentation masks. 3For training or evaluation of prompt-based segmentation. 4""" 5 6from typing import List, Optional, Tuple 7 8import numpy as np 9from kornia import morphology 10 11import torch 12 13 14class PromptGeneratorBase: 15 """PromptGeneratorBase is an interface to implement specific prompt generators. 16 """ 17 def __call__( 18 self, 19 segmentation: torch.Tensor, 20 prediction: Optional[torch.Tensor] = None, 21 bbox_coordinates: Optional[List[tuple]] = None, 22 center_coordinates: Optional[List[np.ndarray]] = None 23 ) -> Tuple[ 24 Optional[torch.Tensor], # the point coordinates 25 Optional[torch.Tensor], # the point labels 26 Optional[torch.Tensor], # the bounding boxes 27 Optional[torch.Tensor], # the mask prompts 28 ]: 29 """Return the point prompts given segmentation masks and optional other inputs. 30 31 Args: 32 segmentation: The object masks derived from instance segmentation groundtruth. 33 Expects a float tensor of shape NUM_OBJECTS x 1 x H x W. 34 The first axis corresponds to the binary object masks. 35 prediction: The predicted object masks corresponding to the segmentation. 36 Expects the same shape as the segmentation 37 bbox_coordinates: Precomputed bounding boxes for the segmentation. 38 Expects a list of length NUM_OBJECTS. 39 center_coordinates: Precomputed center coordinates for the segmentation. 40 Expects a list of length NUM_OBJECTS. 41 42 Returns: 43 The point prompt coordinates. Int tensor of shape NUM_OBJECTS x NUM_POINTS x 2. 44 The point coordinates are retuned in XY axis order. This means they are reversed compared 45 to the standard YX axis order used by numpy. 46 The point prompt labels. Int tensor of shape NUM_OBJECTS x NUM_POINTS. 47 The box prompts. Int tensor of shape NUM_OBJECTS x 4. 48 The box coordinates are retunred as MIN_X, MIN_Y, MAX_X, MAX_Y. 49 The mask prompts. Float tensor of shape NUM_OBJECTS x 1 x H' x W'. 50 With H' = W'= 256. 51 """ 52 raise NotImplementedError( 53 "PromptGeneratorBase is just a class template. " 54 "Use a child class that implements the specific generator instead" 55 ) 56 57 58class PointAndBoxPromptGenerator(PromptGeneratorBase): 59 """Generate point and/or box prompts from an instance segmentation. 60 61 You can use this class to derive prompts from an instance segmentation, either for 62 evaluation purposes or for training Segment Anything on custom data. 63 In order to use this generator you need to precompute the bounding boxes and center 64 coordiantes of the instance segmentation, using e.g. `util.get_centers_and_bounding_boxes`. 65 66 Here's an example for how to use this class: 67 ```python 68 # Initialize generator for 1 positive and 4 negative point prompts. 69 prompt_generator = PointAndBoxPromptGenerator(1, 4, dilation_strength=8) 70 71 # Precompute the bounding boxes for the given segmentation 72 bounding_boxes, _ = util.get_centers_and_bounding_boxes(segmentation) 73 74 # generate point prompts for the objects with ids 1, 2 and 3 75 seg_ids = (1, 2, 3) 76 object_mask = np.stack([segmentation == seg_id for seg_id in seg_ids])[:, None] 77 this_bounding_boxes = [bounding_boxes[seg_id] for seg_id in seg_ids] 78 point_coords, point_labels, _, _ = prompt_generator(object_mask, this_bounding_boxes) 79 ``` 80 81 Args: 82 n_positive_points: The number of positive point prompts to generate per mask. 83 n_negative_points: The number of negative point prompts to generate per mask. 84 dilation_strength: The factor by which the mask is dilated before generating prompts. 85 get_point_prompts: Whether to generate point prompts. 86 get_box_prompts: Whether to generate box prompts. 87 """ 88 def __init__( 89 self, 90 n_positive_points: int, 91 n_negative_points: int, 92 dilation_strength: int, 93 get_point_prompts: bool = True, 94 get_box_prompts: bool = False 95 ) -> None: 96 self.n_positive_points = n_positive_points 97 self.n_negative_points = n_negative_points 98 self.dilation_strength = dilation_strength 99 self.get_box_prompts = get_box_prompts 100 self.get_point_prompts = get_point_prompts 101 102 if self.get_point_prompts is False and self.get_box_prompts is False: 103 raise ValueError("You need to request box prompts, point prompts or both.") 104 105 def _sample_positive_points(self, object_mask, center_coordinates, coord_list, label_list): 106 if center_coordinates is not None: 107 # getting the center coordinate as the first positive point (OPTIONAL) 108 coord_list.append(tuple(map(int, center_coordinates))) # to get int coords instead of float 109 110 # getting the additional positive points by randomly sampling points 111 # from this mask except the center coordinate 112 n_positive_remaining = self.n_positive_points - 1 113 114 else: 115 # need to sample "self.n_positive_points" number of points 116 n_positive_remaining = self.n_positive_points 117 118 if n_positive_remaining > 0: 119 object_coordinates = torch.where(object_mask) 120 n_coordinates = len(object_coordinates[0]) 121 122 # randomly sampling n_positive_remaining_points from these coordinates 123 indices = np.random.choice( 124 n_coordinates, size=n_positive_remaining, 125 # Allow replacing if we can't sample enough coordinates otherwise 126 replace=True if n_positive_remaining > n_coordinates else False, 127 ) 128 coord_list.extend([ 129 [object_coordinates[0][idx], object_coordinates[1][idx]] for idx in indices 130 ]) 131 132 label_list.extend([1] * self.n_positive_points) 133 assert len(coord_list) == len(label_list) == self.n_positive_points 134 return coord_list, label_list 135 136 def _sample_negative_points(self, object_mask, bbox_coordinates, coord_list, label_list): 137 if self.n_negative_points == 0: 138 return coord_list, label_list 139 140 # getting the negative points 141 # for this we do the opposite and we set the mask to the bounding box - the object mask 142 # we need to dilate the object mask before doing this: we use kornia.morphology.dilation for this 143 dilated_object = object_mask[None, None] 144 for _ in range(self.dilation_strength): 145 dilated_object = morphology.dilation(dilated_object, torch.ones(3, 3), engine="convolution") 146 dilated_object = dilated_object.squeeze() 147 148 background_mask = torch.zeros(object_mask.shape, device=object_mask.device) 149 _ds = self.dilation_strength 150 background_mask[ 151 max(bbox_coordinates[0] - _ds, 0): min(bbox_coordinates[2] + _ds, object_mask.shape[-2]), 152 max(bbox_coordinates[1] - _ds, 0): min(bbox_coordinates[3] + _ds, object_mask.shape[-1]) 153 ] = 1 154 background_mask = torch.abs(background_mask - dilated_object) 155 156 # the valid background coordinates 157 background_coordinates = torch.where(background_mask) 158 n_coordinates = len(background_coordinates[0]) 159 160 # randomly sample the negative points from these coordinates 161 indices = np.random.choice( 162 n_coordinates, replace=False, 163 size=min(self.n_negative_points, n_coordinates) # handles the cases with insufficient bg pixels 164 ) 165 coord_list.extend([ 166 [background_coordinates[0][idx], background_coordinates[1][idx]] for idx in indices 167 ]) 168 label_list.extend([0] * len(indices)) 169 170 return coord_list, label_list 171 172 def _ensure_num_points(self, object_mask, coord_list, label_list): 173 num_points = self.n_positive_points + self.n_negative_points 174 175 # fill up to the necessary number of points if we did not sample enough of them 176 if len(coord_list) != num_points: 177 # to stay consistent, we add random points in the background of an object 178 # if there's no neg region around the object - usually happens with small rois 179 needed_points = num_points - len(coord_list) 180 more_neg_points = torch.where(object_mask == 0) 181 indices = np.random.choice(len(more_neg_points[0]), size=needed_points, replace=False) 182 183 coord_list.extend([ 184 (more_neg_points[0][idx], more_neg_points[1][idx]) for idx in indices 185 ]) 186 label_list.extend([0] * needed_points) 187 188 assert len(coord_list) == len(label_list) == num_points 189 return coord_list, label_list 190 191 # Can we batch this properly? 192 def _sample_points(self, segmentation, bbox_coordinates, center_coordinates): 193 all_coords, all_labels = [], [] 194 195 center_coordinates = [None] * len(segmentation) if center_coordinates is None else center_coordinates 196 for object_mask, bbox_coords, center_coords in zip(segmentation, bbox_coordinates, center_coordinates): 197 coord_list, label_list = [], [] 198 coord_list, label_list = self._sample_positive_points(object_mask[0], center_coords, coord_list, label_list) 199 coord_list, label_list = self._sample_negative_points(object_mask[0], bbox_coords, coord_list, label_list) 200 coord_list, label_list = self._ensure_num_points(object_mask[0], coord_list, label_list) 201 202 all_coords.append(coord_list) 203 all_labels.append(label_list) 204 205 return all_coords, all_labels 206 207 def __call__( 208 self, 209 segmentation: torch.Tensor, 210 bbox_coordinates: List[Tuple], 211 center_coordinates: Optional[List[np.ndarray]] = None, 212 **kwargs, 213 ) -> Tuple[ 214 Optional[torch.Tensor], 215 Optional[torch.Tensor], 216 Optional[torch.Tensor], 217 None 218 ]: 219 """Generate the prompts for one object in the segmentation. 220 221 Args: 222 The groundtruth segmentation. Expects a float tensor of shape NUM_OBJECTS x 1 x H x W. 223 bbox_coordinates: The precomputed bounding boxes of particular object in the segmentation. 224 center_coordinates: The precomputed center coordinates of particular object in the segmentation. 225 If passed, these coordinates will be used as the first positive point prompt. 226 If not passed a random point from within the object mask will be used. 227 228 Returns: 229 Coordinates of point prompts. Returns None, if get_point_prompts is false. 230 Point prompt labels. Returns None, if get_point_prompts is false. 231 Bounding box prompts. Returns None, if get_box_prompts is false. 232 """ 233 if self.get_point_prompts: 234 coord_list, label_list = self._sample_points(segmentation, bbox_coordinates, center_coordinates) 235 # change the axis convention of the point coordinates to match the expected coordinate order of SAM 236 coord_list = np.array(coord_list)[:, :, ::-1].copy() 237 coord_list = torch.from_numpy(coord_list) 238 label_list = torch.tensor(label_list) 239 else: 240 coord_list, label_list = None, None 241 242 if self.get_box_prompts: 243 # change the axis convention of the box coordinates to match the expected coordinate order of SAM 244 bbox_list = np.array(bbox_coordinates)[:, [1, 0, 3, 2]] 245 bbox_list = torch.from_numpy(bbox_list) 246 else: 247 bbox_list = None 248 249 return coord_list, label_list, bbox_list, None 250 251 252class IterativePromptGenerator(PromptGeneratorBase): 253 """Generate point prompts from an instance segmentation iteratively. 254 """ 255 def _get_positive_points(self, pos_region, overlap_region, is_3d): 256 positive_locations = [torch.where(pos_reg) for pos_reg in pos_region] 257 # we may have objects without a positive region (= missing true foreground) 258 # in this case we just sample a positive point where the model was already correct 259 positive_locations = [ 260 torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc 261 for pos_loc, ovlp_reg in zip(positive_locations, overlap_region) 262 ] 263 # we sample one positive location for each object in the batch 264 sampled_indices = [np.random.choice(len(pos_loc[0])) for pos_loc in positive_locations] 265 # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM) 266 if is_3d: 267 pos_coordinates = [ 268 [pos_loc[-1][idx], pos_loc[-2][idx], pos_loc[-3][idx]] 269 for pos_loc, idx in zip(positive_locations, sampled_indices) 270 ] 271 else: 272 pos_coordinates = [ 273 [pos_loc[-1][idx], pos_loc[-2][idx]] for pos_loc, idx in zip(positive_locations, sampled_indices) 274 ] 275 276 # make sure that we still have the correct batch size 277 assert len(pos_coordinates) == pos_region.shape[0] 278 pos_labels = [1] * len(pos_coordinates) 279 280 return pos_coordinates, pos_labels 281 282 def _get_negative_locations_in_obj_bbox(self, true_object, custom_df=3): 283 true_loc = torch.where(true_object) 284 bbox = torch.stack( 285 [torch.min(true_loc[1]), torch.min(true_loc[2]), torch.max(true_loc[1]) + 1, torch.max(true_loc[2]) + 1] 286 ) 287 288 # custom dilation factor to perform dilation by expanding the pixels of bbox 289 bbox_mask = torch.zeros_like(true_object).squeeze(0) 290 bbox_mask[ 291 max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, true_object.shape[-2]), 292 max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, true_object.shape[-1]) 293 ] = 1 294 bbox_mask = bbox_mask[None].to(true_object.device) 295 background_mask = torch.abs(bbox_mask - true_object) 296 return torch.where(background_mask) 297 298 def _get_negative_points(self, neg_region, true_object, is_3d): 299 # we have a valid negative region (i.e. a valid region where the model could not generate prediction) 300 negative_locations = [torch.where(neg_reg) for neg_reg in neg_region] 301 # we may have objects without a negative region (= no rectifications required) 302 # in this case we sample a negative point in outer periphery of the object inside the bounding box. 303 negative_locations = [ 304 self._get_negative_locations_in_obj_bbox(true_obj) if len(neg_loc[0]) == 0 else neg_loc 305 for neg_loc, true_obj in zip(negative_locations, true_object) 306 ] 307 # there is a chance that the object is small to not return a decent-sized bounding box 308 # hence we might not find points sometimes there as well. therefore, we sample points from true background. 309 negative_locations = [ 310 torch.where(true_obj == 0) if len(neg_loc[0]) == 0 else neg_loc 311 for neg_loc, true_obj in zip(negative_locations, true_object) 312 ] 313 # we sample one negative location for each object in the batch 314 sampled_indices = [np.random.choice(len(neg_loc[0])) for neg_loc in negative_locations] 315 # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM) 316 if is_3d: 317 neg_coordinates = [ 318 [neg_loc[-1][idx], neg_loc[-2][idx], neg_loc[-3][idx]] 319 for neg_loc, idx in zip(negative_locations, sampled_indices) 320 ] 321 else: 322 neg_coordinates = [ 323 [neg_loc[-1][idx], neg_loc[-2][idx]] for neg_loc, idx in zip(negative_locations, sampled_indices) 324 ] 325 326 # make sure that we still have the correct batch size 327 assert len(neg_coordinates) == neg_region.shape[0] 328 neg_labels = [0] * len(neg_coordinates) 329 330 return neg_coordinates, neg_labels 331 332 def __call__( 333 self, 334 segmentation: torch.Tensor, 335 prediction: torch.Tensor, 336 **kwargs, 337 ) -> Tuple[torch.Tensor, torch.Tensor, None, None]: 338 """Generate the prompts for each object iteratively in the segmentation. 339 340 Args: 341 segmentation: The groundtruth segmentation. 342 Expects a float tensor of shape (NUM_OBJECTS x 1 x H x W) or (NUM_OBJECTS x 1 x Z x H x W). 343 prediction: The predicted objects. Epects a float tensor of the same shape as the segmentation. 344 345 Returns: 346 The updated point prompt coordinates. 347 The updated point prompt labels. 348 """ 349 device = prediction.device 350 assert segmentation.shape == prediction.shape, \ 351 "The segmentation and prediction tensors should have the same shape." 352 353 if segmentation.ndim == 5: # masks in 3d must be tensors of shape NUM_OBJECTS x 1 x Z x H x W 354 is_3d = True 355 elif segmentation.ndim == 4: # masks in 2d must be tensors of shape NUM_OBJECTS x 1 x H x W 356 is_3d = False 357 else: 358 raise ValueError("The segmentation and prediction tensors should have either '4' or '5' dimensions.") 359 360 true_object = segmentation.to(device) 361 expected_diff = (prediction - true_object) 362 neg_region = (expected_diff == 1).to(torch.float32) 363 pos_region = (expected_diff == -1) 364 overlap_region = torch.logical_and(prediction == 1, true_object == 1).to(torch.float32) 365 366 pos_coordinates, pos_labels = self._get_positive_points(pos_region, overlap_region, is_3d) 367 neg_coordinates, neg_labels = self._get_negative_points(neg_region, true_object, is_3d) 368 assert len(pos_coordinates) == len(pos_labels) == len(neg_coordinates) == len(neg_labels) 369 370 pos_coordinates = torch.tensor(pos_coordinates)[:, None] 371 neg_coordinates = torch.tensor(neg_coordinates)[:, None] 372 pos_labels, neg_labels = torch.tensor(pos_labels)[:, None], torch.tensor(neg_labels)[:, None] 373 374 net_coords = torch.cat([pos_coordinates, neg_coordinates], dim=1) 375 net_labels = torch.cat([pos_labels, neg_labels], dim=1) 376 377 return net_coords, net_labels, None, None
class
PromptGeneratorBase:
15class PromptGeneratorBase: 16 """PromptGeneratorBase is an interface to implement specific prompt generators. 17 """ 18 def __call__( 19 self, 20 segmentation: torch.Tensor, 21 prediction: Optional[torch.Tensor] = None, 22 bbox_coordinates: Optional[List[tuple]] = None, 23 center_coordinates: Optional[List[np.ndarray]] = None 24 ) -> Tuple[ 25 Optional[torch.Tensor], # the point coordinates 26 Optional[torch.Tensor], # the point labels 27 Optional[torch.Tensor], # the bounding boxes 28 Optional[torch.Tensor], # the mask prompts 29 ]: 30 """Return the point prompts given segmentation masks and optional other inputs. 31 32 Args: 33 segmentation: The object masks derived from instance segmentation groundtruth. 34 Expects a float tensor of shape NUM_OBJECTS x 1 x H x W. 35 The first axis corresponds to the binary object masks. 36 prediction: The predicted object masks corresponding to the segmentation. 37 Expects the same shape as the segmentation 38 bbox_coordinates: Precomputed bounding boxes for the segmentation. 39 Expects a list of length NUM_OBJECTS. 40 center_coordinates: Precomputed center coordinates for the segmentation. 41 Expects a list of length NUM_OBJECTS. 42 43 Returns: 44 The point prompt coordinates. Int tensor of shape NUM_OBJECTS x NUM_POINTS x 2. 45 The point coordinates are retuned in XY axis order. This means they are reversed compared 46 to the standard YX axis order used by numpy. 47 The point prompt labels. Int tensor of shape NUM_OBJECTS x NUM_POINTS. 48 The box prompts. Int tensor of shape NUM_OBJECTS x 4. 49 The box coordinates are retunred as MIN_X, MIN_Y, MAX_X, MAX_Y. 50 The mask prompts. Float tensor of shape NUM_OBJECTS x 1 x H' x W'. 51 With H' = W'= 256. 52 """ 53 raise NotImplementedError( 54 "PromptGeneratorBase is just a class template. " 55 "Use a child class that implements the specific generator instead" 56 )
PromptGeneratorBase is an interface to implement specific prompt generators.
59class PointAndBoxPromptGenerator(PromptGeneratorBase): 60 """Generate point and/or box prompts from an instance segmentation. 61 62 You can use this class to derive prompts from an instance segmentation, either for 63 evaluation purposes or for training Segment Anything on custom data. 64 In order to use this generator you need to precompute the bounding boxes and center 65 coordiantes of the instance segmentation, using e.g. `util.get_centers_and_bounding_boxes`. 66 67 Here's an example for how to use this class: 68 ```python 69 # Initialize generator for 1 positive and 4 negative point prompts. 70 prompt_generator = PointAndBoxPromptGenerator(1, 4, dilation_strength=8) 71 72 # Precompute the bounding boxes for the given segmentation 73 bounding_boxes, _ = util.get_centers_and_bounding_boxes(segmentation) 74 75 # generate point prompts for the objects with ids 1, 2 and 3 76 seg_ids = (1, 2, 3) 77 object_mask = np.stack([segmentation == seg_id for seg_id in seg_ids])[:, None] 78 this_bounding_boxes = [bounding_boxes[seg_id] for seg_id in seg_ids] 79 point_coords, point_labels, _, _ = prompt_generator(object_mask, this_bounding_boxes) 80 ``` 81 82 Args: 83 n_positive_points: The number of positive point prompts to generate per mask. 84 n_negative_points: The number of negative point prompts to generate per mask. 85 dilation_strength: The factor by which the mask is dilated before generating prompts. 86 get_point_prompts: Whether to generate point prompts. 87 get_box_prompts: Whether to generate box prompts. 88 """ 89 def __init__( 90 self, 91 n_positive_points: int, 92 n_negative_points: int, 93 dilation_strength: int, 94 get_point_prompts: bool = True, 95 get_box_prompts: bool = False 96 ) -> None: 97 self.n_positive_points = n_positive_points 98 self.n_negative_points = n_negative_points 99 self.dilation_strength = dilation_strength 100 self.get_box_prompts = get_box_prompts 101 self.get_point_prompts = get_point_prompts 102 103 if self.get_point_prompts is False and self.get_box_prompts is False: 104 raise ValueError("You need to request box prompts, point prompts or both.") 105 106 def _sample_positive_points(self, object_mask, center_coordinates, coord_list, label_list): 107 if center_coordinates is not None: 108 # getting the center coordinate as the first positive point (OPTIONAL) 109 coord_list.append(tuple(map(int, center_coordinates))) # to get int coords instead of float 110 111 # getting the additional positive points by randomly sampling points 112 # from this mask except the center coordinate 113 n_positive_remaining = self.n_positive_points - 1 114 115 else: 116 # need to sample "self.n_positive_points" number of points 117 n_positive_remaining = self.n_positive_points 118 119 if n_positive_remaining > 0: 120 object_coordinates = torch.where(object_mask) 121 n_coordinates = len(object_coordinates[0]) 122 123 # randomly sampling n_positive_remaining_points from these coordinates 124 indices = np.random.choice( 125 n_coordinates, size=n_positive_remaining, 126 # Allow replacing if we can't sample enough coordinates otherwise 127 replace=True if n_positive_remaining > n_coordinates else False, 128 ) 129 coord_list.extend([ 130 [object_coordinates[0][idx], object_coordinates[1][idx]] for idx in indices 131 ]) 132 133 label_list.extend([1] * self.n_positive_points) 134 assert len(coord_list) == len(label_list) == self.n_positive_points 135 return coord_list, label_list 136 137 def _sample_negative_points(self, object_mask, bbox_coordinates, coord_list, label_list): 138 if self.n_negative_points == 0: 139 return coord_list, label_list 140 141 # getting the negative points 142 # for this we do the opposite and we set the mask to the bounding box - the object mask 143 # we need to dilate the object mask before doing this: we use kornia.morphology.dilation for this 144 dilated_object = object_mask[None, None] 145 for _ in range(self.dilation_strength): 146 dilated_object = morphology.dilation(dilated_object, torch.ones(3, 3), engine="convolution") 147 dilated_object = dilated_object.squeeze() 148 149 background_mask = torch.zeros(object_mask.shape, device=object_mask.device) 150 _ds = self.dilation_strength 151 background_mask[ 152 max(bbox_coordinates[0] - _ds, 0): min(bbox_coordinates[2] + _ds, object_mask.shape[-2]), 153 max(bbox_coordinates[1] - _ds, 0): min(bbox_coordinates[3] + _ds, object_mask.shape[-1]) 154 ] = 1 155 background_mask = torch.abs(background_mask - dilated_object) 156 157 # the valid background coordinates 158 background_coordinates = torch.where(background_mask) 159 n_coordinates = len(background_coordinates[0]) 160 161 # randomly sample the negative points from these coordinates 162 indices = np.random.choice( 163 n_coordinates, replace=False, 164 size=min(self.n_negative_points, n_coordinates) # handles the cases with insufficient bg pixels 165 ) 166 coord_list.extend([ 167 [background_coordinates[0][idx], background_coordinates[1][idx]] for idx in indices 168 ]) 169 label_list.extend([0] * len(indices)) 170 171 return coord_list, label_list 172 173 def _ensure_num_points(self, object_mask, coord_list, label_list): 174 num_points = self.n_positive_points + self.n_negative_points 175 176 # fill up to the necessary number of points if we did not sample enough of them 177 if len(coord_list) != num_points: 178 # to stay consistent, we add random points in the background of an object 179 # if there's no neg region around the object - usually happens with small rois 180 needed_points = num_points - len(coord_list) 181 more_neg_points = torch.where(object_mask == 0) 182 indices = np.random.choice(len(more_neg_points[0]), size=needed_points, replace=False) 183 184 coord_list.extend([ 185 (more_neg_points[0][idx], more_neg_points[1][idx]) for idx in indices 186 ]) 187 label_list.extend([0] * needed_points) 188 189 assert len(coord_list) == len(label_list) == num_points 190 return coord_list, label_list 191 192 # Can we batch this properly? 193 def _sample_points(self, segmentation, bbox_coordinates, center_coordinates): 194 all_coords, all_labels = [], [] 195 196 center_coordinates = [None] * len(segmentation) if center_coordinates is None else center_coordinates 197 for object_mask, bbox_coords, center_coords in zip(segmentation, bbox_coordinates, center_coordinates): 198 coord_list, label_list = [], [] 199 coord_list, label_list = self._sample_positive_points(object_mask[0], center_coords, coord_list, label_list) 200 coord_list, label_list = self._sample_negative_points(object_mask[0], bbox_coords, coord_list, label_list) 201 coord_list, label_list = self._ensure_num_points(object_mask[0], coord_list, label_list) 202 203 all_coords.append(coord_list) 204 all_labels.append(label_list) 205 206 return all_coords, all_labels 207 208 def __call__( 209 self, 210 segmentation: torch.Tensor, 211 bbox_coordinates: List[Tuple], 212 center_coordinates: Optional[List[np.ndarray]] = None, 213 **kwargs, 214 ) -> Tuple[ 215 Optional[torch.Tensor], 216 Optional[torch.Tensor], 217 Optional[torch.Tensor], 218 None 219 ]: 220 """Generate the prompts for one object in the segmentation. 221 222 Args: 223 The groundtruth segmentation. Expects a float tensor of shape NUM_OBJECTS x 1 x H x W. 224 bbox_coordinates: The precomputed bounding boxes of particular object in the segmentation. 225 center_coordinates: The precomputed center coordinates of particular object in the segmentation. 226 If passed, these coordinates will be used as the first positive point prompt. 227 If not passed a random point from within the object mask will be used. 228 229 Returns: 230 Coordinates of point prompts. Returns None, if get_point_prompts is false. 231 Point prompt labels. Returns None, if get_point_prompts is false. 232 Bounding box prompts. Returns None, if get_box_prompts is false. 233 """ 234 if self.get_point_prompts: 235 coord_list, label_list = self._sample_points(segmentation, bbox_coordinates, center_coordinates) 236 # change the axis convention of the point coordinates to match the expected coordinate order of SAM 237 coord_list = np.array(coord_list)[:, :, ::-1].copy() 238 coord_list = torch.from_numpy(coord_list) 239 label_list = torch.tensor(label_list) 240 else: 241 coord_list, label_list = None, None 242 243 if self.get_box_prompts: 244 # change the axis convention of the box coordinates to match the expected coordinate order of SAM 245 bbox_list = np.array(bbox_coordinates)[:, [1, 0, 3, 2]] 246 bbox_list = torch.from_numpy(bbox_list) 247 else: 248 bbox_list = None 249 250 return coord_list, label_list, bbox_list, None
Generate point and/or box prompts from an instance segmentation.
You can use this class to derive prompts from an instance segmentation, either for
evaluation purposes or for training Segment Anything on custom data.
In order to use this generator you need to precompute the bounding boxes and center
coordiantes of the instance segmentation, using e.g. util.get_centers_and_bounding_boxes
.
Here's an example for how to use this class:
# Initialize generator for 1 positive and 4 negative point prompts.
prompt_generator = PointAndBoxPromptGenerator(1, 4, dilation_strength=8)
# Precompute the bounding boxes for the given segmentation
bounding_boxes, _ = util.get_centers_and_bounding_boxes(segmentation)
# generate point prompts for the objects with ids 1, 2 and 3
seg_ids = (1, 2, 3)
object_mask = np.stack([segmentation == seg_id for seg_id in seg_ids])[:, None]
this_bounding_boxes = [bounding_boxes[seg_id] for seg_id in seg_ids]
point_coords, point_labels, _, _ = prompt_generator(object_mask, this_bounding_boxes)
Arguments:
- n_positive_points: The number of positive point prompts to generate per mask.
- n_negative_points: The number of negative point prompts to generate per mask.
- dilation_strength: The factor by which the mask is dilated before generating prompts.
- get_point_prompts: Whether to generate point prompts.
- get_box_prompts: Whether to generate box prompts.
PointAndBoxPromptGenerator( n_positive_points: int, n_negative_points: int, dilation_strength: int, get_point_prompts: bool = True, get_box_prompts: bool = False)
89 def __init__( 90 self, 91 n_positive_points: int, 92 n_negative_points: int, 93 dilation_strength: int, 94 get_point_prompts: bool = True, 95 get_box_prompts: bool = False 96 ) -> None: 97 self.n_positive_points = n_positive_points 98 self.n_negative_points = n_negative_points 99 self.dilation_strength = dilation_strength 100 self.get_box_prompts = get_box_prompts 101 self.get_point_prompts = get_point_prompts 102 103 if self.get_point_prompts is False and self.get_box_prompts is False: 104 raise ValueError("You need to request box prompts, point prompts or both.")
253class IterativePromptGenerator(PromptGeneratorBase): 254 """Generate point prompts from an instance segmentation iteratively. 255 """ 256 def _get_positive_points(self, pos_region, overlap_region, is_3d): 257 positive_locations = [torch.where(pos_reg) for pos_reg in pos_region] 258 # we may have objects without a positive region (= missing true foreground) 259 # in this case we just sample a positive point where the model was already correct 260 positive_locations = [ 261 torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc 262 for pos_loc, ovlp_reg in zip(positive_locations, overlap_region) 263 ] 264 # we sample one positive location for each object in the batch 265 sampled_indices = [np.random.choice(len(pos_loc[0])) for pos_loc in positive_locations] 266 # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM) 267 if is_3d: 268 pos_coordinates = [ 269 [pos_loc[-1][idx], pos_loc[-2][idx], pos_loc[-3][idx]] 270 for pos_loc, idx in zip(positive_locations, sampled_indices) 271 ] 272 else: 273 pos_coordinates = [ 274 [pos_loc[-1][idx], pos_loc[-2][idx]] for pos_loc, idx in zip(positive_locations, sampled_indices) 275 ] 276 277 # make sure that we still have the correct batch size 278 assert len(pos_coordinates) == pos_region.shape[0] 279 pos_labels = [1] * len(pos_coordinates) 280 281 return pos_coordinates, pos_labels 282 283 def _get_negative_locations_in_obj_bbox(self, true_object, custom_df=3): 284 true_loc = torch.where(true_object) 285 bbox = torch.stack( 286 [torch.min(true_loc[1]), torch.min(true_loc[2]), torch.max(true_loc[1]) + 1, torch.max(true_loc[2]) + 1] 287 ) 288 289 # custom dilation factor to perform dilation by expanding the pixels of bbox 290 bbox_mask = torch.zeros_like(true_object).squeeze(0) 291 bbox_mask[ 292 max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, true_object.shape[-2]), 293 max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, true_object.shape[-1]) 294 ] = 1 295 bbox_mask = bbox_mask[None].to(true_object.device) 296 background_mask = torch.abs(bbox_mask - true_object) 297 return torch.where(background_mask) 298 299 def _get_negative_points(self, neg_region, true_object, is_3d): 300 # we have a valid negative region (i.e. a valid region where the model could not generate prediction) 301 negative_locations = [torch.where(neg_reg) for neg_reg in neg_region] 302 # we may have objects without a negative region (= no rectifications required) 303 # in this case we sample a negative point in outer periphery of the object inside the bounding box. 304 negative_locations = [ 305 self._get_negative_locations_in_obj_bbox(true_obj) if len(neg_loc[0]) == 0 else neg_loc 306 for neg_loc, true_obj in zip(negative_locations, true_object) 307 ] 308 # there is a chance that the object is small to not return a decent-sized bounding box 309 # hence we might not find points sometimes there as well. therefore, we sample points from true background. 310 negative_locations = [ 311 torch.where(true_obj == 0) if len(neg_loc[0]) == 0 else neg_loc 312 for neg_loc, true_obj in zip(negative_locations, true_object) 313 ] 314 # we sample one negative location for each object in the batch 315 sampled_indices = [np.random.choice(len(neg_loc[0])) for neg_loc in negative_locations] 316 # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM) 317 if is_3d: 318 neg_coordinates = [ 319 [neg_loc[-1][idx], neg_loc[-2][idx], neg_loc[-3][idx]] 320 for neg_loc, idx in zip(negative_locations, sampled_indices) 321 ] 322 else: 323 neg_coordinates = [ 324 [neg_loc[-1][idx], neg_loc[-2][idx]] for neg_loc, idx in zip(negative_locations, sampled_indices) 325 ] 326 327 # make sure that we still have the correct batch size 328 assert len(neg_coordinates) == neg_region.shape[0] 329 neg_labels = [0] * len(neg_coordinates) 330 331 return neg_coordinates, neg_labels 332 333 def __call__( 334 self, 335 segmentation: torch.Tensor, 336 prediction: torch.Tensor, 337 **kwargs, 338 ) -> Tuple[torch.Tensor, torch.Tensor, None, None]: 339 """Generate the prompts for each object iteratively in the segmentation. 340 341 Args: 342 segmentation: The groundtruth segmentation. 343 Expects a float tensor of shape (NUM_OBJECTS x 1 x H x W) or (NUM_OBJECTS x 1 x Z x H x W). 344 prediction: The predicted objects. Epects a float tensor of the same shape as the segmentation. 345 346 Returns: 347 The updated point prompt coordinates. 348 The updated point prompt labels. 349 """ 350 device = prediction.device 351 assert segmentation.shape == prediction.shape, \ 352 "The segmentation and prediction tensors should have the same shape." 353 354 if segmentation.ndim == 5: # masks in 3d must be tensors of shape NUM_OBJECTS x 1 x Z x H x W 355 is_3d = True 356 elif segmentation.ndim == 4: # masks in 2d must be tensors of shape NUM_OBJECTS x 1 x H x W 357 is_3d = False 358 else: 359 raise ValueError("The segmentation and prediction tensors should have either '4' or '5' dimensions.") 360 361 true_object = segmentation.to(device) 362 expected_diff = (prediction - true_object) 363 neg_region = (expected_diff == 1).to(torch.float32) 364 pos_region = (expected_diff == -1) 365 overlap_region = torch.logical_and(prediction == 1, true_object == 1).to(torch.float32) 366 367 pos_coordinates, pos_labels = self._get_positive_points(pos_region, overlap_region, is_3d) 368 neg_coordinates, neg_labels = self._get_negative_points(neg_region, true_object, is_3d) 369 assert len(pos_coordinates) == len(pos_labels) == len(neg_coordinates) == len(neg_labels) 370 371 pos_coordinates = torch.tensor(pos_coordinates)[:, None] 372 neg_coordinates = torch.tensor(neg_coordinates)[:, None] 373 pos_labels, neg_labels = torch.tensor(pos_labels)[:, None], torch.tensor(neg_labels)[:, None] 374 375 net_coords = torch.cat([pos_coordinates, neg_coordinates], dim=1) 376 net_labels = torch.cat([pos_labels, neg_labels], dim=1) 377 378 return net_coords, net_labels, None, None
Generate point prompts from an instance segmentation iteratively.