micro_sam.prompt_generators

Classes for generating prompts from ground-truth segmentation masks. For training or evaluation of prompt-based segmentation.

  1"""
  2Classes for generating prompts from ground-truth segmentation masks.
  3For training or evaluation of prompt-based segmentation.
  4"""
  5
  6from typing import List, Optional, Tuple
  7
  8import numpy as np
  9from kornia import morphology
 10
 11import torch
 12
 13
 14class PromptGeneratorBase:
 15    """PromptGeneratorBase is an interface to implement specific prompt generators.
 16    """
 17    def __call__(
 18        self,
 19        segmentation: torch.Tensor,
 20        prediction: Optional[torch.Tensor] = None,
 21        bbox_coordinates: Optional[List[tuple]] = None,
 22        center_coordinates: Optional[List[np.ndarray]] = None
 23    ) -> Tuple[
 24        Optional[torch.Tensor],  # the point coordinates
 25        Optional[torch.Tensor],  # the point labels
 26        Optional[torch.Tensor],  # the bounding boxes
 27        Optional[torch.Tensor],  # the mask prompts
 28    ]:
 29        """Return the point prompts given segmentation masks and optional other inputs.
 30
 31        Args:
 32            segmentation: The object masks derived from instance segmentation groundtruth.
 33                Expects a float tensor of shape NUM_OBJECTS x 1 x H x W.
 34                The first axis corresponds to the binary object masks.
 35            prediction: The predicted object masks corresponding to the segmentation.
 36                Expects the same shape as the segmentation
 37            bbox_coordinates: Precomputed bounding boxes for the segmentation.
 38                Expects a list of length NUM_OBJECTS.
 39            center_coordinates: Precomputed center coordinates for the segmentation.
 40                Expects a list of length NUM_OBJECTS.
 41
 42        Returns:
 43            The point prompt coordinates. Int tensor of shape NUM_OBJECTS x NUM_POINTS x 2.
 44                The point coordinates are retuned in XY axis order. This means they are reversed compared
 45                to the standard YX axis order used by numpy.
 46            The point prompt labels. Int tensor of shape NUM_OBJECTS x NUM_POINTS.
 47            The box prompts. Int tensor of shape NUM_OBJECTS x 4.
 48                The box coordinates are retunred as MIN_X, MIN_Y, MAX_X, MAX_Y.
 49            The mask prompts. Float tensor of shape NUM_OBJECTS x 1 x H' x W'.
 50                With H' = W'= 256.
 51        """
 52        raise NotImplementedError(
 53            "PromptGeneratorBase is just a class template. "
 54            "Use a child class that implements the specific generator instead"
 55        )
 56
 57
 58class PointAndBoxPromptGenerator(PromptGeneratorBase):
 59    """Generate point and/or box prompts from an instance segmentation.
 60
 61    You can use this class to derive prompts from an instance segmentation, either for
 62    evaluation purposes or for training Segment Anything on custom data.
 63    In order to use this generator you need to precompute the bounding boxes and center
 64    coordiantes of the instance segmentation, using e.g. `util.get_centers_and_bounding_boxes`.
 65
 66    Here's an example for how to use this class:
 67    ```python
 68    # Initialize generator for 1 positive and 4 negative point prompts.
 69    prompt_generator = PointAndBoxPromptGenerator(1, 4, dilation_strength=8)
 70
 71    # Precompute the bounding boxes for the given segmentation
 72    bounding_boxes, _ = util.get_centers_and_bounding_boxes(segmentation)
 73
 74    # generate point prompts for the objects with ids 1, 2 and 3
 75    seg_ids = (1, 2, 3)
 76    object_mask = np.stack([segmentation == seg_id for seg_id in seg_ids])[:, None]
 77    this_bounding_boxes = [bounding_boxes[seg_id] for seg_id in seg_ids]
 78    point_coords, point_labels, _, _ = prompt_generator(object_mask, this_bounding_boxes)
 79    ```
 80
 81    Args:
 82        n_positive_points: The number of positive point prompts to generate per mask.
 83        n_negative_points: The number of negative point prompts to generate per mask.
 84        dilation_strength: The factor by which the mask is dilated before generating prompts.
 85        get_point_prompts: Whether to generate point prompts.
 86        get_box_prompts: Whether to generate box prompts.
 87    """
 88    def __init__(
 89        self,
 90        n_positive_points: int,
 91        n_negative_points: int,
 92        dilation_strength: int,
 93        get_point_prompts: bool = True,
 94        get_box_prompts: bool = False
 95    ) -> None:
 96        self.n_positive_points = n_positive_points
 97        self.n_negative_points = n_negative_points
 98        self.dilation_strength = dilation_strength
 99        self.get_box_prompts = get_box_prompts
100        self.get_point_prompts = get_point_prompts
101
102        if self.get_point_prompts is False and self.get_box_prompts is False:
103            raise ValueError("You need to request box prompts, point prompts or both.")
104
105    def _sample_positive_points(self, object_mask, center_coordinates, coord_list, label_list):
106        if center_coordinates is not None:
107            # getting the center coordinate as the first positive point (OPTIONAL)
108            coord_list.append(tuple(map(int, center_coordinates)))  # to get int coords instead of float
109
110            # getting the additional positive points by randomly sampling points
111            # from this mask except the center coordinate
112            n_positive_remaining = self.n_positive_points - 1
113
114        else:
115            # need to sample "self.n_positive_points" number of points
116            n_positive_remaining = self.n_positive_points
117
118        if n_positive_remaining > 0:
119            object_coordinates = torch.where(object_mask)
120            n_coordinates = len(object_coordinates[0])
121
122            # randomly sampling n_positive_remaining_points from these coordinates
123            indices = np.random.choice(
124                n_coordinates, size=n_positive_remaining,
125                # Allow replacing if we can't sample enough coordinates otherwise
126                replace=True if n_positive_remaining > n_coordinates else False,
127            )
128            coord_list.extend([
129                [object_coordinates[0][idx], object_coordinates[1][idx]] for idx in indices
130            ])
131
132        label_list.extend([1] * self.n_positive_points)
133        assert len(coord_list) == len(label_list) == self.n_positive_points
134        return coord_list, label_list
135
136    def _sample_negative_points(self, object_mask, bbox_coordinates, coord_list, label_list):
137        if self.n_negative_points == 0:
138            return coord_list, label_list
139
140        # getting the negative points
141        # for this we do the opposite and we set the mask to the bounding box - the object mask
142        # we need to dilate the object mask before doing this: we use kornia.morphology.dilation for this
143        dilated_object = object_mask[None, None]
144        for _ in range(self.dilation_strength):
145            dilated_object = morphology.dilation(dilated_object, torch.ones(3, 3), engine="convolution")
146        dilated_object = dilated_object.squeeze()
147
148        background_mask = torch.zeros(object_mask.shape, device=object_mask.device)
149        _ds = self.dilation_strength
150        background_mask[
151            max(bbox_coordinates[0] - _ds, 0): min(bbox_coordinates[2] + _ds, object_mask.shape[-2]),
152            max(bbox_coordinates[1] - _ds, 0): min(bbox_coordinates[3] + _ds, object_mask.shape[-1])
153        ] = 1
154        background_mask = torch.abs(background_mask - dilated_object)
155
156        # the valid background coordinates
157        background_coordinates = torch.where(background_mask)
158        n_coordinates = len(background_coordinates[0])
159
160        # randomly sample the negative points from these coordinates
161        indices = np.random.choice(
162            n_coordinates, replace=False,
163            size=min(self.n_negative_points, n_coordinates)  # handles the cases with insufficient bg pixels
164        )
165        coord_list.extend([
166            [background_coordinates[0][idx], background_coordinates[1][idx]] for idx in indices
167        ])
168        label_list.extend([0] * len(indices))
169
170        return coord_list, label_list
171
172    def _ensure_num_points(self, object_mask, coord_list, label_list):
173        num_points = self.n_positive_points + self.n_negative_points
174
175        # fill up to the necessary number of points if we did not sample enough of them
176        if len(coord_list) != num_points:
177            # to stay consistent, we add random points in the background of an object
178            # if there's no neg region around the object - usually happens with small rois
179            needed_points = num_points - len(coord_list)
180            more_neg_points = torch.where(object_mask == 0)
181            indices = np.random.choice(len(more_neg_points[0]), size=needed_points, replace=False)
182
183            coord_list.extend([
184                (more_neg_points[0][idx], more_neg_points[1][idx]) for idx in indices
185            ])
186            label_list.extend([0] * needed_points)
187
188        assert len(coord_list) == len(label_list) == num_points
189        return coord_list, label_list
190
191    # Can we batch this properly?
192    def _sample_points(self, segmentation, bbox_coordinates, center_coordinates):
193        all_coords, all_labels = [], []
194
195        center_coordinates = [None] * len(segmentation) if center_coordinates is None else center_coordinates
196        for object_mask, bbox_coords, center_coords in zip(segmentation, bbox_coordinates, center_coordinates):
197            coord_list, label_list = [], []
198            coord_list, label_list = self._sample_positive_points(object_mask[0], center_coords, coord_list, label_list)
199            coord_list, label_list = self._sample_negative_points(object_mask[0], bbox_coords, coord_list, label_list)
200            coord_list, label_list = self._ensure_num_points(object_mask[0], coord_list, label_list)
201
202            all_coords.append(coord_list)
203            all_labels.append(label_list)
204
205        return all_coords, all_labels
206
207    def __call__(
208        self,
209        segmentation: torch.Tensor,
210        bbox_coordinates: List[Tuple],
211        center_coordinates: Optional[List[np.ndarray]] = None,
212        **kwargs,
213    ) -> Tuple[
214        Optional[torch.Tensor],
215        Optional[torch.Tensor],
216        Optional[torch.Tensor],
217        None
218    ]:
219        """Generate the prompts for one object in the segmentation.
220
221        Args:
222            The groundtruth segmentation. Expects a float tensor of shape NUM_OBJECTS x 1 x H x W.
223            bbox_coordinates: The precomputed bounding boxes of particular object in the segmentation.
224            center_coordinates: The precomputed center coordinates of particular object in the segmentation.
225                If passed, these coordinates will be used as the first positive point prompt.
226                If not passed a random point from within the object mask will be used.
227
228        Returns:
229            Coordinates of point prompts. Returns None, if get_point_prompts is false.
230            Point prompt labels. Returns None, if get_point_prompts is false.
231            Bounding box prompts. Returns None, if get_box_prompts is false.
232        """
233        if self.get_point_prompts:
234            coord_list, label_list = self._sample_points(segmentation, bbox_coordinates, center_coordinates)
235            # change the axis convention of the point coordinates to match the expected coordinate order of SAM
236            coord_list = np.array(coord_list)[:, :, ::-1].copy()
237            coord_list = torch.from_numpy(coord_list)
238            label_list = torch.tensor(label_list)
239        else:
240            coord_list, label_list = None, None
241
242        if self.get_box_prompts:
243            # change the axis convention of the box coordinates to match the expected coordinate order of SAM
244            bbox_list = np.array(bbox_coordinates)[:, [1, 0, 3, 2]]
245            bbox_list = torch.from_numpy(bbox_list)
246        else:
247            bbox_list = None
248
249        return coord_list, label_list, bbox_list, None
250
251
252class IterativePromptGenerator(PromptGeneratorBase):
253    """Generate point prompts from an instance segmentation iteratively.
254    """
255    def _get_positive_points(self, pos_region, overlap_region, is_3d):
256        positive_locations = [torch.where(pos_reg) for pos_reg in pos_region]
257        # we may have objects without a positive region (= missing true foreground)
258        # in this case we just sample a positive point where the model was already correct
259        positive_locations = [
260            torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc
261            for pos_loc, ovlp_reg in zip(positive_locations, overlap_region)
262        ]
263        # we sample one positive location for each object in the batch
264        sampled_indices = [np.random.choice(len(pos_loc[0])) for pos_loc in positive_locations]
265        # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM)
266        if is_3d:
267            pos_coordinates = [
268                [pos_loc[-1][idx], pos_loc[-2][idx], pos_loc[-3][idx]]
269                for pos_loc, idx in zip(positive_locations, sampled_indices)
270            ]
271        else:
272            pos_coordinates = [
273                [pos_loc[-1][idx], pos_loc[-2][idx]] for pos_loc, idx in zip(positive_locations, sampled_indices)
274            ]
275
276        # make sure that we still have the correct batch size
277        assert len(pos_coordinates) == pos_region.shape[0]
278        pos_labels = [1] * len(pos_coordinates)
279
280        return pos_coordinates, pos_labels
281
282    def _get_negative_locations_in_obj_bbox(self, true_object, custom_df=3):
283        true_loc = torch.where(true_object)
284        bbox = torch.stack(
285            [torch.min(true_loc[1]), torch.min(true_loc[2]), torch.max(true_loc[1]) + 1, torch.max(true_loc[2]) + 1]
286        )
287
288        # custom dilation factor to perform dilation by expanding the pixels of bbox
289        bbox_mask = torch.zeros_like(true_object).squeeze(0)
290        bbox_mask[
291            max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, true_object.shape[-2]),
292            max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, true_object.shape[-1])
293        ] = 1
294        bbox_mask = bbox_mask[None].to(true_object.device)
295        background_mask = torch.abs(bbox_mask - true_object)
296        return torch.where(background_mask)
297
298    def _get_negative_points(self, neg_region, true_object, is_3d):
299        # we have a valid negative region (i.e. a valid region where the model could not generate prediction)
300        negative_locations = [torch.where(neg_reg) for neg_reg in neg_region]
301        # we may have objects without a negative region (= no rectifications required)
302        # in this case we sample a negative point in outer periphery of the object inside the bounding box.
303        negative_locations = [
304            self._get_negative_locations_in_obj_bbox(true_obj) if len(neg_loc[0]) == 0 else neg_loc
305            for neg_loc, true_obj in zip(negative_locations, true_object)
306        ]
307        # there is a chance that the object is small to not return a decent-sized bounding box
308        # hence we might not find points sometimes there as well. therefore, we sample points from true background.
309        negative_locations = [
310            torch.where(true_obj == 0) if len(neg_loc[0]) == 0 else neg_loc
311            for neg_loc, true_obj in zip(negative_locations, true_object)
312        ]
313        # we sample one negative location for each object in the batch
314        sampled_indices = [np.random.choice(len(neg_loc[0])) for neg_loc in negative_locations]
315        # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM)
316        if is_3d:
317            neg_coordinates = [
318                [neg_loc[-1][idx], neg_loc[-2][idx], neg_loc[-3][idx]]
319                for neg_loc, idx in zip(negative_locations, sampled_indices)
320            ]
321        else:
322            neg_coordinates = [
323                [neg_loc[-1][idx], neg_loc[-2][idx]] for neg_loc, idx in zip(negative_locations, sampled_indices)
324            ]
325
326        # make sure that we still have the correct batch size
327        assert len(neg_coordinates) == neg_region.shape[0]
328        neg_labels = [0] * len(neg_coordinates)
329
330        return neg_coordinates, neg_labels
331
332    def __call__(
333        self,
334        segmentation: torch.Tensor,
335        prediction: torch.Tensor,
336        **kwargs,
337    ) -> Tuple[torch.Tensor, torch.Tensor, None, None]:
338        """Generate the prompts for each object iteratively in the segmentation.
339
340        Args:
341            segmentation: The groundtruth segmentation.
342                Expects a float tensor of shape (NUM_OBJECTS x 1 x H x W) or (NUM_OBJECTS x 1 x Z x H x W).
343            prediction: The predicted objects. Epects a float tensor of the same shape as the segmentation.
344
345        Returns:
346            The updated point prompt coordinates.
347            The updated point prompt labels.
348        """
349        device = prediction.device
350        assert segmentation.shape == prediction.shape, \
351            "The segmentation and prediction tensors should have the same shape."
352
353        if segmentation.ndim == 5:  # masks in 3d must be tensors of shape NUM_OBJECTS x 1 x Z x H x W
354            is_3d = True
355        elif segmentation.ndim == 4:  # masks in 2d must be tensors of shape NUM_OBJECTS x 1 x H x W
356            is_3d = False
357        else:
358            raise ValueError("The segmentation and prediction tensors should have either '4' or '5' dimensions.")
359
360        true_object = segmentation.to(device)
361        expected_diff = (prediction - true_object)
362        neg_region = (expected_diff == 1).to(torch.float32)
363        pos_region = (expected_diff == -1)
364        overlap_region = torch.logical_and(prediction == 1, true_object == 1).to(torch.float32)
365
366        pos_coordinates, pos_labels = self._get_positive_points(pos_region, overlap_region, is_3d)
367        neg_coordinates, neg_labels = self._get_negative_points(neg_region, true_object, is_3d)
368        assert len(pos_coordinates) == len(pos_labels) == len(neg_coordinates) == len(neg_labels)
369
370        pos_coordinates = torch.tensor(pos_coordinates)[:, None]
371        neg_coordinates = torch.tensor(neg_coordinates)[:, None]
372        pos_labels, neg_labels = torch.tensor(pos_labels)[:, None], torch.tensor(neg_labels)[:, None]
373
374        net_coords = torch.cat([pos_coordinates, neg_coordinates], dim=1)
375        net_labels = torch.cat([pos_labels, neg_labels], dim=1)
376
377        return net_coords, net_labels, None, None
class PromptGeneratorBase:
15class PromptGeneratorBase:
16    """PromptGeneratorBase is an interface to implement specific prompt generators.
17    """
18    def __call__(
19        self,
20        segmentation: torch.Tensor,
21        prediction: Optional[torch.Tensor] = None,
22        bbox_coordinates: Optional[List[tuple]] = None,
23        center_coordinates: Optional[List[np.ndarray]] = None
24    ) -> Tuple[
25        Optional[torch.Tensor],  # the point coordinates
26        Optional[torch.Tensor],  # the point labels
27        Optional[torch.Tensor],  # the bounding boxes
28        Optional[torch.Tensor],  # the mask prompts
29    ]:
30        """Return the point prompts given segmentation masks and optional other inputs.
31
32        Args:
33            segmentation: The object masks derived from instance segmentation groundtruth.
34                Expects a float tensor of shape NUM_OBJECTS x 1 x H x W.
35                The first axis corresponds to the binary object masks.
36            prediction: The predicted object masks corresponding to the segmentation.
37                Expects the same shape as the segmentation
38            bbox_coordinates: Precomputed bounding boxes for the segmentation.
39                Expects a list of length NUM_OBJECTS.
40            center_coordinates: Precomputed center coordinates for the segmentation.
41                Expects a list of length NUM_OBJECTS.
42
43        Returns:
44            The point prompt coordinates. Int tensor of shape NUM_OBJECTS x NUM_POINTS x 2.
45                The point coordinates are retuned in XY axis order. This means they are reversed compared
46                to the standard YX axis order used by numpy.
47            The point prompt labels. Int tensor of shape NUM_OBJECTS x NUM_POINTS.
48            The box prompts. Int tensor of shape NUM_OBJECTS x 4.
49                The box coordinates are retunred as MIN_X, MIN_Y, MAX_X, MAX_Y.
50            The mask prompts. Float tensor of shape NUM_OBJECTS x 1 x H' x W'.
51                With H' = W'= 256.
52        """
53        raise NotImplementedError(
54            "PromptGeneratorBase is just a class template. "
55            "Use a child class that implements the specific generator instead"
56        )

PromptGeneratorBase is an interface to implement specific prompt generators.

class PointAndBoxPromptGenerator(PromptGeneratorBase):
 59class PointAndBoxPromptGenerator(PromptGeneratorBase):
 60    """Generate point and/or box prompts from an instance segmentation.
 61
 62    You can use this class to derive prompts from an instance segmentation, either for
 63    evaluation purposes or for training Segment Anything on custom data.
 64    In order to use this generator you need to precompute the bounding boxes and center
 65    coordiantes of the instance segmentation, using e.g. `util.get_centers_and_bounding_boxes`.
 66
 67    Here's an example for how to use this class:
 68    ```python
 69    # Initialize generator for 1 positive and 4 negative point prompts.
 70    prompt_generator = PointAndBoxPromptGenerator(1, 4, dilation_strength=8)
 71
 72    # Precompute the bounding boxes for the given segmentation
 73    bounding_boxes, _ = util.get_centers_and_bounding_boxes(segmentation)
 74
 75    # generate point prompts for the objects with ids 1, 2 and 3
 76    seg_ids = (1, 2, 3)
 77    object_mask = np.stack([segmentation == seg_id for seg_id in seg_ids])[:, None]
 78    this_bounding_boxes = [bounding_boxes[seg_id] for seg_id in seg_ids]
 79    point_coords, point_labels, _, _ = prompt_generator(object_mask, this_bounding_boxes)
 80    ```
 81
 82    Args:
 83        n_positive_points: The number of positive point prompts to generate per mask.
 84        n_negative_points: The number of negative point prompts to generate per mask.
 85        dilation_strength: The factor by which the mask is dilated before generating prompts.
 86        get_point_prompts: Whether to generate point prompts.
 87        get_box_prompts: Whether to generate box prompts.
 88    """
 89    def __init__(
 90        self,
 91        n_positive_points: int,
 92        n_negative_points: int,
 93        dilation_strength: int,
 94        get_point_prompts: bool = True,
 95        get_box_prompts: bool = False
 96    ) -> None:
 97        self.n_positive_points = n_positive_points
 98        self.n_negative_points = n_negative_points
 99        self.dilation_strength = dilation_strength
100        self.get_box_prompts = get_box_prompts
101        self.get_point_prompts = get_point_prompts
102
103        if self.get_point_prompts is False and self.get_box_prompts is False:
104            raise ValueError("You need to request box prompts, point prompts or both.")
105
106    def _sample_positive_points(self, object_mask, center_coordinates, coord_list, label_list):
107        if center_coordinates is not None:
108            # getting the center coordinate as the first positive point (OPTIONAL)
109            coord_list.append(tuple(map(int, center_coordinates)))  # to get int coords instead of float
110
111            # getting the additional positive points by randomly sampling points
112            # from this mask except the center coordinate
113            n_positive_remaining = self.n_positive_points - 1
114
115        else:
116            # need to sample "self.n_positive_points" number of points
117            n_positive_remaining = self.n_positive_points
118
119        if n_positive_remaining > 0:
120            object_coordinates = torch.where(object_mask)
121            n_coordinates = len(object_coordinates[0])
122
123            # randomly sampling n_positive_remaining_points from these coordinates
124            indices = np.random.choice(
125                n_coordinates, size=n_positive_remaining,
126                # Allow replacing if we can't sample enough coordinates otherwise
127                replace=True if n_positive_remaining > n_coordinates else False,
128            )
129            coord_list.extend([
130                [object_coordinates[0][idx], object_coordinates[1][idx]] for idx in indices
131            ])
132
133        label_list.extend([1] * self.n_positive_points)
134        assert len(coord_list) == len(label_list) == self.n_positive_points
135        return coord_list, label_list
136
137    def _sample_negative_points(self, object_mask, bbox_coordinates, coord_list, label_list):
138        if self.n_negative_points == 0:
139            return coord_list, label_list
140
141        # getting the negative points
142        # for this we do the opposite and we set the mask to the bounding box - the object mask
143        # we need to dilate the object mask before doing this: we use kornia.morphology.dilation for this
144        dilated_object = object_mask[None, None]
145        for _ in range(self.dilation_strength):
146            dilated_object = morphology.dilation(dilated_object, torch.ones(3, 3), engine="convolution")
147        dilated_object = dilated_object.squeeze()
148
149        background_mask = torch.zeros(object_mask.shape, device=object_mask.device)
150        _ds = self.dilation_strength
151        background_mask[
152            max(bbox_coordinates[0] - _ds, 0): min(bbox_coordinates[2] + _ds, object_mask.shape[-2]),
153            max(bbox_coordinates[1] - _ds, 0): min(bbox_coordinates[3] + _ds, object_mask.shape[-1])
154        ] = 1
155        background_mask = torch.abs(background_mask - dilated_object)
156
157        # the valid background coordinates
158        background_coordinates = torch.where(background_mask)
159        n_coordinates = len(background_coordinates[0])
160
161        # randomly sample the negative points from these coordinates
162        indices = np.random.choice(
163            n_coordinates, replace=False,
164            size=min(self.n_negative_points, n_coordinates)  # handles the cases with insufficient bg pixels
165        )
166        coord_list.extend([
167            [background_coordinates[0][idx], background_coordinates[1][idx]] for idx in indices
168        ])
169        label_list.extend([0] * len(indices))
170
171        return coord_list, label_list
172
173    def _ensure_num_points(self, object_mask, coord_list, label_list):
174        num_points = self.n_positive_points + self.n_negative_points
175
176        # fill up to the necessary number of points if we did not sample enough of them
177        if len(coord_list) != num_points:
178            # to stay consistent, we add random points in the background of an object
179            # if there's no neg region around the object - usually happens with small rois
180            needed_points = num_points - len(coord_list)
181            more_neg_points = torch.where(object_mask == 0)
182            indices = np.random.choice(len(more_neg_points[0]), size=needed_points, replace=False)
183
184            coord_list.extend([
185                (more_neg_points[0][idx], more_neg_points[1][idx]) for idx in indices
186            ])
187            label_list.extend([0] * needed_points)
188
189        assert len(coord_list) == len(label_list) == num_points
190        return coord_list, label_list
191
192    # Can we batch this properly?
193    def _sample_points(self, segmentation, bbox_coordinates, center_coordinates):
194        all_coords, all_labels = [], []
195
196        center_coordinates = [None] * len(segmentation) if center_coordinates is None else center_coordinates
197        for object_mask, bbox_coords, center_coords in zip(segmentation, bbox_coordinates, center_coordinates):
198            coord_list, label_list = [], []
199            coord_list, label_list = self._sample_positive_points(object_mask[0], center_coords, coord_list, label_list)
200            coord_list, label_list = self._sample_negative_points(object_mask[0], bbox_coords, coord_list, label_list)
201            coord_list, label_list = self._ensure_num_points(object_mask[0], coord_list, label_list)
202
203            all_coords.append(coord_list)
204            all_labels.append(label_list)
205
206        return all_coords, all_labels
207
208    def __call__(
209        self,
210        segmentation: torch.Tensor,
211        bbox_coordinates: List[Tuple],
212        center_coordinates: Optional[List[np.ndarray]] = None,
213        **kwargs,
214    ) -> Tuple[
215        Optional[torch.Tensor],
216        Optional[torch.Tensor],
217        Optional[torch.Tensor],
218        None
219    ]:
220        """Generate the prompts for one object in the segmentation.
221
222        Args:
223            The groundtruth segmentation. Expects a float tensor of shape NUM_OBJECTS x 1 x H x W.
224            bbox_coordinates: The precomputed bounding boxes of particular object in the segmentation.
225            center_coordinates: The precomputed center coordinates of particular object in the segmentation.
226                If passed, these coordinates will be used as the first positive point prompt.
227                If not passed a random point from within the object mask will be used.
228
229        Returns:
230            Coordinates of point prompts. Returns None, if get_point_prompts is false.
231            Point prompt labels. Returns None, if get_point_prompts is false.
232            Bounding box prompts. Returns None, if get_box_prompts is false.
233        """
234        if self.get_point_prompts:
235            coord_list, label_list = self._sample_points(segmentation, bbox_coordinates, center_coordinates)
236            # change the axis convention of the point coordinates to match the expected coordinate order of SAM
237            coord_list = np.array(coord_list)[:, :, ::-1].copy()
238            coord_list = torch.from_numpy(coord_list)
239            label_list = torch.tensor(label_list)
240        else:
241            coord_list, label_list = None, None
242
243        if self.get_box_prompts:
244            # change the axis convention of the box coordinates to match the expected coordinate order of SAM
245            bbox_list = np.array(bbox_coordinates)[:, [1, 0, 3, 2]]
246            bbox_list = torch.from_numpy(bbox_list)
247        else:
248            bbox_list = None
249
250        return coord_list, label_list, bbox_list, None

Generate point and/or box prompts from an instance segmentation.

You can use this class to derive prompts from an instance segmentation, either for evaluation purposes or for training Segment Anything on custom data. In order to use this generator you need to precompute the bounding boxes and center coordiantes of the instance segmentation, using e.g. util.get_centers_and_bounding_boxes.

Here's an example for how to use this class:

# Initialize generator for 1 positive and 4 negative point prompts.
prompt_generator = PointAndBoxPromptGenerator(1, 4, dilation_strength=8)

# Precompute the bounding boxes for the given segmentation
bounding_boxes, _ = util.get_centers_and_bounding_boxes(segmentation)

# generate point prompts for the objects with ids 1, 2 and 3
seg_ids = (1, 2, 3)
object_mask = np.stack([segmentation == seg_id for seg_id in seg_ids])[:, None]
this_bounding_boxes = [bounding_boxes[seg_id] for seg_id in seg_ids]
point_coords, point_labels, _, _ = prompt_generator(object_mask, this_bounding_boxes)
Arguments:
  • n_positive_points: The number of positive point prompts to generate per mask.
  • n_negative_points: The number of negative point prompts to generate per mask.
  • dilation_strength: The factor by which the mask is dilated before generating prompts.
  • get_point_prompts: Whether to generate point prompts.
  • get_box_prompts: Whether to generate box prompts.
PointAndBoxPromptGenerator( n_positive_points: int, n_negative_points: int, dilation_strength: int, get_point_prompts: bool = True, get_box_prompts: bool = False)
 89    def __init__(
 90        self,
 91        n_positive_points: int,
 92        n_negative_points: int,
 93        dilation_strength: int,
 94        get_point_prompts: bool = True,
 95        get_box_prompts: bool = False
 96    ) -> None:
 97        self.n_positive_points = n_positive_points
 98        self.n_negative_points = n_negative_points
 99        self.dilation_strength = dilation_strength
100        self.get_box_prompts = get_box_prompts
101        self.get_point_prompts = get_point_prompts
102
103        if self.get_point_prompts is False and self.get_box_prompts is False:
104            raise ValueError("You need to request box prompts, point prompts or both.")
n_positive_points
n_negative_points
dilation_strength
get_box_prompts
get_point_prompts
class IterativePromptGenerator(PromptGeneratorBase):
253class IterativePromptGenerator(PromptGeneratorBase):
254    """Generate point prompts from an instance segmentation iteratively.
255    """
256    def _get_positive_points(self, pos_region, overlap_region, is_3d):
257        positive_locations = [torch.where(pos_reg) for pos_reg in pos_region]
258        # we may have objects without a positive region (= missing true foreground)
259        # in this case we just sample a positive point where the model was already correct
260        positive_locations = [
261            torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc
262            for pos_loc, ovlp_reg in zip(positive_locations, overlap_region)
263        ]
264        # we sample one positive location for each object in the batch
265        sampled_indices = [np.random.choice(len(pos_loc[0])) for pos_loc in positive_locations]
266        # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM)
267        if is_3d:
268            pos_coordinates = [
269                [pos_loc[-1][idx], pos_loc[-2][idx], pos_loc[-3][idx]]
270                for pos_loc, idx in zip(positive_locations, sampled_indices)
271            ]
272        else:
273            pos_coordinates = [
274                [pos_loc[-1][idx], pos_loc[-2][idx]] for pos_loc, idx in zip(positive_locations, sampled_indices)
275            ]
276
277        # make sure that we still have the correct batch size
278        assert len(pos_coordinates) == pos_region.shape[0]
279        pos_labels = [1] * len(pos_coordinates)
280
281        return pos_coordinates, pos_labels
282
283    def _get_negative_locations_in_obj_bbox(self, true_object, custom_df=3):
284        true_loc = torch.where(true_object)
285        bbox = torch.stack(
286            [torch.min(true_loc[1]), torch.min(true_loc[2]), torch.max(true_loc[1]) + 1, torch.max(true_loc[2]) + 1]
287        )
288
289        # custom dilation factor to perform dilation by expanding the pixels of bbox
290        bbox_mask = torch.zeros_like(true_object).squeeze(0)
291        bbox_mask[
292            max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, true_object.shape[-2]),
293            max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, true_object.shape[-1])
294        ] = 1
295        bbox_mask = bbox_mask[None].to(true_object.device)
296        background_mask = torch.abs(bbox_mask - true_object)
297        return torch.where(background_mask)
298
299    def _get_negative_points(self, neg_region, true_object, is_3d):
300        # we have a valid negative region (i.e. a valid region where the model could not generate prediction)
301        negative_locations = [torch.where(neg_reg) for neg_reg in neg_region]
302        # we may have objects without a negative region (= no rectifications required)
303        # in this case we sample a negative point in outer periphery of the object inside the bounding box.
304        negative_locations = [
305            self._get_negative_locations_in_obj_bbox(true_obj) if len(neg_loc[0]) == 0 else neg_loc
306            for neg_loc, true_obj in zip(negative_locations, true_object)
307        ]
308        # there is a chance that the object is small to not return a decent-sized bounding box
309        # hence we might not find points sometimes there as well. therefore, we sample points from true background.
310        negative_locations = [
311            torch.where(true_obj == 0) if len(neg_loc[0]) == 0 else neg_loc
312            for neg_loc, true_obj in zip(negative_locations, true_object)
313        ]
314        # we sample one negative location for each object in the batch
315        sampled_indices = [np.random.choice(len(neg_loc[0])) for neg_loc in negative_locations]
316        # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM)
317        if is_3d:
318            neg_coordinates = [
319                [neg_loc[-1][idx], neg_loc[-2][idx], neg_loc[-3][idx]]
320                for neg_loc, idx in zip(negative_locations, sampled_indices)
321            ]
322        else:
323            neg_coordinates = [
324                [neg_loc[-1][idx], neg_loc[-2][idx]] for neg_loc, idx in zip(negative_locations, sampled_indices)
325            ]
326
327        # make sure that we still have the correct batch size
328        assert len(neg_coordinates) == neg_region.shape[0]
329        neg_labels = [0] * len(neg_coordinates)
330
331        return neg_coordinates, neg_labels
332
333    def __call__(
334        self,
335        segmentation: torch.Tensor,
336        prediction: torch.Tensor,
337        **kwargs,
338    ) -> Tuple[torch.Tensor, torch.Tensor, None, None]:
339        """Generate the prompts for each object iteratively in the segmentation.
340
341        Args:
342            segmentation: The groundtruth segmentation.
343                Expects a float tensor of shape (NUM_OBJECTS x 1 x H x W) or (NUM_OBJECTS x 1 x Z x H x W).
344            prediction: The predicted objects. Epects a float tensor of the same shape as the segmentation.
345
346        Returns:
347            The updated point prompt coordinates.
348            The updated point prompt labels.
349        """
350        device = prediction.device
351        assert segmentation.shape == prediction.shape, \
352            "The segmentation and prediction tensors should have the same shape."
353
354        if segmentation.ndim == 5:  # masks in 3d must be tensors of shape NUM_OBJECTS x 1 x Z x H x W
355            is_3d = True
356        elif segmentation.ndim == 4:  # masks in 2d must be tensors of shape NUM_OBJECTS x 1 x H x W
357            is_3d = False
358        else:
359            raise ValueError("The segmentation and prediction tensors should have either '4' or '5' dimensions.")
360
361        true_object = segmentation.to(device)
362        expected_diff = (prediction - true_object)
363        neg_region = (expected_diff == 1).to(torch.float32)
364        pos_region = (expected_diff == -1)
365        overlap_region = torch.logical_and(prediction == 1, true_object == 1).to(torch.float32)
366
367        pos_coordinates, pos_labels = self._get_positive_points(pos_region, overlap_region, is_3d)
368        neg_coordinates, neg_labels = self._get_negative_points(neg_region, true_object, is_3d)
369        assert len(pos_coordinates) == len(pos_labels) == len(neg_coordinates) == len(neg_labels)
370
371        pos_coordinates = torch.tensor(pos_coordinates)[:, None]
372        neg_coordinates = torch.tensor(neg_coordinates)[:, None]
373        pos_labels, neg_labels = torch.tensor(pos_labels)[:, None], torch.tensor(neg_labels)[:, None]
374
375        net_coords = torch.cat([pos_coordinates, neg_coordinates], dim=1)
376        net_labels = torch.cat([pos_labels, neg_labels], dim=1)
377
378        return net_coords, net_labels, None, None

Generate point prompts from an instance segmentation iteratively.