micro_sam.prompt_based_segmentation

Functions for prompt-based segmentation with Segment Anything.

  1"""
  2Functions for prompt-based segmentation with Segment Anything.
  3"""
  4
  5import warnings
  6from typing import Optional, Tuple
  7
  8import numpy as np
  9from skimage.filters import gaussian
 10from skimage.feature import peak_local_max
 11from skimage.segmentation import find_boundaries
 12from scipy.ndimage import distance_transform_edt
 13
 14import torch
 15
 16from nifty.tools import blocking
 17
 18from segment_anything.predictor import SamPredictor
 19from segment_anything.utils.transforms import ResizeLongestSide
 20
 21from . import util
 22
 23
 24#
 25# helper functions for translating mask inputs into other prompts
 26#
 27
 28
 29# compute the bounding box from a mask. SAM expects the following input:
 30# box (np.ndarray or None): A length 4 array given a box prompt to the model, in XYXY format.
 31def _compute_box_from_mask(mask, original_size=None, box_extension=0):
 32    coords = np.where(mask == 1)
 33    min_y, min_x = coords[0].min(), coords[1].min()
 34    max_y, max_x = coords[0].max(), coords[1].max()
 35    box = np.array([min_y, min_x, max_y + 1, max_x + 1])
 36    return _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension)
 37
 38
 39# sample points from a mask. SAM expects the following point inputs:
 40def _compute_points_from_mask(mask, original_size, box_extension, use_single_point=False):
 41    box = _compute_box_from_mask(mask, box_extension=box_extension)
 42
 43    # get slice and offset in python coordinate convention
 44    bb = (slice(box[1], box[3]), slice(box[0], box[2]))
 45    offset = np.array([box[1], box[0]])
 46
 47    # crop the mask and compute distances
 48    cropped_mask = mask[bb]
 49    object_boundaries = find_boundaries(cropped_mask, mode="outer")
 50    distances = gaussian(distance_transform_edt(object_boundaries == 0))
 51    inner_distances = distances.copy()
 52    cropped_mask = cropped_mask.astype("bool")
 53    inner_distances[~cropped_mask] = 0.0
 54    if use_single_point:
 55        center = inner_distances.argmax()
 56        center = np.unravel_index(center, inner_distances.shape)
 57        point_coords = (center + offset)[None]
 58        point_labels = np.ones(1, dtype="uint8")
 59        return point_coords[:, ::-1], point_labels
 60
 61    outer_distances = distances.copy()
 62    outer_distances[cropped_mask] = 0.0
 63
 64    # sample positives and negatives from the distance maxima
 65    inner_maxima = peak_local_max(inner_distances, exclude_border=False, min_distance=3)
 66    outer_maxima = peak_local_max(outer_distances, exclude_border=False, min_distance=5)
 67
 68    # derive the positive (=inner maxima) and negative (=outer maxima) points
 69    point_coords = np.concatenate([inner_maxima, outer_maxima]).astype("float64")
 70    point_coords += offset
 71
 72    if original_size is not None:
 73        scale_factor = np.array([
 74            original_size[0] / float(mask.shape[0]), original_size[1] / float(mask.shape[1])
 75        ])[None]
 76        point_coords *= scale_factor
 77
 78    # get the point labels
 79    point_labels = np.concatenate(
 80        [
 81            np.ones(len(inner_maxima), dtype="uint8"),
 82            np.zeros(len(outer_maxima), dtype="uint8"),
 83        ]
 84    )
 85    return point_coords[:, ::-1], point_labels
 86
 87
 88def _compute_logits_from_mask(mask, eps=1e-3):
 89
 90    def inv_sigmoid(x):
 91        return np.log(x / (1 - x))
 92
 93    logits = np.zeros(mask.shape, dtype="float32")
 94    logits[mask == 1] = 1 - eps
 95    logits[mask == 0] = eps
 96    logits = inv_sigmoid(logits)
 97
 98    # resize to the expected mask shape of SAM (256x256)
 99    assert logits.ndim == 2
100    expected_shape = (256, 256)
101
102    if logits.shape == expected_shape:  # shape matches, do nothing
103        pass
104
105    elif logits.shape[0] == logits.shape[1]:  # shape is square
106        trafo = ResizeLongestSide(expected_shape[0])
107        logits = trafo.apply_image_torch(torch.from_numpy(logits[None, None]))
108        logits = logits.numpy().squeeze()
109
110    else:  # shape is not square
111        # resize the longest side to expected shape
112        trafo = ResizeLongestSide(expected_shape[0])
113        logits = trafo.apply_image_torch(torch.from_numpy(logits[None, None]))
114        logits = logits.numpy().squeeze()
115
116        # pad the other side
117        h, w = logits.shape
118        padh = expected_shape[0] - h
119        padw = expected_shape[1] - w
120        # IMPORTANT: need to pad with zero, otherwise SAM doesn't understand the padding
121        pad_width = ((0, padh), (0, padw))
122        logits = np.pad(logits, pad_width, mode="constant", constant_values=0)
123
124    logits = logits[None]
125    assert logits.shape == (1, 256, 256), f"{logits.shape}"
126    return logits
127
128
129#
130# other helper functions
131#
132
133
134def _process_box(box, shape, original_size=None, box_extension=0):
135    if box_extension == 0:  # no extension
136        extension_y, extension_x = 0, 0
137    elif box_extension >= 1:  # extension by a fixed factor
138        extension_y, extension_x = box_extension, box_extension
139    else:  # extension by fraction of the box len
140        len_y, len_x = box[2] - box[0], box[3] - box[1]
141        extension_y, extension_x = box_extension * len_y, box_extension * len_x
142
143    box = np.array([
144        max(box[1] - extension_x, 0), max(box[0] - extension_y, 0),
145        min(box[3] + extension_x, shape[1]), min(box[2] + extension_y, shape[0]),
146    ])
147
148    if original_size is not None:
149        trafo = ResizeLongestSide(max(original_size))
150        box = trafo.apply_boxes(box[None], (256, 256)).squeeze()
151
152    # round up the bounding box values
153    box = np.round(box).astype(int)
154
155    return box
156
157
158# Select the correct tile based on average of points
159# and bring the points to the coordinate system of the tile.
160# Discard points that are not in the tile and warn if this happens.
161def _points_to_tile(prompts, shape, tile_shape, halo):
162    points, labels = prompts
163
164    tiling = blocking([0, 0], shape, tile_shape)
165    center = np.mean(points, axis=0).round().astype("int").tolist()
166    tile_id = tiling.coordinatesToBlockId(center)
167
168    tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
169    offset = tile.begin
170    this_tile_shape = tile.shape
171
172    points_in_tile = points - np.array(offset)
173    labels_in_tile = labels
174
175    valid_point_mask = (points_in_tile >= 0).all(axis=1)
176    valid_point_mask = np.logical_and(
177        valid_point_mask,
178        np.logical_and(
179            points_in_tile[:, 0] < this_tile_shape[0], points_in_tile[:, 1] < this_tile_shape[1]
180        )
181    )
182    if not valid_point_mask.all():
183        points_in_tile = points_in_tile[valid_point_mask]
184        labels_in_tile = labels_in_tile[valid_point_mask]
185        warnings.warn(
186            f"{(~valid_point_mask).sum()} points were not in the tile and are dropped"
187        )
188
189    return tile_id, tile, (points_in_tile, labels_in_tile)
190
191
192def _box_to_tile(box, shape, tile_shape, halo):
193    tiling = blocking([0, 0], shape, tile_shape)
194    center = np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2]).round().astype("int").tolist()
195    tile_id = tiling.coordinatesToBlockId(center)
196
197    tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
198    offset = tile.begin
199    this_tile_shape = tile.shape
200
201    box_in_tile = np.array(
202        [
203            max(box[0] - offset[0], 0), max(box[1] - offset[1], 0),
204            min(box[2] - offset[0], this_tile_shape[0]), min(box[3] - offset[1], this_tile_shape[1])
205        ]
206    )
207
208    return tile_id, tile, box_in_tile
209
210
211def _mask_to_tile(mask, shape, tile_shape, halo):
212    tiling = blocking([0, 0], shape, tile_shape)
213
214    coords = np.where(mask)
215    center = np.array([np.mean(coords[0]), np.mean(coords[1])]).round().astype("int").tolist()
216    tile_id = tiling.coordinatesToBlockId(center)
217
218    tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
219    bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end))
220
221    mask_in_tile = mask[bb]
222    return tile_id, tile, mask_in_tile
223
224
225def _initialize_predictor(predictor, image_embeddings, i, prompts, to_tile):
226    tile = None
227
228    # Set the precomputed state for tiled prediction.
229    if image_embeddings is not None and image_embeddings["input_size"] is None:
230        features = image_embeddings["features"]
231        shape, tile_shape, halo = features.attrs["shape"], features.attrs["tile_shape"], features.attrs["halo"]
232        tile_id, tile, prompts = to_tile(prompts, shape, tile_shape, halo)
233        util.set_precomputed(predictor, image_embeddings, i, tile_id=tile_id)
234
235    # Set the precomputed state for normal prediction.
236    elif image_embeddings is not None:
237        shape = image_embeddings["original_size"]
238        util.set_precomputed(predictor, image_embeddings, i)
239
240    else:
241        shape = predictor.original_size
242
243    return predictor, tile, prompts, shape
244
245
246def _tile_to_full_mask(mask, shape, tile):
247    full_mask = np.zeros(mask.shape[0:1] + tuple(shape), dtype=mask.dtype)
248    bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end))
249    full_mask[(slice(None),) + bb] = mask
250    return full_mask
251
252
253#
254# functions for prompted segmentation:
255# - segment_from_points: use point prompts as input
256# - segment_from_mask: use binary mask as input, support conversion to mask, box and point prompts
257# - segment_from_box: use box prompt as input
258# - segment_from_box_and_points: use box and point prompts as input
259#
260
261
262def segment_from_points(
263    predictor: SamPredictor,
264    points: np.ndarray,
265    labels: np.ndarray,
266    image_embeddings: Optional[util.ImageEmbeddings] = None,
267    i: Optional[int] = None,
268    multimask_output: bool = False,
269    return_all: bool = False,
270    use_best_multimask: Optional[bool] = None,
271):
272    """Segmentation from point prompts.
273
274    Args:
275        predictor: The segment anything predictor.
276        points: The point prompts given in the image coordinate system.
277        labels: The labels (positive or negative) associated with the points.
278        image_embeddings: Optional precomputed image embeddings.
279            Has to be passed if the predictor is not yet initialized.
280         i: Index for the image data. Required if the input data has three spatial dimensions
281             or a time dimension and two spatial dimensions.
282        multimask_output: Whether to return multiple or just a single mask.
283        return_all: Whether to return the score and logits in addition to the mask.
284        use_best_multimask: Whether to use multimask output and then choose the best mask.
285            By default this is used for a single positive point and not otherwise.
286
287    Returns:
288        The binary segmentation mask.
289    """
290    predictor, tile, prompts, shape = _initialize_predictor(
291        predictor, image_embeddings, i, (points, labels), _points_to_tile
292    )
293    points, labels = prompts
294
295    if use_best_multimask is None:
296        use_best_multimask = len(points) == 1 and labels[0] == 1
297    multimask_output_ = multimask_output or use_best_multimask
298
299    # predict the mask
300    mask, scores, logits = predictor.predict(
301        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
302        point_labels=labels,
303        multimask_output=multimask_output_,
304    )
305
306    if use_best_multimask:
307        best_mask_id = np.argmax(scores)
308        mask = mask[best_mask_id][None]
309
310    if tile is not None:
311        mask = _tile_to_full_mask(mask, shape, tile)
312
313    if return_all:
314        return mask, scores, logits
315    else:
316        return mask
317
318
319def segment_from_mask(
320    predictor: SamPredictor,
321    mask: np.ndarray,
322    image_embeddings: Optional[util.ImageEmbeddings] = None,
323    i: Optional[int] = None,
324    use_box: bool = True,
325    use_mask: bool = True,
326    use_points: bool = False,
327    original_size: Optional[Tuple[int, ...]] = None,
328    multimask_output: bool = False,
329    return_all: bool = False,
330    return_logits: bool = False,
331    box_extension: float = 0.0,
332    box: Optional[np.ndarray] = None,
333    points: Optional[np.ndarray] = None,
334    labels: Optional[np.ndarray] = None,
335    use_single_point: bool = False,
336):
337    """Segmentation from a mask prompt.
338
339    Args:
340        predictor: The segment anything predictor.
341        mask: The mask used to derive prompts.
342        image_embeddings: Optional precomputed image embeddings.
343            Has to be passed if the predictor is not yet initialized.
344         i: Index for the image data. Required if the input data has three spatial dimensions
345             or a time dimension and two spatial dimensions.
346        use_box: Whether to derive the bounding box prompt from the mask.
347        use_mask: Whether to use the mask itself as prompt.
348        use_points: Whether to derive point prompts from the mask.
349        original_size: Full image shape. Use this if the mask that is being passed
350            downsampled compared to the original image.
351        multimask_output: Whether to return multiple or just a single mask.
352        return_all: Whether to return the score and logits in addition to the mask.
353        box_extension: Relative factor used to enlarge the bounding box prompt.
354        box: Precomputed bounding box.
355        points: Precomputed point prompts.
356        labels: Positive/negative labels corresponding to the point prompts.
357        use_single_point: Whether to derive just a single point from the mask.
358            In case use_points is true.
359
360    Returns:
361        The binary segmentation mask.
362    """
363    prompts = (mask, box, points, labels)
364
365    def _to_tile(prompts, shape, tile_shape, halo):
366        mask, box, points, labels = prompts
367        tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo)
368        if points is not None:
369            tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
370            if tile_id_points != tile_id:
371                raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.")
372            points, labels = point_prompts
373        if box is not None:
374            tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
375            if tile_id_box != tile_id:
376                raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.")
377        return tile_id, tile, (mask, box, points, labels)
378
379    predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile)
380    mask, box, points, labels = prompts
381
382    if points is not None:
383        if labels is None:
384            raise ValueError("If points are passed you also need to pass labels.")
385        point_coords, point_labels = points, labels
386
387    elif use_points:
388        point_coords, point_labels = _compute_points_from_mask(
389            mask, original_size=original_size, box_extension=box_extension,
390            use_single_point=use_single_point,
391        )
392
393    else:
394        point_coords, point_labels = None, None
395
396    if box is None:
397        box = _compute_box_from_mask(
398            mask, original_size=original_size, box_extension=box_extension
399        ) if use_box else None
400    else:
401        box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension)
402
403    logits = _compute_logits_from_mask(mask) if use_mask else None
404
405    mask, scores, logits = predictor.predict(
406        point_coords=point_coords, point_labels=point_labels,
407        mask_input=logits, box=box,
408        multimask_output=multimask_output, return_logits=return_logits
409    )
410
411    if tile is not None:
412        mask = _tile_to_full_mask(mask, shape, tile)
413
414    if return_all:
415        return mask, scores, logits
416    else:
417        return mask
418
419
420def segment_from_box(
421    predictor: SamPredictor,
422    box: np.ndarray,
423    image_embeddings: Optional[util.ImageEmbeddings] = None,
424    i: Optional[int] = None,
425    multimask_output: bool = False,
426    return_all: bool = False,
427    box_extension: float = 0.0,
428):
429    """Segmentation from a box prompt.
430
431    Args:
432        predictor: The segment anything predictor.
433        box: The box prompt.
434        image_embeddings: Optional precomputed image embeddings.
435            Has to be passed if the predictor is not yet initialized.
436         i: Index for the image data. Required if the input data has three spatial dimensions
437             or a time dimension and two spatial dimensions.
438        multimask_output: Whether to return multiple or just a single mask.
439        return_all: Whether to return the score and logits in addition to the mask.
440        box_extension: Relative factor used to enlarge the bounding box prompt.
441
442    Returns:
443        The binary segmentation mask.
444    """
445    predictor, tile, box, shape = _initialize_predictor(
446        predictor, image_embeddings, i, box, _box_to_tile
447    )
448    mask, scores, logits = predictor.predict(
449        box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output
450    )
451
452    if tile is not None:
453        mask = _tile_to_full_mask(mask, shape, tile)
454
455    if return_all:
456        return mask, scores, logits
457    else:
458        return mask
459
460
461def segment_from_box_and_points(
462    predictor: SamPredictor,
463    box: np.ndarray,
464    points: np.ndarray,
465    labels: np.ndarray,
466    image_embeddings: Optional[util.ImageEmbeddings] = None,
467    i: Optional[int] = None,
468    multimask_output: bool = False,
469    return_all: bool = False,
470):
471    """Segmentation from a box prompt and point prompts.
472
473    Args:
474        predictor: The segment anything predictor.
475        box: The box prompt.
476        points: The point prompts, given in the image coordinates system.
477        labels: The point labels, either positive or negative.
478        image_embeddings: Optional precomputed image embeddings.
479            Has to be passed if the predictor is not yet initialized.
480         i: Index for the image data. Required if the input data has three spatial dimensions
481             or a time dimension and two spatial dimensions.
482        multimask_output: Whether to return multiple or just a single mask.
483        return_all: Whether to return the score and logits in addition to the mask.
484
485    Returns:
486        The binary segmentation mask.
487    """
488    def box_and_points_to_tile(prompts, shape, tile_shape, halo):
489        box, points, labels = prompts
490        tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
491        points, labels = point_prompts
492        tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
493        if tile_id_box != tile_id:
494            raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.")
495        return tile_id, tile, (box, points, labels)
496
497    predictor, tile, prompts, shape = _initialize_predictor(
498        predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile
499    )
500    box, points, labels = prompts
501
502    mask, scores, logits = predictor.predict(
503        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
504        point_labels=labels,
505        box=_process_box(box, shape),
506        multimask_output=multimask_output
507    )
508
509    if tile is not None:
510        mask = _tile_to_full_mask(mask, shape, tile)
511
512    if return_all:
513        return mask, scores, logits
514    else:
515        return mask
def segment_from_points( predictor: segment_anything.predictor.SamPredictor, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, use_best_multimask: Optional[bool] = None):
263def segment_from_points(
264    predictor: SamPredictor,
265    points: np.ndarray,
266    labels: np.ndarray,
267    image_embeddings: Optional[util.ImageEmbeddings] = None,
268    i: Optional[int] = None,
269    multimask_output: bool = False,
270    return_all: bool = False,
271    use_best_multimask: Optional[bool] = None,
272):
273    """Segmentation from point prompts.
274
275    Args:
276        predictor: The segment anything predictor.
277        points: The point prompts given in the image coordinate system.
278        labels: The labels (positive or negative) associated with the points.
279        image_embeddings: Optional precomputed image embeddings.
280            Has to be passed if the predictor is not yet initialized.
281         i: Index for the image data. Required if the input data has three spatial dimensions
282             or a time dimension and two spatial dimensions.
283        multimask_output: Whether to return multiple or just a single mask.
284        return_all: Whether to return the score and logits in addition to the mask.
285        use_best_multimask: Whether to use multimask output and then choose the best mask.
286            By default this is used for a single positive point and not otherwise.
287
288    Returns:
289        The binary segmentation mask.
290    """
291    predictor, tile, prompts, shape = _initialize_predictor(
292        predictor, image_embeddings, i, (points, labels), _points_to_tile
293    )
294    points, labels = prompts
295
296    if use_best_multimask is None:
297        use_best_multimask = len(points) == 1 and labels[0] == 1
298    multimask_output_ = multimask_output or use_best_multimask
299
300    # predict the mask
301    mask, scores, logits = predictor.predict(
302        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
303        point_labels=labels,
304        multimask_output=multimask_output_,
305    )
306
307    if use_best_multimask:
308        best_mask_id = np.argmax(scores)
309        mask = mask[best_mask_id][None]
310
311    if tile is not None:
312        mask = _tile_to_full_mask(mask, shape, tile)
313
314    if return_all:
315        return mask, scores, logits
316    else:
317        return mask

Segmentation from point prompts.

Arguments:
  • predictor: The segment anything predictor.
  • points: The point prompts given in the image coordinate system.
  • labels: The labels (positive or negative) associated with the points.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized. i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • multimask_output: Whether to return multiple or just a single mask.
  • return_all: Whether to return the score and logits in addition to the mask.
  • use_best_multimask: Whether to use multimask output and then choose the best mask. By default this is used for a single positive point and not otherwise.
Returns:

The binary segmentation mask.

def segment_from_mask( predictor: segment_anything.predictor.SamPredictor, mask: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, use_box: bool = True, use_mask: bool = True, use_points: bool = False, original_size: Optional[Tuple[int, ...]] = None, multimask_output: bool = False, return_all: bool = False, return_logits: bool = False, box_extension: float = 0.0, box: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, labels: Optional[numpy.ndarray] = None, use_single_point: bool = False):
320def segment_from_mask(
321    predictor: SamPredictor,
322    mask: np.ndarray,
323    image_embeddings: Optional[util.ImageEmbeddings] = None,
324    i: Optional[int] = None,
325    use_box: bool = True,
326    use_mask: bool = True,
327    use_points: bool = False,
328    original_size: Optional[Tuple[int, ...]] = None,
329    multimask_output: bool = False,
330    return_all: bool = False,
331    return_logits: bool = False,
332    box_extension: float = 0.0,
333    box: Optional[np.ndarray] = None,
334    points: Optional[np.ndarray] = None,
335    labels: Optional[np.ndarray] = None,
336    use_single_point: bool = False,
337):
338    """Segmentation from a mask prompt.
339
340    Args:
341        predictor: The segment anything predictor.
342        mask: The mask used to derive prompts.
343        image_embeddings: Optional precomputed image embeddings.
344            Has to be passed if the predictor is not yet initialized.
345         i: Index for the image data. Required if the input data has three spatial dimensions
346             or a time dimension and two spatial dimensions.
347        use_box: Whether to derive the bounding box prompt from the mask.
348        use_mask: Whether to use the mask itself as prompt.
349        use_points: Whether to derive point prompts from the mask.
350        original_size: Full image shape. Use this if the mask that is being passed
351            downsampled compared to the original image.
352        multimask_output: Whether to return multiple or just a single mask.
353        return_all: Whether to return the score and logits in addition to the mask.
354        box_extension: Relative factor used to enlarge the bounding box prompt.
355        box: Precomputed bounding box.
356        points: Precomputed point prompts.
357        labels: Positive/negative labels corresponding to the point prompts.
358        use_single_point: Whether to derive just a single point from the mask.
359            In case use_points is true.
360
361    Returns:
362        The binary segmentation mask.
363    """
364    prompts = (mask, box, points, labels)
365
366    def _to_tile(prompts, shape, tile_shape, halo):
367        mask, box, points, labels = prompts
368        tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo)
369        if points is not None:
370            tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
371            if tile_id_points != tile_id:
372                raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.")
373            points, labels = point_prompts
374        if box is not None:
375            tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
376            if tile_id_box != tile_id:
377                raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.")
378        return tile_id, tile, (mask, box, points, labels)
379
380    predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile)
381    mask, box, points, labels = prompts
382
383    if points is not None:
384        if labels is None:
385            raise ValueError("If points are passed you also need to pass labels.")
386        point_coords, point_labels = points, labels
387
388    elif use_points:
389        point_coords, point_labels = _compute_points_from_mask(
390            mask, original_size=original_size, box_extension=box_extension,
391            use_single_point=use_single_point,
392        )
393
394    else:
395        point_coords, point_labels = None, None
396
397    if box is None:
398        box = _compute_box_from_mask(
399            mask, original_size=original_size, box_extension=box_extension
400        ) if use_box else None
401    else:
402        box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension)
403
404    logits = _compute_logits_from_mask(mask) if use_mask else None
405
406    mask, scores, logits = predictor.predict(
407        point_coords=point_coords, point_labels=point_labels,
408        mask_input=logits, box=box,
409        multimask_output=multimask_output, return_logits=return_logits
410    )
411
412    if tile is not None:
413        mask = _tile_to_full_mask(mask, shape, tile)
414
415    if return_all:
416        return mask, scores, logits
417    else:
418        return mask

Segmentation from a mask prompt.

Arguments:
  • predictor: The segment anything predictor.
  • mask: The mask used to derive prompts.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized. i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • use_box: Whether to derive the bounding box prompt from the mask.
  • use_mask: Whether to use the mask itself as prompt.
  • use_points: Whether to derive point prompts from the mask.
  • original_size: Full image shape. Use this if the mask that is being passed downsampled compared to the original image.
  • multimask_output: Whether to return multiple or just a single mask.
  • return_all: Whether to return the score and logits in addition to the mask.
  • box_extension: Relative factor used to enlarge the bounding box prompt.
  • box: Precomputed bounding box.
  • points: Precomputed point prompts.
  • labels: Positive/negative labels corresponding to the point prompts.
  • use_single_point: Whether to derive just a single point from the mask. In case use_points is true.
Returns:

The binary segmentation mask.

def segment_from_box( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, box_extension: float = 0.0):
421def segment_from_box(
422    predictor: SamPredictor,
423    box: np.ndarray,
424    image_embeddings: Optional[util.ImageEmbeddings] = None,
425    i: Optional[int] = None,
426    multimask_output: bool = False,
427    return_all: bool = False,
428    box_extension: float = 0.0,
429):
430    """Segmentation from a box prompt.
431
432    Args:
433        predictor: The segment anything predictor.
434        box: The box prompt.
435        image_embeddings: Optional precomputed image embeddings.
436            Has to be passed if the predictor is not yet initialized.
437         i: Index for the image data. Required if the input data has three spatial dimensions
438             or a time dimension and two spatial dimensions.
439        multimask_output: Whether to return multiple or just a single mask.
440        return_all: Whether to return the score and logits in addition to the mask.
441        box_extension: Relative factor used to enlarge the bounding box prompt.
442
443    Returns:
444        The binary segmentation mask.
445    """
446    predictor, tile, box, shape = _initialize_predictor(
447        predictor, image_embeddings, i, box, _box_to_tile
448    )
449    mask, scores, logits = predictor.predict(
450        box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output
451    )
452
453    if tile is not None:
454        mask = _tile_to_full_mask(mask, shape, tile)
455
456    if return_all:
457        return mask, scores, logits
458    else:
459        return mask

Segmentation from a box prompt.

Arguments:
  • predictor: The segment anything predictor.
  • box: The box prompt.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized. i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • multimask_output: Whether to return multiple or just a single mask.
  • return_all: Whether to return the score and logits in addition to the mask.
  • box_extension: Relative factor used to enlarge the bounding box prompt.
Returns:

The binary segmentation mask.

def segment_from_box_and_points( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False):
462def segment_from_box_and_points(
463    predictor: SamPredictor,
464    box: np.ndarray,
465    points: np.ndarray,
466    labels: np.ndarray,
467    image_embeddings: Optional[util.ImageEmbeddings] = None,
468    i: Optional[int] = None,
469    multimask_output: bool = False,
470    return_all: bool = False,
471):
472    """Segmentation from a box prompt and point prompts.
473
474    Args:
475        predictor: The segment anything predictor.
476        box: The box prompt.
477        points: The point prompts, given in the image coordinates system.
478        labels: The point labels, either positive or negative.
479        image_embeddings: Optional precomputed image embeddings.
480            Has to be passed if the predictor is not yet initialized.
481         i: Index for the image data. Required if the input data has three spatial dimensions
482             or a time dimension and two spatial dimensions.
483        multimask_output: Whether to return multiple or just a single mask.
484        return_all: Whether to return the score and logits in addition to the mask.
485
486    Returns:
487        The binary segmentation mask.
488    """
489    def box_and_points_to_tile(prompts, shape, tile_shape, halo):
490        box, points, labels = prompts
491        tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
492        points, labels = point_prompts
493        tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
494        if tile_id_box != tile_id:
495            raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.")
496        return tile_id, tile, (box, points, labels)
497
498    predictor, tile, prompts, shape = _initialize_predictor(
499        predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile
500    )
501    box, points, labels = prompts
502
503    mask, scores, logits = predictor.predict(
504        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
505        point_labels=labels,
506        box=_process_box(box, shape),
507        multimask_output=multimask_output
508    )
509
510    if tile is not None:
511        mask = _tile_to_full_mask(mask, shape, tile)
512
513    if return_all:
514        return mask, scores, logits
515    else:
516        return mask

Segmentation from a box prompt and point prompts.

Arguments:
  • predictor: The segment anything predictor.
  • box: The box prompt.
  • points: The point prompts, given in the image coordinates system.
  • labels: The point labels, either positive or negative.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized. i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • multimask_output: Whether to return multiple or just a single mask.
  • return_all: Whether to return the score and logits in addition to the mask.
Returns:

The binary segmentation mask.