micro_sam.prompt_based_segmentation

Functions for prompt-based segmentation with Segment Anything.

  1"""Functions for prompt-based segmentation with Segment Anything.
  2"""
  3
  4import warnings
  5from typing import Optional, Tuple
  6
  7import numpy as np
  8from skimage.filters import gaussian
  9from skimage.feature import peak_local_max
 10from skimage.segmentation import find_boundaries
 11from scipy.ndimage import distance_transform_edt
 12
 13import torch
 14
 15from nifty.tools import blocking
 16
 17from segment_anything.predictor import SamPredictor
 18from segment_anything.utils.transforms import ResizeLongestSide
 19
 20from . import util
 21
 22
 23#
 24# helper functions for translating mask inputs into other prompts
 25#
 26
 27
 28# compute the bounding box from a mask. SAM expects the following input:
 29# box (np.ndarray or None): A length 4 array given a box prompt to the model, in XYXY format.
 30def _compute_box_from_mask(mask, original_size=None, box_extension=0):
 31    coords = np.where(mask == 1)
 32    min_y, min_x = coords[0].min(), coords[1].min()
 33    max_y, max_x = coords[0].max(), coords[1].max()
 34    box = np.array([min_y, min_x, max_y + 1, max_x + 1])
 35    return _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension)
 36
 37
 38# sample points from a mask. SAM expects the following point inputs:
 39def _compute_points_from_mask(mask, original_size, box_extension, use_single_point=False):
 40    box = _compute_box_from_mask(mask, box_extension=box_extension)
 41
 42    # get slice and offset in python coordinate convention
 43    bb = (slice(box[1], box[3]), slice(box[0], box[2]))
 44    offset = np.array([box[1], box[0]])
 45
 46    # crop the mask and compute distances
 47    cropped_mask = mask[bb]
 48    object_boundaries = find_boundaries(cropped_mask, mode="outer")
 49    distances = gaussian(distance_transform_edt(object_boundaries == 0))
 50    inner_distances = distances.copy()
 51    cropped_mask = cropped_mask.astype("bool")
 52    inner_distances[~cropped_mask] = 0.0
 53    if use_single_point:
 54        center = inner_distances.argmax()
 55        center = np.unravel_index(center, inner_distances.shape)
 56        point_coords = (center + offset)[None]
 57        point_labels = np.ones(1, dtype="uint8")
 58        return point_coords[:, ::-1], point_labels
 59
 60    outer_distances = distances.copy()
 61    outer_distances[cropped_mask] = 0.0
 62
 63    # sample positives and negatives from the distance maxima
 64    inner_maxima = peak_local_max(inner_distances, exclude_border=False, min_distance=3)
 65    outer_maxima = peak_local_max(outer_distances, exclude_border=False, min_distance=5)
 66
 67    # derive the positive (=inner maxima) and negative (=outer maxima) points
 68    point_coords = np.concatenate([inner_maxima, outer_maxima]).astype("float64")
 69    point_coords += offset
 70
 71    if original_size is not None:
 72        scale_factor = np.array([
 73            original_size[0] / float(mask.shape[0]), original_size[1] / float(mask.shape[1])
 74        ])[None]
 75        point_coords *= scale_factor
 76
 77    # get the point labels
 78    point_labels = np.concatenate(
 79        [np.ones(len(inner_maxima), dtype="uint8"), np.zeros(len(outer_maxima), dtype="uint8")]
 80    )
 81    return point_coords[:, ::-1], point_labels
 82
 83
 84def _compute_logits_from_mask(mask, eps=1e-3):
 85
 86    def inv_sigmoid(x):
 87        return np.log(x / (1 - x))
 88
 89    logits = np.zeros(mask.shape, dtype="float32")
 90    logits[mask == 1] = 1 - eps
 91    logits[mask == 0] = eps
 92    logits = inv_sigmoid(logits)
 93
 94    # resize to the expected mask shape of SAM (256x256)
 95    assert logits.ndim == 2
 96    expected_shape = (256, 256)
 97
 98    if logits.shape == expected_shape:  # shape matches, do nothing
 99        pass
100
101    elif logits.shape[0] == logits.shape[1]:  # shape is square
102        trafo = ResizeLongestSide(expected_shape[0])
103        logits = trafo.apply_image_torch(torch.from_numpy(logits[None, None]))
104        logits = logits.numpy().squeeze()
105
106    else:  # shape is not square
107        # resize the longest side to expected shape
108        trafo = ResizeLongestSide(expected_shape[0])
109        logits = trafo.apply_image_torch(torch.from_numpy(logits[None, None]))
110        logits = logits.numpy().squeeze()
111
112        # pad the other side
113        h, w = logits.shape
114        padh = expected_shape[0] - h
115        padw = expected_shape[1] - w
116        # IMPORTANT: need to pad with zero, otherwise SAM doesn't understand the padding
117        pad_width = ((0, padh), (0, padw))
118        logits = np.pad(logits, pad_width, mode="constant", constant_values=0)
119
120    logits = logits[None]
121    assert logits.shape == (1, 256, 256), f"{logits.shape}"
122    return logits
123
124
125#
126# other helper functions
127#
128
129
130def _process_box(box, shape, original_size=None, box_extension=0):
131    if box_extension == 0:  # no extension
132        extension_y, extension_x = 0, 0
133    elif box_extension >= 1:  # extension by a fixed factor
134        extension_y, extension_x = box_extension, box_extension
135    else:  # extension by fraction of the box len
136        len_y, len_x = box[2] - box[0], box[3] - box[1]
137        extension_y, extension_x = box_extension * len_y, box_extension * len_x
138
139    box = np.array([
140        max(box[1] - extension_x, 0), max(box[0] - extension_y, 0),
141        min(box[3] + extension_x, shape[1]), min(box[2] + extension_y, shape[0]),
142    ])
143
144    if original_size is not None:
145        trafo = ResizeLongestSide(max(original_size))
146        box = trafo.apply_boxes(box[None], (256, 256)).squeeze()
147
148    # round up the bounding box values
149    box = np.round(box).astype(int)
150
151    return box
152
153
154# Select the correct tile based on average of points
155# and bring the points to the coordinate system of the tile.
156# Discard points that are not in the tile and warn if this happens.
157def _points_to_tile(prompts, shape, tile_shape, halo):
158    points, labels = prompts
159
160    tiling = blocking([0, 0], shape, tile_shape)
161    center = np.mean(points, axis=0).round().astype("int").tolist()
162    tile_id = tiling.coordinatesToBlockId(center)
163
164    tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
165    offset = tile.begin
166    this_tile_shape = tile.shape
167
168    points_in_tile = points - np.array(offset)
169    labels_in_tile = labels
170
171    valid_point_mask = (points_in_tile >= 0).all(axis=1)
172    valid_point_mask = np.logical_and(
173        valid_point_mask,
174        np.logical_and(
175            points_in_tile[:, 0] < this_tile_shape[0], points_in_tile[:, 1] < this_tile_shape[1]
176        )
177    )
178    if not valid_point_mask.all():
179        points_in_tile = points_in_tile[valid_point_mask]
180        labels_in_tile = labels_in_tile[valid_point_mask]
181        warnings.warn(
182            f"{(~valid_point_mask).sum()} points were not in the tile and are dropped"
183        )
184
185    return tile_id, tile, (points_in_tile, labels_in_tile)
186
187
188def _box_to_tile(box, shape, tile_shape, halo):
189    tiling = blocking([0, 0], shape, tile_shape)
190    center = np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2]).round().astype("int").tolist()
191    tile_id = tiling.coordinatesToBlockId(center)
192
193    tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
194    offset = tile.begin
195    this_tile_shape = tile.shape
196
197    box_in_tile = np.array(
198        [
199            max(box[0] - offset[0], 0), max(box[1] - offset[1], 0),
200            min(box[2] - offset[0], this_tile_shape[0]), min(box[3] - offset[1], this_tile_shape[1])
201        ]
202    )
203
204    return tile_id, tile, box_in_tile
205
206
207def _mask_to_tile(mask, shape, tile_shape, halo):
208    tiling = blocking([0, 0], shape, tile_shape)
209
210    coords = np.where(mask)
211    center = np.array([np.mean(coords[0]), np.mean(coords[1])]).round().astype("int").tolist()
212    tile_id = tiling.coordinatesToBlockId(center)
213
214    tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
215    bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end))
216
217    mask_in_tile = mask[bb]
218    return tile_id, tile, mask_in_tile
219
220
221def _initialize_predictor(predictor, image_embeddings, i, prompts, to_tile):
222    tile = None
223
224    # Set the precomputed state for tiled prediction.
225    if image_embeddings is not None and image_embeddings["input_size"] is None:
226        features = image_embeddings["features"]
227        shape, tile_shape, halo = features.attrs["shape"], features.attrs["tile_shape"], features.attrs["halo"]
228        tile_id, tile, prompts = to_tile(prompts, shape, tile_shape, halo)
229        util.set_precomputed(predictor, image_embeddings, i, tile_id=tile_id)
230
231    # Set the precomputed state for normal prediction.
232    elif image_embeddings is not None:
233        shape = image_embeddings["original_size"]
234        util.set_precomputed(predictor, image_embeddings, i)
235
236    else:
237        shape = predictor.original_size
238
239    return predictor, tile, prompts, shape
240
241
242def _tile_to_full_mask(mask, shape, tile):
243    full_mask = np.zeros(mask.shape[0:1] + tuple(shape), dtype=mask.dtype)
244    bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end))
245    full_mask[(slice(None),) + bb] = mask
246    return full_mask
247
248
249#
250# functions for prompted segmentation:
251# - segment_from_points: use point prompts as input
252# - segment_from_mask: use binary mask as input, support conversion to mask, box and point prompts
253# - segment_from_box: use box prompt as input
254# - segment_from_box_and_points: use box and point prompts as input
255#
256
257
258def segment_from_points(
259    predictor: SamPredictor,
260    points: np.ndarray,
261    labels: np.ndarray,
262    image_embeddings: Optional[util.ImageEmbeddings] = None,
263    i: Optional[int] = None,
264    multimask_output: bool = False,
265    return_all: bool = False,
266    use_best_multimask: Optional[bool] = None,
267):
268    """Segmentation from point prompts.
269
270    Args:
271        predictor: The segment anything predictor.
272        points: The point prompts given in the image coordinate system.
273        labels: The labels (positive or negative) associated with the points.
274        image_embeddings: Optional precomputed image embeddings.
275            Has to be passed if the predictor is not yet initialized.
276         i: Index for the image data. Required if the input data has three spatial dimensions
277             or a time dimension and two spatial dimensions.
278        multimask_output: Whether to return multiple or just a single mask.
279        return_all: Whether to return the score and logits in addition to the mask.
280        use_best_multimask: Whether to use multimask output and then choose the best mask.
281            By default this is used for a single positive point and not otherwise.
282
283    Returns:
284        The binary segmentation mask.
285    """
286    predictor, tile, prompts, shape = _initialize_predictor(
287        predictor, image_embeddings, i, (points, labels), _points_to_tile
288    )
289    points, labels = prompts
290
291    if use_best_multimask is None:
292        use_best_multimask = len(points) == 1 and labels[0] == 1
293    multimask_output_ = multimask_output or use_best_multimask
294
295    # predict the mask
296    mask, scores, logits = predictor.predict(
297        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
298        point_labels=labels,
299        multimask_output=multimask_output_,
300    )
301
302    if use_best_multimask:
303        best_mask_id = np.argmax(scores)
304        mask = mask[best_mask_id][None]
305
306    if tile is not None:
307        mask = _tile_to_full_mask(mask, shape, tile)
308
309    if return_all:
310        return mask, scores, logits
311    else:
312        return mask
313
314
315def segment_from_mask(
316    predictor: SamPredictor,
317    mask: np.ndarray,
318    image_embeddings: Optional[util.ImageEmbeddings] = None,
319    i: Optional[int] = None,
320    use_box: bool = True,
321    use_mask: bool = True,
322    use_points: bool = False,
323    original_size: Optional[Tuple[int, ...]] = None,
324    multimask_output: bool = False,
325    return_all: bool = False,
326    return_logits: bool = False,
327    box_extension: float = 0.0,
328    box: Optional[np.ndarray] = None,
329    points: Optional[np.ndarray] = None,
330    labels: Optional[np.ndarray] = None,
331    use_single_point: bool = False,
332):
333    """Segmentation from a mask prompt.
334
335    Args:
336        predictor: The segment anything predictor.
337        mask: The mask used to derive prompts.
338        image_embeddings: Optional precomputed image embeddings.
339            Has to be passed if the predictor is not yet initialized.
340        i: Index for the image data. Required if the input data has three spatial dimensions
341             or a time dimension and two spatial dimensions.
342        use_box: Whether to derive the bounding box prompt from the mask.
343        use_mask: Whether to use the mask itself as prompt.
344        use_points: Whether to derive point prompts from the mask.
345        original_size: Full image shape. Use this if the mask that is being passed
346            downsampled compared to the original image.
347        multimask_output: Whether to return multiple or just a single mask.
348        return_all: Whether to return the score and logits in addition to the mask.
349        box_extension: Relative factor used to enlarge the bounding box prompt.
350        box: Precomputed bounding box.
351        points: Precomputed point prompts.
352        labels: Positive/negative labels corresponding to the point prompts.
353        use_single_point: Whether to derive just a single point from the mask.
354            In case use_points is true.
355
356    Returns:
357        The binary segmentation mask.
358    """
359    prompts = (mask, box, points, labels)
360
361    def _to_tile(prompts, shape, tile_shape, halo):
362        mask, box, points, labels = prompts
363        tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo)
364        if points is not None:
365            tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
366            if tile_id_points != tile_id:
367                raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.")
368            points, labels = point_prompts
369        if box is not None:
370            tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
371            if tile_id_box != tile_id:
372                raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.")
373        return tile_id, tile, (mask, box, points, labels)
374
375    predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile)
376    mask, box, points, labels = prompts
377
378    if points is not None:
379        if labels is None:
380            raise ValueError("If points are passed you also need to pass labels.")
381        point_coords, point_labels = points, labels
382
383    elif use_points and mask.sum() != 0:
384        point_coords, point_labels = _compute_points_from_mask(
385            mask, original_size=original_size, box_extension=box_extension,
386            use_single_point=use_single_point,
387        )
388
389    else:
390        point_coords, point_labels = None, None
391
392    if box is None:
393        box = _compute_box_from_mask(
394            mask, original_size=original_size, box_extension=box_extension
395        ) if use_box and mask.sum() != 0 else None
396    else:
397        box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension)
398
399    logits = _compute_logits_from_mask(mask) if use_mask else None
400
401    mask, scores, logits = predictor.predict(
402        point_coords=point_coords, point_labels=point_labels,
403        mask_input=logits, box=box,
404        multimask_output=multimask_output, return_logits=return_logits
405    )
406
407    if tile is not None:
408        mask = _tile_to_full_mask(mask, shape, tile)
409
410    if return_all:
411        return mask, scores, logits
412    else:
413        return mask
414
415
416def segment_from_box(
417    predictor: SamPredictor,
418    box: np.ndarray,
419    image_embeddings: Optional[util.ImageEmbeddings] = None,
420    i: Optional[int] = None,
421    multimask_output: bool = False,
422    return_all: bool = False,
423    box_extension: float = 0.0,
424):
425    """Segmentation from a box prompt.
426
427    Args:
428        predictor: The segment anything predictor.
429        box: The box prompt.
430        image_embeddings: Optional precomputed image embeddings.
431            Has to be passed if the predictor is not yet initialized.
432         i: Index for the image data. Required if the input data has three spatial dimensions
433             or a time dimension and two spatial dimensions.
434        multimask_output: Whether to return multiple or just a single mask.
435        return_all: Whether to return the score and logits in addition to the mask.
436        box_extension: Relative factor used to enlarge the bounding box prompt.
437
438    Returns:
439        The binary segmentation mask.
440    """
441    predictor, tile, box, shape = _initialize_predictor(
442        predictor, image_embeddings, i, box, _box_to_tile
443    )
444    mask, scores, logits = predictor.predict(
445        box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output
446    )
447
448    if tile is not None:
449        mask = _tile_to_full_mask(mask, shape, tile)
450
451    if return_all:
452        return mask, scores, logits
453    else:
454        return mask
455
456
457def segment_from_box_and_points(
458    predictor: SamPredictor,
459    box: np.ndarray,
460    points: np.ndarray,
461    labels: np.ndarray,
462    image_embeddings: Optional[util.ImageEmbeddings] = None,
463    i: Optional[int] = None,
464    multimask_output: bool = False,
465    return_all: bool = False,
466):
467    """Segmentation from a box prompt and point prompts.
468
469    Args:
470        predictor: The segment anything predictor.
471        box: The box prompt.
472        points: The point prompts, given in the image coordinates system.
473        labels: The point labels, either positive or negative.
474        image_embeddings: Optional precomputed image embeddings.
475            Has to be passed if the predictor is not yet initialized.
476         i: Index for the image data. Required if the input data has three spatial dimensions
477             or a time dimension and two spatial dimensions.
478        multimask_output: Whether to return multiple or just a single mask.
479        return_all: Whether to return the score and logits in addition to the mask.
480
481    Returns:
482        The binary segmentation mask.
483    """
484    def box_and_points_to_tile(prompts, shape, tile_shape, halo):
485        box, points, labels = prompts
486        tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
487        points, labels = point_prompts
488        tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
489        if tile_id_box != tile_id:
490            raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.")
491        return tile_id, tile, (box, points, labels)
492
493    predictor, tile, prompts, shape = _initialize_predictor(
494        predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile
495    )
496    box, points, labels = prompts
497
498    mask, scores, logits = predictor.predict(
499        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
500        point_labels=labels,
501        box=_process_box(box, shape),
502        multimask_output=multimask_output
503    )
504
505    if tile is not None:
506        mask = _tile_to_full_mask(mask, shape, tile)
507
508    if return_all:
509        return mask, scores, logits
510    else:
511        return mask
def segment_from_points( predictor: segment_anything.predictor.SamPredictor, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, use_best_multimask: Optional[bool] = None):
259def segment_from_points(
260    predictor: SamPredictor,
261    points: np.ndarray,
262    labels: np.ndarray,
263    image_embeddings: Optional[util.ImageEmbeddings] = None,
264    i: Optional[int] = None,
265    multimask_output: bool = False,
266    return_all: bool = False,
267    use_best_multimask: Optional[bool] = None,
268):
269    """Segmentation from point prompts.
270
271    Args:
272        predictor: The segment anything predictor.
273        points: The point prompts given in the image coordinate system.
274        labels: The labels (positive or negative) associated with the points.
275        image_embeddings: Optional precomputed image embeddings.
276            Has to be passed if the predictor is not yet initialized.
277         i: Index for the image data. Required if the input data has three spatial dimensions
278             or a time dimension and two spatial dimensions.
279        multimask_output: Whether to return multiple or just a single mask.
280        return_all: Whether to return the score and logits in addition to the mask.
281        use_best_multimask: Whether to use multimask output and then choose the best mask.
282            By default this is used for a single positive point and not otherwise.
283
284    Returns:
285        The binary segmentation mask.
286    """
287    predictor, tile, prompts, shape = _initialize_predictor(
288        predictor, image_embeddings, i, (points, labels), _points_to_tile
289    )
290    points, labels = prompts
291
292    if use_best_multimask is None:
293        use_best_multimask = len(points) == 1 and labels[0] == 1
294    multimask_output_ = multimask_output or use_best_multimask
295
296    # predict the mask
297    mask, scores, logits = predictor.predict(
298        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
299        point_labels=labels,
300        multimask_output=multimask_output_,
301    )
302
303    if use_best_multimask:
304        best_mask_id = np.argmax(scores)
305        mask = mask[best_mask_id][None]
306
307    if tile is not None:
308        mask = _tile_to_full_mask(mask, shape, tile)
309
310    if return_all:
311        return mask, scores, logits
312    else:
313        return mask

Segmentation from point prompts.

Arguments:
  • predictor: The segment anything predictor.
  • points: The point prompts given in the image coordinate system.
  • labels: The labels (positive or negative) associated with the points.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized. i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • multimask_output: Whether to return multiple or just a single mask.
  • return_all: Whether to return the score and logits in addition to the mask.
  • use_best_multimask: Whether to use multimask output and then choose the best mask. By default this is used for a single positive point and not otherwise.
Returns:

The binary segmentation mask.

def segment_from_mask( predictor: segment_anything.predictor.SamPredictor, mask: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, use_box: bool = True, use_mask: bool = True, use_points: bool = False, original_size: Optional[Tuple[int, ...]] = None, multimask_output: bool = False, return_all: bool = False, return_logits: bool = False, box_extension: float = 0.0, box: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, labels: Optional[numpy.ndarray] = None, use_single_point: bool = False):
316def segment_from_mask(
317    predictor: SamPredictor,
318    mask: np.ndarray,
319    image_embeddings: Optional[util.ImageEmbeddings] = None,
320    i: Optional[int] = None,
321    use_box: bool = True,
322    use_mask: bool = True,
323    use_points: bool = False,
324    original_size: Optional[Tuple[int, ...]] = None,
325    multimask_output: bool = False,
326    return_all: bool = False,
327    return_logits: bool = False,
328    box_extension: float = 0.0,
329    box: Optional[np.ndarray] = None,
330    points: Optional[np.ndarray] = None,
331    labels: Optional[np.ndarray] = None,
332    use_single_point: bool = False,
333):
334    """Segmentation from a mask prompt.
335
336    Args:
337        predictor: The segment anything predictor.
338        mask: The mask used to derive prompts.
339        image_embeddings: Optional precomputed image embeddings.
340            Has to be passed if the predictor is not yet initialized.
341        i: Index for the image data. Required if the input data has three spatial dimensions
342             or a time dimension and two spatial dimensions.
343        use_box: Whether to derive the bounding box prompt from the mask.
344        use_mask: Whether to use the mask itself as prompt.
345        use_points: Whether to derive point prompts from the mask.
346        original_size: Full image shape. Use this if the mask that is being passed
347            downsampled compared to the original image.
348        multimask_output: Whether to return multiple or just a single mask.
349        return_all: Whether to return the score and logits in addition to the mask.
350        box_extension: Relative factor used to enlarge the bounding box prompt.
351        box: Precomputed bounding box.
352        points: Precomputed point prompts.
353        labels: Positive/negative labels corresponding to the point prompts.
354        use_single_point: Whether to derive just a single point from the mask.
355            In case use_points is true.
356
357    Returns:
358        The binary segmentation mask.
359    """
360    prompts = (mask, box, points, labels)
361
362    def _to_tile(prompts, shape, tile_shape, halo):
363        mask, box, points, labels = prompts
364        tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo)
365        if points is not None:
366            tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
367            if tile_id_points != tile_id:
368                raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.")
369            points, labels = point_prompts
370        if box is not None:
371            tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
372            if tile_id_box != tile_id:
373                raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.")
374        return tile_id, tile, (mask, box, points, labels)
375
376    predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile)
377    mask, box, points, labels = prompts
378
379    if points is not None:
380        if labels is None:
381            raise ValueError("If points are passed you also need to pass labels.")
382        point_coords, point_labels = points, labels
383
384    elif use_points and mask.sum() != 0:
385        point_coords, point_labels = _compute_points_from_mask(
386            mask, original_size=original_size, box_extension=box_extension,
387            use_single_point=use_single_point,
388        )
389
390    else:
391        point_coords, point_labels = None, None
392
393    if box is None:
394        box = _compute_box_from_mask(
395            mask, original_size=original_size, box_extension=box_extension
396        ) if use_box and mask.sum() != 0 else None
397    else:
398        box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension)
399
400    logits = _compute_logits_from_mask(mask) if use_mask else None
401
402    mask, scores, logits = predictor.predict(
403        point_coords=point_coords, point_labels=point_labels,
404        mask_input=logits, box=box,
405        multimask_output=multimask_output, return_logits=return_logits
406    )
407
408    if tile is not None:
409        mask = _tile_to_full_mask(mask, shape, tile)
410
411    if return_all:
412        return mask, scores, logits
413    else:
414        return mask

Segmentation from a mask prompt.

Arguments:
  • predictor: The segment anything predictor.
  • mask: The mask used to derive prompts.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
  • i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • use_box: Whether to derive the bounding box prompt from the mask.
  • use_mask: Whether to use the mask itself as prompt.
  • use_points: Whether to derive point prompts from the mask.
  • original_size: Full image shape. Use this if the mask that is being passed downsampled compared to the original image.
  • multimask_output: Whether to return multiple or just a single mask.
  • return_all: Whether to return the score and logits in addition to the mask.
  • box_extension: Relative factor used to enlarge the bounding box prompt.
  • box: Precomputed bounding box.
  • points: Precomputed point prompts.
  • labels: Positive/negative labels corresponding to the point prompts.
  • use_single_point: Whether to derive just a single point from the mask. In case use_points is true.
Returns:

The binary segmentation mask.

def segment_from_box( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, box_extension: float = 0.0):
417def segment_from_box(
418    predictor: SamPredictor,
419    box: np.ndarray,
420    image_embeddings: Optional[util.ImageEmbeddings] = None,
421    i: Optional[int] = None,
422    multimask_output: bool = False,
423    return_all: bool = False,
424    box_extension: float = 0.0,
425):
426    """Segmentation from a box prompt.
427
428    Args:
429        predictor: The segment anything predictor.
430        box: The box prompt.
431        image_embeddings: Optional precomputed image embeddings.
432            Has to be passed if the predictor is not yet initialized.
433         i: Index for the image data. Required if the input data has three spatial dimensions
434             or a time dimension and two spatial dimensions.
435        multimask_output: Whether to return multiple or just a single mask.
436        return_all: Whether to return the score and logits in addition to the mask.
437        box_extension: Relative factor used to enlarge the bounding box prompt.
438
439    Returns:
440        The binary segmentation mask.
441    """
442    predictor, tile, box, shape = _initialize_predictor(
443        predictor, image_embeddings, i, box, _box_to_tile
444    )
445    mask, scores, logits = predictor.predict(
446        box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output
447    )
448
449    if tile is not None:
450        mask = _tile_to_full_mask(mask, shape, tile)
451
452    if return_all:
453        return mask, scores, logits
454    else:
455        return mask

Segmentation from a box prompt.

Arguments:
  • predictor: The segment anything predictor.
  • box: The box prompt.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized. i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • multimask_output: Whether to return multiple or just a single mask.
  • return_all: Whether to return the score and logits in addition to the mask.
  • box_extension: Relative factor used to enlarge the bounding box prompt.
Returns:

The binary segmentation mask.

def segment_from_box_and_points( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False):
458def segment_from_box_and_points(
459    predictor: SamPredictor,
460    box: np.ndarray,
461    points: np.ndarray,
462    labels: np.ndarray,
463    image_embeddings: Optional[util.ImageEmbeddings] = None,
464    i: Optional[int] = None,
465    multimask_output: bool = False,
466    return_all: bool = False,
467):
468    """Segmentation from a box prompt and point prompts.
469
470    Args:
471        predictor: The segment anything predictor.
472        box: The box prompt.
473        points: The point prompts, given in the image coordinates system.
474        labels: The point labels, either positive or negative.
475        image_embeddings: Optional precomputed image embeddings.
476            Has to be passed if the predictor is not yet initialized.
477         i: Index for the image data. Required if the input data has three spatial dimensions
478             or a time dimension and two spatial dimensions.
479        multimask_output: Whether to return multiple or just a single mask.
480        return_all: Whether to return the score and logits in addition to the mask.
481
482    Returns:
483        The binary segmentation mask.
484    """
485    def box_and_points_to_tile(prompts, shape, tile_shape, halo):
486        box, points, labels = prompts
487        tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
488        points, labels = point_prompts
489        tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
490        if tile_id_box != tile_id:
491            raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.")
492        return tile_id, tile, (box, points, labels)
493
494    predictor, tile, prompts, shape = _initialize_predictor(
495        predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile
496    )
497    box, points, labels = prompts
498
499    mask, scores, logits = predictor.predict(
500        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
501        point_labels=labels,
502        box=_process_box(box, shape),
503        multimask_output=multimask_output
504    )
505
506    if tile is not None:
507        mask = _tile_to_full_mask(mask, shape, tile)
508
509    if return_all:
510        return mask, scores, logits
511    else:
512        return mask

Segmentation from a box prompt and point prompts.

Arguments:
  • predictor: The segment anything predictor.
  • box: The box prompt.
  • points: The point prompts, given in the image coordinates system.
  • labels: The point labels, either positive or negative.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized. i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • multimask_output: Whether to return multiple or just a single mask.
  • return_all: Whether to return the score and logits in addition to the mask.
Returns:

The binary segmentation mask.