micro_sam.prompt_generators

Classes for generating prompts from ground-truth segmentation masks. For training or evaluation of prompt-based segmentation.

  1"""
  2Classes for generating prompts from ground-truth segmentation masks.
  3For training or evaluation of prompt-based segmentation.
  4"""
  5
  6from typing import List, Optional, Tuple
  7
  8import numpy as np
  9from kornia import morphology
 10
 11import torch
 12
 13
 14class PromptGeneratorBase:
 15    """PromptGeneratorBase is an interface to implement specific prompt generators.
 16    """
 17    def __call__(
 18            self,
 19            segmentation: torch.Tensor,
 20            prediction: Optional[torch.Tensor] = None,
 21            bbox_coordinates: Optional[List[tuple]] = None,
 22            center_coordinates: Optional[List[np.ndarray]] = None
 23    ) -> Tuple[
 24        Optional[torch.Tensor],  # the point coordinates
 25        Optional[torch.Tensor],  # the point labels
 26        Optional[torch.Tensor],  # the bounding boxes
 27        Optional[torch.Tensor],  # the mask prompts
 28    ]:
 29        """Return the point prompts given segmentation masks and optional other inputs.
 30
 31        Args:
 32            segmentation: The object masks derived from instance segmentation groundtruth.
 33                Expects a float tensor of shape NUM_OBJECTS x 1 x H x W.
 34                The first axis corresponds to the binary object masks.
 35            prediction: The predicted object masks corresponding to the segmentation.
 36                Expects the same shape as the segmentation
 37            bbox_coordinates: Precomputed bounding boxes for the segmentation.
 38                Expects a list of length NUM_OBJECTS.
 39            center_coordinates: Precomputed center coordinates for the segmentation.
 40                Expects a list of length NUM_OBJECTS.
 41
 42        Returns:
 43            The point prompt coordinates. Int tensor of shape NUM_OBJECTS x NUM_POINTS x 2.
 44                The point coordinates are retuned in XY axis order. This means they are reversed compared
 45                to the standard YX axis order used by numpy.
 46            The point prompt labels. Int tensor of shape NUM_OBJECTS x NUM_POINTS.
 47            The box prompts. Int tensor of shape NUM_OBJECTS x 4.
 48                The box coordinates are retunred as MIN_X, MIN_Y, MAX_X, MAX_Y.
 49            The mask prompts. Float tensor of shape NUM_OBJECTS x 1 x H' x W'.
 50                With H' = W'= 256.
 51        """
 52        raise NotImplementedError("PromptGeneratorBase is just a class template. \
 53                                  Use a child class that implements the specific generator instead")
 54
 55
 56class PointAndBoxPromptGenerator(PromptGeneratorBase):
 57    """Generate point and/or box prompts from an instance segmentation.
 58
 59    You can use this class to derive prompts from an instance segmentation, either for
 60    evaluation purposes or for training Segment Anything on custom data.
 61    In order to use this generator you need to precompute the bounding boxes and center
 62    coordiantes of the instance segmentation, using e.g. `util.get_centers_and_bounding_boxes`.
 63
 64    Here's an example for how to use this class:
 65    ```python
 66    # Initialize generator for 1 positive and 4 negative point prompts.
 67    prompt_generator = PointAndBoxPromptGenerator(1, 4, dilation_strength=8)
 68
 69    # Precompute the bounding boxes for the given segmentation
 70    bounding_boxes, _ = util.get_centers_and_bounding_boxes(segmentation)
 71
 72    # generate point prompts for the objects with ids 1, 2 and 3
 73    seg_ids = (1, 2, 3)
 74    object_mask = np.stack([segmentation == seg_id for seg_id in seg_ids])[:, None]
 75    this_bounding_boxes = [bounding_boxes[seg_id] for seg_id in seg_ids]
 76    point_coords, point_labels, _, _ = prompt_generator(object_mask, this_bounding_boxes)
 77    ```
 78
 79    Args:
 80        n_positive_points: The number of positive point prompts to generate per mask.
 81        n_negative_points: The number of negative point prompts to generate per mask.
 82        dilation_strength: The factor by which the mask is dilated before generating prompts.
 83        get_point_prompts: Whether to generate point prompts.
 84        get_box_prompts: Whether to generate box prompts.
 85    """
 86    def __init__(
 87        self,
 88        n_positive_points: int,
 89        n_negative_points: int,
 90        dilation_strength: int,
 91        get_point_prompts: bool = True,
 92        get_box_prompts: bool = False
 93    ) -> None:
 94        self.n_positive_points = n_positive_points
 95        self.n_negative_points = n_negative_points
 96        self.dilation_strength = dilation_strength
 97        self.get_box_prompts = get_box_prompts
 98        self.get_point_prompts = get_point_prompts
 99
100        if self.get_point_prompts is False and self.get_box_prompts is False:
101            raise ValueError("You need to request box prompts, point prompts or both.")
102
103    def _sample_positive_points(self, object_mask, center_coordinates, coord_list, label_list):
104        if center_coordinates is not None:
105            # getting the center coordinate as the first positive point (OPTIONAL)
106            coord_list.append(tuple(map(int, center_coordinates)))  # to get int coords instead of float
107
108            # getting the additional positive points by randomly sampling points
109            # from this mask except the center coordinate
110            n_positive_remaining = self.n_positive_points - 1
111
112        else:
113            # need to sample "self.n_positive_points" number of points
114            n_positive_remaining = self.n_positive_points
115
116        if n_positive_remaining > 0:
117            object_coordinates = torch.where(object_mask)
118            n_coordinates = len(object_coordinates[0])
119
120            # randomly sampling n_positive_remaining_points from these coordinates
121            indices = np.random.choice(
122                n_coordinates, size=n_positive_remaining,
123                # Allow replacing if we can't sample enough coordinates otherwise
124                replace=True if n_positive_remaining > n_coordinates else False,
125            )
126            coord_list.extend([
127                [object_coordinates[0][idx], object_coordinates[1][idx]] for idx in indices
128            ])
129
130        label_list.extend([1] * self.n_positive_points)
131        assert len(coord_list) == len(label_list) == self.n_positive_points
132        return coord_list, label_list
133
134    def _sample_negative_points(self, object_mask, bbox_coordinates, coord_list, label_list):
135        if self.n_negative_points == 0:
136            return coord_list, label_list
137
138        # getting the negative points
139        # for this we do the opposite and we set the mask to the bounding box - the object mask
140        # we need to dilate the object mask before doing this: we use kornia.morphology.dilation for this
141        dilated_object = object_mask[None, None]
142        for _ in range(self.dilation_strength):
143            dilated_object = morphology.dilation(dilated_object, torch.ones(3, 3), engine="convolution")
144        dilated_object = dilated_object.squeeze()
145
146        background_mask = torch.zeros(object_mask.shape, device=object_mask.device)
147        _ds = self.dilation_strength
148        background_mask[max(bbox_coordinates[0] - _ds, 0): min(bbox_coordinates[2] + _ds, object_mask.shape[-2]),
149                        max(bbox_coordinates[1] - _ds, 0): min(bbox_coordinates[3] + _ds, object_mask.shape[-1])] = 1
150        background_mask = torch.abs(background_mask - dilated_object)
151
152        # the valid background coordinates
153        background_coordinates = torch.where(background_mask)
154        n_coordinates = len(background_coordinates[0])
155
156        # randomly sample the negative points from these coordinates
157        indices = np.random.choice(
158            n_coordinates, replace=False,
159            size=min(self.n_negative_points, n_coordinates)  # handles the cases with insufficient bg pixels
160        )
161        coord_list.extend([
162            [background_coordinates[0][idx], background_coordinates[1][idx]] for idx in indices
163        ])
164        label_list.extend([0] * len(indices))
165
166        return coord_list, label_list
167
168    def _ensure_num_points(self, object_mask, coord_list, label_list):
169        num_points = self.n_positive_points + self.n_negative_points
170
171        # fill up to the necessary number of points if we did not sample enough of them
172        if len(coord_list) != num_points:
173            # to stay consistent, we add random points in the background of an object
174            # if there's no neg region around the object - usually happens with small rois
175            needed_points = num_points - len(coord_list)
176            more_neg_points = torch.where(object_mask == 0)
177            indices = np.random.choice(len(more_neg_points[0]), size=needed_points, replace=False)
178
179            coord_list.extend([
180                (more_neg_points[0][idx], more_neg_points[1][idx]) for idx in indices
181            ])
182            label_list.extend([0] * needed_points)
183
184        assert len(coord_list) == len(label_list) == num_points
185        return coord_list, label_list
186
187    # Can we batch this properly?
188    def _sample_points(self, segmentation, bbox_coordinates, center_coordinates):
189        all_coords, all_labels = [], []
190
191        center_coordinates = [None] * len(segmentation) if center_coordinates is None else center_coordinates
192        for object_mask, bbox_coords, center_coords in zip(segmentation, bbox_coordinates, center_coordinates):
193            coord_list, label_list = [], []
194            coord_list, label_list = self._sample_positive_points(object_mask[0], center_coords, coord_list, label_list)
195            coord_list, label_list = self._sample_negative_points(object_mask[0], bbox_coords, coord_list, label_list)
196            coord_list, label_list = self._ensure_num_points(object_mask[0], coord_list, label_list)
197
198            all_coords.append(coord_list)
199            all_labels.append(label_list)
200
201        return all_coords, all_labels
202
203    def __call__(
204        self,
205        segmentation: torch.Tensor,
206        bbox_coordinates: List[Tuple],
207        center_coordinates: Optional[List[np.ndarray]] = None,
208        **kwargs,
209    ) -> Tuple[
210        Optional[torch.Tensor],
211        Optional[torch.Tensor],
212        Optional[torch.Tensor],
213        None
214    ]:
215        """Generate the prompts for one object in the segmentation.
216
217        Args:
218            The groundtruth segmentation. Expects a float tensor of shape NUM_OBJECTS x 1 x H x W.
219            bbox_coordinates: The precomputed bounding boxes of particular object in the segmentation.
220            center_coordinates: The precomputed center coordinates of particular object in the segmentation.
221                If passed, these coordinates will be used as the first positive point prompt.
222                If not passed a random point from within the object mask will be used.
223
224        Returns:
225            Coordinates of point prompts. Returns None, if get_point_prompts is false.
226            Point prompt labels. Returns None, if get_point_prompts is false.
227            Bounding box prompts. Returns None, if get_box_prompts is false.
228        """
229        if self.get_point_prompts:
230            coord_list, label_list = self._sample_points(segmentation, bbox_coordinates, center_coordinates)
231            # change the axis convention of the point coordinates to match the expected coordinate order of SAM
232            coord_list = np.array(coord_list)[:, :, ::-1].copy()
233            coord_list = torch.from_numpy(coord_list)
234            label_list = torch.tensor(label_list)
235        else:
236            coord_list, label_list = None, None
237
238        if self.get_box_prompts:
239            # change the axis convention of the box coordinates to match the expected coordinate order of SAM
240            bbox_list = np.array(bbox_coordinates)[:, [1, 0, 3, 2]]
241            bbox_list = torch.from_numpy(bbox_list)
242        else:
243            bbox_list = None
244
245        return coord_list, label_list, bbox_list, None
246
247
248class IterativePromptGenerator(PromptGeneratorBase):
249    """Generate point prompts from an instance segmentation iteratively.
250    """
251    def _get_positive_points(self, pos_region, overlap_region, is_3d):
252        positive_locations = [torch.where(pos_reg) for pos_reg in pos_region]
253        # we may have objects without a positive region (= missing true foreground)
254        # in this case we just sample a positive point where the model was already correct
255        positive_locations = [
256            torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc
257            for pos_loc, ovlp_reg in zip(positive_locations, overlap_region)
258        ]
259        # we sample one positive location for each object in the batch
260        sampled_indices = [np.random.choice(len(pos_loc[0])) for pos_loc in positive_locations]
261        # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM)
262        if is_3d:
263            pos_coordinates = [
264                [pos_loc[-1][idx], pos_loc[-2][idx], pos_loc[-3][idx]]
265                for pos_loc, idx in zip(positive_locations, sampled_indices)
266            ]
267        else:
268            pos_coordinates = [
269                [pos_loc[-1][idx], pos_loc[-2][idx]] for pos_loc, idx in zip(positive_locations, sampled_indices)
270            ]
271
272        # make sure that we still have the correct batch size
273        assert len(pos_coordinates) == pos_region.shape[0]
274        pos_labels = [1] * len(pos_coordinates)
275
276        return pos_coordinates, pos_labels
277
278    def _get_negative_locations_in_obj_bbox(self, true_object, custom_df=3):
279        true_loc = torch.where(true_object)
280        bbox = torch.stack(
281            [torch.min(true_loc[1]), torch.min(true_loc[2]), torch.max(true_loc[1]) + 1, torch.max(true_loc[2]) + 1]
282        )
283
284        # custom dilation factor to perform dilation by expanding the pixels of bbox
285        bbox_mask = torch.zeros_like(true_object).squeeze(0)
286        bbox_mask[
287            max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, true_object.shape[-2]),
288            max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, true_object.shape[-1])
289        ] = 1
290        bbox_mask = bbox_mask[None].to(true_object.device)
291        background_mask = torch.abs(bbox_mask - true_object)
292        return torch.where(background_mask)
293
294    def _get_negative_points(self, neg_region, true_object, is_3d):
295        # we have a valid negative region (i.e. a valid region where the model could not generate prediction)
296        negative_locations = [torch.where(neg_reg) for neg_reg in neg_region]
297        # we may have objects without a negative region (= no rectifications required)
298        # in this case we sample a negative point in outer periphery of the object inside the bounding box.
299        negative_locations = [
300            self._get_negative_locations_in_obj_bbox(true_obj) if len(neg_loc[0]) == 0 else neg_loc
301            for neg_loc, true_obj in zip(negative_locations, true_object)
302        ]
303        # there is a chance that the object is small to not return a decent-sized bounding box
304        # hence we might not find points sometimes there as well. therefore, we sample points from true background.
305        negative_locations = [
306            torch.where(true_obj == 0) if len(neg_loc[0]) == 0 else neg_loc
307            for neg_loc, true_obj in zip(negative_locations, true_object)
308        ]
309        # we sample one negative location for each object in the batch
310        sampled_indices = [np.random.choice(len(neg_loc[0])) for neg_loc in negative_locations]
311        # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM)
312        if is_3d:
313            neg_coordinates = [
314                [neg_loc[-1][idx], neg_loc[-2][idx], neg_loc[-3][idx]]
315                for neg_loc, idx in zip(negative_locations, sampled_indices)
316            ]
317        else:
318            neg_coordinates = [
319                [neg_loc[-1][idx], neg_loc[-2][idx]] for neg_loc, idx in zip(negative_locations, sampled_indices)
320            ]
321
322        # make sure that we still have the correct batch size
323        assert len(neg_coordinates) == neg_region.shape[0]
324        neg_labels = [0] * len(neg_coordinates)
325
326        return neg_coordinates, neg_labels
327
328    def __call__(
329        self,
330        segmentation: torch.Tensor,
331        prediction: torch.Tensor,
332        **kwargs,
333    ) -> Tuple[torch.Tensor, torch.Tensor, None, None]:
334        """Generate the prompts for each object iteratively in the segmentation.
335
336        Args:
337            segmentation: The groundtruth segmentation.
338                Expects a float tensor of shape (NUM_OBJECTS x 1 x H x W) or (NUM_OBJECTS x 1 x Z x H x W).
339            prediction: The predicted objects. Epects a float tensor of the same shape as the segmentation.
340
341        Returns:
342            The updated point prompt coordinates.
343            The updated point prompt labels.
344        """
345        device = prediction.device
346        assert segmentation.shape == prediction.shape, \
347            "The segmentation and prediction tensors should have the same shape."
348
349        if segmentation.ndim == 5:  # masks in 3d must be tensors of shape NUM_OBJECTS x 1 x Z x H x W
350            is_3d = True
351        elif segmentation.ndim == 4:  # masks in 2d must be tensors of shape NUM_OBJECTS x 1 x H x W
352            is_3d = False
353        else:
354            raise ValueError("The segmentation and prediction tensors should have either '4' or '5' dimensions.")
355
356        true_object = segmentation.to(device)
357        expected_diff = (prediction - true_object)
358        neg_region = (expected_diff == 1).to(torch.float32)
359        pos_region = (expected_diff == -1)
360        overlap_region = torch.logical_and(prediction == 1, true_object == 1).to(torch.float32)
361
362        pos_coordinates, pos_labels = self._get_positive_points(pos_region, overlap_region, is_3d)
363        neg_coordinates, neg_labels = self._get_negative_points(neg_region, true_object, is_3d)
364        assert len(pos_coordinates) == len(pos_labels) == len(neg_coordinates) == len(neg_labels)
365
366        pos_coordinates = torch.tensor(pos_coordinates)[:, None]
367        neg_coordinates = torch.tensor(neg_coordinates)[:, None]
368        pos_labels, neg_labels = torch.tensor(pos_labels)[:, None], torch.tensor(neg_labels)[:, None]
369
370        net_coords = torch.cat([pos_coordinates, neg_coordinates], dim=1)
371        net_labels = torch.cat([pos_labels, neg_labels], dim=1)
372
373        return net_coords, net_labels, None, None
class PromptGeneratorBase:
15class PromptGeneratorBase:
16    """PromptGeneratorBase is an interface to implement specific prompt generators.
17    """
18    def __call__(
19            self,
20            segmentation: torch.Tensor,
21            prediction: Optional[torch.Tensor] = None,
22            bbox_coordinates: Optional[List[tuple]] = None,
23            center_coordinates: Optional[List[np.ndarray]] = None
24    ) -> Tuple[
25        Optional[torch.Tensor],  # the point coordinates
26        Optional[torch.Tensor],  # the point labels
27        Optional[torch.Tensor],  # the bounding boxes
28        Optional[torch.Tensor],  # the mask prompts
29    ]:
30        """Return the point prompts given segmentation masks and optional other inputs.
31
32        Args:
33            segmentation: The object masks derived from instance segmentation groundtruth.
34                Expects a float tensor of shape NUM_OBJECTS x 1 x H x W.
35                The first axis corresponds to the binary object masks.
36            prediction: The predicted object masks corresponding to the segmentation.
37                Expects the same shape as the segmentation
38            bbox_coordinates: Precomputed bounding boxes for the segmentation.
39                Expects a list of length NUM_OBJECTS.
40            center_coordinates: Precomputed center coordinates for the segmentation.
41                Expects a list of length NUM_OBJECTS.
42
43        Returns:
44            The point prompt coordinates. Int tensor of shape NUM_OBJECTS x NUM_POINTS x 2.
45                The point coordinates are retuned in XY axis order. This means they are reversed compared
46                to the standard YX axis order used by numpy.
47            The point prompt labels. Int tensor of shape NUM_OBJECTS x NUM_POINTS.
48            The box prompts. Int tensor of shape NUM_OBJECTS x 4.
49                The box coordinates are retunred as MIN_X, MIN_Y, MAX_X, MAX_Y.
50            The mask prompts. Float tensor of shape NUM_OBJECTS x 1 x H' x W'.
51                With H' = W'= 256.
52        """
53        raise NotImplementedError("PromptGeneratorBase is just a class template. \
54                                  Use a child class that implements the specific generator instead")

PromptGeneratorBase is an interface to implement specific prompt generators.

class PointAndBoxPromptGenerator(PromptGeneratorBase):
 57class PointAndBoxPromptGenerator(PromptGeneratorBase):
 58    """Generate point and/or box prompts from an instance segmentation.
 59
 60    You can use this class to derive prompts from an instance segmentation, either for
 61    evaluation purposes or for training Segment Anything on custom data.
 62    In order to use this generator you need to precompute the bounding boxes and center
 63    coordiantes of the instance segmentation, using e.g. `util.get_centers_and_bounding_boxes`.
 64
 65    Here's an example for how to use this class:
 66    ```python
 67    # Initialize generator for 1 positive and 4 negative point prompts.
 68    prompt_generator = PointAndBoxPromptGenerator(1, 4, dilation_strength=8)
 69
 70    # Precompute the bounding boxes for the given segmentation
 71    bounding_boxes, _ = util.get_centers_and_bounding_boxes(segmentation)
 72
 73    # generate point prompts for the objects with ids 1, 2 and 3
 74    seg_ids = (1, 2, 3)
 75    object_mask = np.stack([segmentation == seg_id for seg_id in seg_ids])[:, None]
 76    this_bounding_boxes = [bounding_boxes[seg_id] for seg_id in seg_ids]
 77    point_coords, point_labels, _, _ = prompt_generator(object_mask, this_bounding_boxes)
 78    ```
 79
 80    Args:
 81        n_positive_points: The number of positive point prompts to generate per mask.
 82        n_negative_points: The number of negative point prompts to generate per mask.
 83        dilation_strength: The factor by which the mask is dilated before generating prompts.
 84        get_point_prompts: Whether to generate point prompts.
 85        get_box_prompts: Whether to generate box prompts.
 86    """
 87    def __init__(
 88        self,
 89        n_positive_points: int,
 90        n_negative_points: int,
 91        dilation_strength: int,
 92        get_point_prompts: bool = True,
 93        get_box_prompts: bool = False
 94    ) -> None:
 95        self.n_positive_points = n_positive_points
 96        self.n_negative_points = n_negative_points
 97        self.dilation_strength = dilation_strength
 98        self.get_box_prompts = get_box_prompts
 99        self.get_point_prompts = get_point_prompts
100
101        if self.get_point_prompts is False and self.get_box_prompts is False:
102            raise ValueError("You need to request box prompts, point prompts or both.")
103
104    def _sample_positive_points(self, object_mask, center_coordinates, coord_list, label_list):
105        if center_coordinates is not None:
106            # getting the center coordinate as the first positive point (OPTIONAL)
107            coord_list.append(tuple(map(int, center_coordinates)))  # to get int coords instead of float
108
109            # getting the additional positive points by randomly sampling points
110            # from this mask except the center coordinate
111            n_positive_remaining = self.n_positive_points - 1
112
113        else:
114            # need to sample "self.n_positive_points" number of points
115            n_positive_remaining = self.n_positive_points
116
117        if n_positive_remaining > 0:
118            object_coordinates = torch.where(object_mask)
119            n_coordinates = len(object_coordinates[0])
120
121            # randomly sampling n_positive_remaining_points from these coordinates
122            indices = np.random.choice(
123                n_coordinates, size=n_positive_remaining,
124                # Allow replacing if we can't sample enough coordinates otherwise
125                replace=True if n_positive_remaining > n_coordinates else False,
126            )
127            coord_list.extend([
128                [object_coordinates[0][idx], object_coordinates[1][idx]] for idx in indices
129            ])
130
131        label_list.extend([1] * self.n_positive_points)
132        assert len(coord_list) == len(label_list) == self.n_positive_points
133        return coord_list, label_list
134
135    def _sample_negative_points(self, object_mask, bbox_coordinates, coord_list, label_list):
136        if self.n_negative_points == 0:
137            return coord_list, label_list
138
139        # getting the negative points
140        # for this we do the opposite and we set the mask to the bounding box - the object mask
141        # we need to dilate the object mask before doing this: we use kornia.morphology.dilation for this
142        dilated_object = object_mask[None, None]
143        for _ in range(self.dilation_strength):
144            dilated_object = morphology.dilation(dilated_object, torch.ones(3, 3), engine="convolution")
145        dilated_object = dilated_object.squeeze()
146
147        background_mask = torch.zeros(object_mask.shape, device=object_mask.device)
148        _ds = self.dilation_strength
149        background_mask[max(bbox_coordinates[0] - _ds, 0): min(bbox_coordinates[2] + _ds, object_mask.shape[-2]),
150                        max(bbox_coordinates[1] - _ds, 0): min(bbox_coordinates[3] + _ds, object_mask.shape[-1])] = 1
151        background_mask = torch.abs(background_mask - dilated_object)
152
153        # the valid background coordinates
154        background_coordinates = torch.where(background_mask)
155        n_coordinates = len(background_coordinates[0])
156
157        # randomly sample the negative points from these coordinates
158        indices = np.random.choice(
159            n_coordinates, replace=False,
160            size=min(self.n_negative_points, n_coordinates)  # handles the cases with insufficient bg pixels
161        )
162        coord_list.extend([
163            [background_coordinates[0][idx], background_coordinates[1][idx]] for idx in indices
164        ])
165        label_list.extend([0] * len(indices))
166
167        return coord_list, label_list
168
169    def _ensure_num_points(self, object_mask, coord_list, label_list):
170        num_points = self.n_positive_points + self.n_negative_points
171
172        # fill up to the necessary number of points if we did not sample enough of them
173        if len(coord_list) != num_points:
174            # to stay consistent, we add random points in the background of an object
175            # if there's no neg region around the object - usually happens with small rois
176            needed_points = num_points - len(coord_list)
177            more_neg_points = torch.where(object_mask == 0)
178            indices = np.random.choice(len(more_neg_points[0]), size=needed_points, replace=False)
179
180            coord_list.extend([
181                (more_neg_points[0][idx], more_neg_points[1][idx]) for idx in indices
182            ])
183            label_list.extend([0] * needed_points)
184
185        assert len(coord_list) == len(label_list) == num_points
186        return coord_list, label_list
187
188    # Can we batch this properly?
189    def _sample_points(self, segmentation, bbox_coordinates, center_coordinates):
190        all_coords, all_labels = [], []
191
192        center_coordinates = [None] * len(segmentation) if center_coordinates is None else center_coordinates
193        for object_mask, bbox_coords, center_coords in zip(segmentation, bbox_coordinates, center_coordinates):
194            coord_list, label_list = [], []
195            coord_list, label_list = self._sample_positive_points(object_mask[0], center_coords, coord_list, label_list)
196            coord_list, label_list = self._sample_negative_points(object_mask[0], bbox_coords, coord_list, label_list)
197            coord_list, label_list = self._ensure_num_points(object_mask[0], coord_list, label_list)
198
199            all_coords.append(coord_list)
200            all_labels.append(label_list)
201
202        return all_coords, all_labels
203
204    def __call__(
205        self,
206        segmentation: torch.Tensor,
207        bbox_coordinates: List[Tuple],
208        center_coordinates: Optional[List[np.ndarray]] = None,
209        **kwargs,
210    ) -> Tuple[
211        Optional[torch.Tensor],
212        Optional[torch.Tensor],
213        Optional[torch.Tensor],
214        None
215    ]:
216        """Generate the prompts for one object in the segmentation.
217
218        Args:
219            The groundtruth segmentation. Expects a float tensor of shape NUM_OBJECTS x 1 x H x W.
220            bbox_coordinates: The precomputed bounding boxes of particular object in the segmentation.
221            center_coordinates: The precomputed center coordinates of particular object in the segmentation.
222                If passed, these coordinates will be used as the first positive point prompt.
223                If not passed a random point from within the object mask will be used.
224
225        Returns:
226            Coordinates of point prompts. Returns None, if get_point_prompts is false.
227            Point prompt labels. Returns None, if get_point_prompts is false.
228            Bounding box prompts. Returns None, if get_box_prompts is false.
229        """
230        if self.get_point_prompts:
231            coord_list, label_list = self._sample_points(segmentation, bbox_coordinates, center_coordinates)
232            # change the axis convention of the point coordinates to match the expected coordinate order of SAM
233            coord_list = np.array(coord_list)[:, :, ::-1].copy()
234            coord_list = torch.from_numpy(coord_list)
235            label_list = torch.tensor(label_list)
236        else:
237            coord_list, label_list = None, None
238
239        if self.get_box_prompts:
240            # change the axis convention of the box coordinates to match the expected coordinate order of SAM
241            bbox_list = np.array(bbox_coordinates)[:, [1, 0, 3, 2]]
242            bbox_list = torch.from_numpy(bbox_list)
243        else:
244            bbox_list = None
245
246        return coord_list, label_list, bbox_list, None

Generate point and/or box prompts from an instance segmentation.

You can use this class to derive prompts from an instance segmentation, either for evaluation purposes or for training Segment Anything on custom data. In order to use this generator you need to precompute the bounding boxes and center coordiantes of the instance segmentation, using e.g. util.get_centers_and_bounding_boxes.

Here's an example for how to use this class:

# Initialize generator for 1 positive and 4 negative point prompts.
prompt_generator = PointAndBoxPromptGenerator(1, 4, dilation_strength=8)

# Precompute the bounding boxes for the given segmentation
bounding_boxes, _ = util.get_centers_and_bounding_boxes(segmentation)

# generate point prompts for the objects with ids 1, 2 and 3
seg_ids = (1, 2, 3)
object_mask = np.stack([segmentation == seg_id for seg_id in seg_ids])[:, None]
this_bounding_boxes = [bounding_boxes[seg_id] for seg_id in seg_ids]
point_coords, point_labels, _, _ = prompt_generator(object_mask, this_bounding_boxes)
Arguments:
  • n_positive_points: The number of positive point prompts to generate per mask.
  • n_negative_points: The number of negative point prompts to generate per mask.
  • dilation_strength: The factor by which the mask is dilated before generating prompts.
  • get_point_prompts: Whether to generate point prompts.
  • get_box_prompts: Whether to generate box prompts.
PointAndBoxPromptGenerator( n_positive_points: int, n_negative_points: int, dilation_strength: int, get_point_prompts: bool = True, get_box_prompts: bool = False)
 87    def __init__(
 88        self,
 89        n_positive_points: int,
 90        n_negative_points: int,
 91        dilation_strength: int,
 92        get_point_prompts: bool = True,
 93        get_box_prompts: bool = False
 94    ) -> None:
 95        self.n_positive_points = n_positive_points
 96        self.n_negative_points = n_negative_points
 97        self.dilation_strength = dilation_strength
 98        self.get_box_prompts = get_box_prompts
 99        self.get_point_prompts = get_point_prompts
100
101        if self.get_point_prompts is False and self.get_box_prompts is False:
102            raise ValueError("You need to request box prompts, point prompts or both.")
n_positive_points
n_negative_points
dilation_strength
get_box_prompts
get_point_prompts
class IterativePromptGenerator(PromptGeneratorBase):
249class IterativePromptGenerator(PromptGeneratorBase):
250    """Generate point prompts from an instance segmentation iteratively.
251    """
252    def _get_positive_points(self, pos_region, overlap_region, is_3d):
253        positive_locations = [torch.where(pos_reg) for pos_reg in pos_region]
254        # we may have objects without a positive region (= missing true foreground)
255        # in this case we just sample a positive point where the model was already correct
256        positive_locations = [
257            torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc
258            for pos_loc, ovlp_reg in zip(positive_locations, overlap_region)
259        ]
260        # we sample one positive location for each object in the batch
261        sampled_indices = [np.random.choice(len(pos_loc[0])) for pos_loc in positive_locations]
262        # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM)
263        if is_3d:
264            pos_coordinates = [
265                [pos_loc[-1][idx], pos_loc[-2][idx], pos_loc[-3][idx]]
266                for pos_loc, idx in zip(positive_locations, sampled_indices)
267            ]
268        else:
269            pos_coordinates = [
270                [pos_loc[-1][idx], pos_loc[-2][idx]] for pos_loc, idx in zip(positive_locations, sampled_indices)
271            ]
272
273        # make sure that we still have the correct batch size
274        assert len(pos_coordinates) == pos_region.shape[0]
275        pos_labels = [1] * len(pos_coordinates)
276
277        return pos_coordinates, pos_labels
278
279    def _get_negative_locations_in_obj_bbox(self, true_object, custom_df=3):
280        true_loc = torch.where(true_object)
281        bbox = torch.stack(
282            [torch.min(true_loc[1]), torch.min(true_loc[2]), torch.max(true_loc[1]) + 1, torch.max(true_loc[2]) + 1]
283        )
284
285        # custom dilation factor to perform dilation by expanding the pixels of bbox
286        bbox_mask = torch.zeros_like(true_object).squeeze(0)
287        bbox_mask[
288            max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, true_object.shape[-2]),
289            max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, true_object.shape[-1])
290        ] = 1
291        bbox_mask = bbox_mask[None].to(true_object.device)
292        background_mask = torch.abs(bbox_mask - true_object)
293        return torch.where(background_mask)
294
295    def _get_negative_points(self, neg_region, true_object, is_3d):
296        # we have a valid negative region (i.e. a valid region where the model could not generate prediction)
297        negative_locations = [torch.where(neg_reg) for neg_reg in neg_region]
298        # we may have objects without a negative region (= no rectifications required)
299        # in this case we sample a negative point in outer periphery of the object inside the bounding box.
300        negative_locations = [
301            self._get_negative_locations_in_obj_bbox(true_obj) if len(neg_loc[0]) == 0 else neg_loc
302            for neg_loc, true_obj in zip(negative_locations, true_object)
303        ]
304        # there is a chance that the object is small to not return a decent-sized bounding box
305        # hence we might not find points sometimes there as well. therefore, we sample points from true background.
306        negative_locations = [
307            torch.where(true_obj == 0) if len(neg_loc[0]) == 0 else neg_loc
308            for neg_loc, true_obj in zip(negative_locations, true_object)
309        ]
310        # we sample one negative location for each object in the batch
311        sampled_indices = [np.random.choice(len(neg_loc[0])) for neg_loc in negative_locations]
312        # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM)
313        if is_3d:
314            neg_coordinates = [
315                [neg_loc[-1][idx], neg_loc[-2][idx], neg_loc[-3][idx]]
316                for neg_loc, idx in zip(negative_locations, sampled_indices)
317            ]
318        else:
319            neg_coordinates = [
320                [neg_loc[-1][idx], neg_loc[-2][idx]] for neg_loc, idx in zip(negative_locations, sampled_indices)
321            ]
322
323        # make sure that we still have the correct batch size
324        assert len(neg_coordinates) == neg_region.shape[0]
325        neg_labels = [0] * len(neg_coordinates)
326
327        return neg_coordinates, neg_labels
328
329    def __call__(
330        self,
331        segmentation: torch.Tensor,
332        prediction: torch.Tensor,
333        **kwargs,
334    ) -> Tuple[torch.Tensor, torch.Tensor, None, None]:
335        """Generate the prompts for each object iteratively in the segmentation.
336
337        Args:
338            segmentation: The groundtruth segmentation.
339                Expects a float tensor of shape (NUM_OBJECTS x 1 x H x W) or (NUM_OBJECTS x 1 x Z x H x W).
340            prediction: The predicted objects. Epects a float tensor of the same shape as the segmentation.
341
342        Returns:
343            The updated point prompt coordinates.
344            The updated point prompt labels.
345        """
346        device = prediction.device
347        assert segmentation.shape == prediction.shape, \
348            "The segmentation and prediction tensors should have the same shape."
349
350        if segmentation.ndim == 5:  # masks in 3d must be tensors of shape NUM_OBJECTS x 1 x Z x H x W
351            is_3d = True
352        elif segmentation.ndim == 4:  # masks in 2d must be tensors of shape NUM_OBJECTS x 1 x H x W
353            is_3d = False
354        else:
355            raise ValueError("The segmentation and prediction tensors should have either '4' or '5' dimensions.")
356
357        true_object = segmentation.to(device)
358        expected_diff = (prediction - true_object)
359        neg_region = (expected_diff == 1).to(torch.float32)
360        pos_region = (expected_diff == -1)
361        overlap_region = torch.logical_and(prediction == 1, true_object == 1).to(torch.float32)
362
363        pos_coordinates, pos_labels = self._get_positive_points(pos_region, overlap_region, is_3d)
364        neg_coordinates, neg_labels = self._get_negative_points(neg_region, true_object, is_3d)
365        assert len(pos_coordinates) == len(pos_labels) == len(neg_coordinates) == len(neg_labels)
366
367        pos_coordinates = torch.tensor(pos_coordinates)[:, None]
368        neg_coordinates = torch.tensor(neg_coordinates)[:, None]
369        pos_labels, neg_labels = torch.tensor(pos_labels)[:, None], torch.tensor(neg_labels)[:, None]
370
371        net_coords = torch.cat([pos_coordinates, neg_coordinates], dim=1)
372        net_labels = torch.cat([pos_labels, neg_labels], dim=1)
373
374        return net_coords, net_labels, None, None

Generate point prompts from an instance segmentation iteratively.