micro_sam.prompt_based_segmentation

Functions for prompt-based segmentation with Segment Anything.

  1"""Functions for prompt-based segmentation with Segment Anything.
  2"""
  3
  4import warnings
  5from typing import Optional, Tuple
  6
  7import numpy as np
  8from skimage.filters import gaussian
  9from skimage.feature import peak_local_max
 10from skimage.segmentation import find_boundaries
 11from scipy.ndimage import distance_transform_edt
 12
 13import torch
 14
 15from nifty.tools import blocking
 16
 17from segment_anything.predictor import SamPredictor
 18from segment_anything.utils.transforms import ResizeLongestSide
 19
 20from . import util
 21
 22
 23#
 24# helper functions for translating mask inputs into other prompts
 25#
 26
 27
 28# compute the bounding box from a mask. SAM expects the following input:
 29# box (np.ndarray or None): A length 4 array given a box prompt to the model, in XYXY format.
 30def _compute_box_from_mask(mask, original_size=None, box_extension=0):
 31    coords = np.where(mask == 1)
 32    min_y, min_x = coords[0].min(), coords[1].min()
 33    max_y, max_x = coords[0].max(), coords[1].max()
 34    box = np.array([min_y, min_x, max_y + 1, max_x + 1])
 35    return _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension)
 36
 37
 38# sample points from a mask. SAM expects the following point inputs:
 39def _compute_points_from_mask(mask, original_size, box_extension, use_single_point=False):
 40    box = _compute_box_from_mask(mask, box_extension=box_extension)
 41
 42    # get slice and offset in python coordinate convention
 43    bb = (slice(box[1], box[3]), slice(box[0], box[2]))
 44    offset = np.array([box[1], box[0]])
 45
 46    # crop the mask and compute distances
 47    cropped_mask = mask[bb]
 48    object_boundaries = find_boundaries(cropped_mask, mode="outer")
 49    distances = gaussian(distance_transform_edt(object_boundaries == 0))
 50    inner_distances = distances.copy()
 51    cropped_mask = cropped_mask.astype("bool")
 52    inner_distances[~cropped_mask] = 0.0
 53    if use_single_point:
 54        center = inner_distances.argmax()
 55        center = np.unravel_index(center, inner_distances.shape)
 56        point_coords = (center + offset)[None]
 57        point_labels = np.ones(1, dtype="uint8")
 58        return point_coords[:, ::-1], point_labels
 59
 60    outer_distances = distances.copy()
 61    outer_distances[cropped_mask] = 0.0
 62
 63    # sample positives and negatives from the distance maxima
 64    inner_maxima = peak_local_max(inner_distances, exclude_border=False, min_distance=3)
 65    outer_maxima = peak_local_max(outer_distances, exclude_border=False, min_distance=5)
 66
 67    # derive the positive (=inner maxima) and negative (=outer maxima) points
 68    point_coords = np.concatenate([inner_maxima, outer_maxima]).astype("float64")
 69    point_coords += offset
 70
 71    if original_size is not None:
 72        scale_factor = np.array([
 73            original_size[0] / float(mask.shape[0]), original_size[1] / float(mask.shape[1])
 74        ])[None]
 75        point_coords *= scale_factor
 76
 77    # get the point labels
 78    point_labels = np.concatenate(
 79        [np.ones(len(inner_maxima), dtype="uint8"), np.zeros(len(outer_maxima), dtype="uint8")]
 80    )
 81    return point_coords[:, ::-1], point_labels
 82
 83
 84def _compute_logits_from_mask(mask, eps=1e-3):
 85
 86    def inv_sigmoid(x):
 87        return np.log(x / (1 - x))
 88
 89    logits = np.zeros(mask.shape, dtype="float32")
 90    logits[mask == 1] = 1 - eps
 91    logits[mask == 0] = eps
 92    logits = inv_sigmoid(logits)
 93
 94    # resize to the expected mask shape of SAM (256x256)
 95    assert logits.ndim == 2
 96    expected_shape = (256, 256)
 97
 98    if logits.shape == expected_shape:  # shape matches, do nothing
 99        pass
100
101    elif logits.shape[0] == logits.shape[1]:  # shape is square
102        trafo = ResizeLongestSide(expected_shape[0])
103        logits = trafo.apply_image_torch(torch.from_numpy(logits[None, None]))
104        logits = logits.numpy().squeeze()
105
106    else:  # shape is not square
107        # resize the longest side to expected shape
108        trafo = ResizeLongestSide(expected_shape[0])
109        logits = trafo.apply_image_torch(torch.from_numpy(logits[None, None]))
110        logits = logits.numpy().squeeze()
111
112        # pad the other side
113        h, w = logits.shape
114        padh = expected_shape[0] - h
115        padw = expected_shape[1] - w
116        # IMPORTANT: need to pad with zero, otherwise SAM doesn't understand the padding
117        pad_width = ((0, padh), (0, padw))
118        logits = np.pad(logits, pad_width, mode="constant", constant_values=0)
119
120    logits = logits[None]
121    assert logits.shape == (1, 256, 256), f"{logits.shape}"
122    return logits
123
124
125#
126# other helper functions
127#
128
129
130def _process_box(box, shape, original_size=None, box_extension=0):
131    if box_extension == 0:  # no extension
132        extension_y, extension_x = 0, 0
133    elif box_extension >= 1:  # extension by a fixed factor
134        extension_y, extension_x = box_extension, box_extension
135    else:  # extension by fraction of the box len
136        len_y, len_x = box[2] - box[0], box[3] - box[1]
137        extension_y, extension_x = box_extension * len_y, box_extension * len_x
138
139    box = np.array([
140        max(box[1] - extension_x, 0), max(box[0] - extension_y, 0),
141        min(box[3] + extension_x, shape[1]), min(box[2] + extension_y, shape[0]),
142    ])
143
144    if original_size is not None:
145        trafo = ResizeLongestSide(max(original_size))
146        box = trafo.apply_boxes(box[None], (256, 256)).squeeze()
147
148    # round up the bounding box values
149    box = np.round(box).astype(int)
150
151    return box
152
153
154# Select the correct tile based on average of points
155# and bring the points to the coordinate system of the tile.
156# Discard points that are not in the tile and warn if this happens.
157def _points_to_tile(prompts, shape, tile_shape, halo):
158    points, labels = prompts
159
160    tiling = blocking([0, 0], shape, tile_shape)
161    center = np.mean(points, axis=0).round().astype("int").tolist()
162    tile_id = tiling.coordinatesToBlockId(center)
163
164    tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
165    offset = tile.begin
166    this_tile_shape = tile.shape
167
168    points_in_tile = points - np.array(offset)
169    labels_in_tile = labels
170
171    valid_point_mask = (points_in_tile >= 0).all(axis=1)
172    valid_point_mask = np.logical_and(
173        valid_point_mask,
174        np.logical_and(
175            points_in_tile[:, 0] < this_tile_shape[0], points_in_tile[:, 1] < this_tile_shape[1]
176        )
177    )
178    if not valid_point_mask.all():
179        points_in_tile = points_in_tile[valid_point_mask]
180        labels_in_tile = labels_in_tile[valid_point_mask]
181        warnings.warn(
182            f"{(~valid_point_mask).sum()} points were not in the tile and are dropped"
183        )
184
185    return tile_id, tile, (points_in_tile, labels_in_tile)
186
187
188def _box_to_tile(box, shape, tile_shape, halo):
189    tiling = blocking([0, 0], shape, tile_shape)
190    center = np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2]).round().astype("int").tolist()
191    tile_id = tiling.coordinatesToBlockId(center)
192
193    tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
194    offset = tile.begin
195    this_tile_shape = tile.shape
196
197    box_in_tile = np.array(
198        [
199            max(box[0] - offset[0], 0), max(box[1] - offset[1], 0),
200            min(box[2] - offset[0], this_tile_shape[0]), min(box[3] - offset[1], this_tile_shape[1])
201        ]
202    )
203
204    return tile_id, tile, box_in_tile
205
206
207def _mask_to_tile(mask, shape, tile_shape, halo):
208    tiling = blocking([0, 0], shape, tile_shape)
209
210    coords = np.where(mask)
211    center = np.array([np.mean(coords[0]), np.mean(coords[1])]).round().astype("int").tolist()
212    tile_id = tiling.coordinatesToBlockId(center)
213
214    tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
215    bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end))
216
217    mask_in_tile = mask[bb]
218    return tile_id, tile, mask_in_tile
219
220
221def _initialize_predictor(predictor, image_embeddings, i, prompts, to_tile):
222    tile = None
223
224    # Set the precomputed state for tiled prediction.
225    if image_embeddings is not None and image_embeddings["input_size"] is None:
226        features = image_embeddings["features"]
227        shape, tile_shape, halo = features.attrs["shape"], features.attrs["tile_shape"], features.attrs["halo"]
228        tile_id, tile, prompts = to_tile(prompts, shape, tile_shape, halo)
229        util.set_precomputed(predictor, image_embeddings, i, tile_id=tile_id)
230
231    # Set the precomputed state for normal prediction.
232    elif image_embeddings is not None:
233        shape = image_embeddings["original_size"]
234        util.set_precomputed(predictor, image_embeddings, i)
235
236    else:
237        shape = predictor.original_size
238
239    return predictor, tile, prompts, shape
240
241
242def _tile_to_full_mask(mask, shape, tile):
243    full_mask = np.zeros(mask.shape[0:1] + tuple(shape), dtype=mask.dtype)
244    bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end))
245    full_mask[(slice(None),) + bb] = mask
246    return full_mask
247
248
249#
250# functions for prompted segmentation:
251# - segment_from_points: use point prompts as input
252# - segment_from_mask: use binary mask as input, support conversion to mask, box and point prompts
253# - segment_from_box: use box prompt as input
254# - segment_from_box_and_points: use box and point prompts as input
255#
256
257
258def segment_from_points(
259    predictor: SamPredictor,
260    points: np.ndarray,
261    labels: np.ndarray,
262    image_embeddings: Optional[util.ImageEmbeddings] = None,
263    i: Optional[int] = None,
264    multimask_output: bool = False,
265    return_all: bool = False,
266    use_best_multimask: Optional[bool] = None,
267):
268    """Segmentation from point prompts.
269
270    Args:
271        predictor: The segment anything predictor.
272        points: The point prompts given in the image coordinate system.
273        labels: The labels (positive or negative) associated with the points.
274        image_embeddings: Optional precomputed image embeddings.
275            Has to be passed if the predictor is not yet initialized.
276        i: Index for the image data. Required if the input data has three spatial dimensions
277            or a time dimension and two spatial dimensions.
278        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
279        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
280        use_best_multimask: Whether to use multimask output and then choose the best mask.
281            By default this is used for a single positive point and not otherwise.
282
283    Returns:
284        The binary segmentation mask.
285    """
286    predictor, tile, prompts, shape = _initialize_predictor(
287        predictor, image_embeddings, i, (points, labels), _points_to_tile
288    )
289    points, labels = prompts
290
291    if use_best_multimask is None:
292        use_best_multimask = len(points) == 1 and labels[0] == 1
293    multimask_output_ = multimask_output or use_best_multimask
294
295    # predict the mask
296    mask, scores, logits = predictor.predict(
297        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
298        point_labels=labels,
299        multimask_output=multimask_output_,
300    )
301
302    if use_best_multimask:
303        best_mask_id = np.argmax(scores)
304        mask = mask[best_mask_id][None]
305
306    if tile is not None:
307        mask = _tile_to_full_mask(mask, shape, tile)
308
309    if return_all:
310        return mask, scores, logits
311    else:
312        return mask
313
314
315def segment_from_mask(
316    predictor: SamPredictor,
317    mask: np.ndarray,
318    image_embeddings: Optional[util.ImageEmbeddings] = None,
319    i: Optional[int] = None,
320    use_box: bool = True,
321    use_mask: bool = True,
322    use_points: bool = False,
323    original_size: Optional[Tuple[int, ...]] = None,
324    multimask_output: bool = False,
325    return_all: bool = False,
326    return_logits: bool = False,
327    box_extension: float = 0.0,
328    box: Optional[np.ndarray] = None,
329    points: Optional[np.ndarray] = None,
330    labels: Optional[np.ndarray] = None,
331    use_single_point: bool = False,
332):
333    """Segmentation from a mask prompt.
334
335    Args:
336        predictor: The segment anything predictor.
337        mask: The mask used to derive prompts.
338        image_embeddings: Optional precomputed image embeddings.
339            Has to be passed if the predictor is not yet initialized.
340        i: Index for the image data. Required if the input data has three spatial dimensions
341            or a time dimension and two spatial dimensions.
342        use_box: Whether to derive the bounding box prompt from the mask. By default, set to 'True'.
343        use_mask: Whether to use the mask itself as prompt. By default, set to 'True'.
344        use_points: Whether to derive point prompts from the mask. By default, set to 'False'.
345        original_size: Full image shape. Use this if the mask that is being passed
346            downsampled compared to the original image.
347        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
348        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
349        box_extension: Relative factor used to enlarge the bounding box prompt.
350            By default, does not enlarge the bounding box.
351        box: Precomputed bounding box.
352        points: Precomputed point prompts.
353        labels: Positive/negative labels corresponding to the point prompts.
354        use_single_point: Whether to derive just a single point from the mask.
355            In case use_points is true.
356
357    Returns:
358        The binary segmentation mask.
359    """
360    prompts = (mask, box, points, labels)
361
362    def _to_tile(prompts, shape, tile_shape, halo):
363        mask, box, points, labels = prompts
364        tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo)
365        if points is not None:
366            tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
367            if tile_id_points != tile_id:
368                raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.")
369            points, labels = point_prompts
370        if box is not None:
371            tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
372            if tile_id_box != tile_id:
373                raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.")
374        return tile_id, tile, (mask, box, points, labels)
375
376    predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile)
377    mask, box, points, labels = prompts
378
379    if points is not None:
380        if labels is None:
381            raise ValueError("If points are passed you also need to pass labels.")
382        point_coords, point_labels = points, labels
383
384    elif use_points and mask.sum() != 0:
385        point_coords, point_labels = _compute_points_from_mask(
386            mask, original_size=original_size, box_extension=box_extension,
387            use_single_point=use_single_point,
388        )
389
390    else:
391        point_coords, point_labels = None, None
392
393    if box is None:
394        box = _compute_box_from_mask(
395            mask, original_size=original_size, box_extension=box_extension
396        ) if use_box and mask.sum() != 0 else None
397    else:
398        box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension)
399
400    logits = _compute_logits_from_mask(mask) if use_mask else None
401
402    mask, scores, logits = predictor.predict(
403        point_coords=point_coords, point_labels=point_labels,
404        mask_input=logits, box=box,
405        multimask_output=multimask_output, return_logits=return_logits
406    )
407
408    if tile is not None:
409        mask = _tile_to_full_mask(mask, shape, tile)
410
411    if return_all:
412        return mask, scores, logits
413    else:
414        return mask
415
416
417def segment_from_box(
418    predictor: SamPredictor,
419    box: np.ndarray,
420    image_embeddings: Optional[util.ImageEmbeddings] = None,
421    i: Optional[int] = None,
422    multimask_output: bool = False,
423    return_all: bool = False,
424    box_extension: float = 0.0,
425):
426    """Segmentation from a box prompt.
427
428    Args:
429        predictor: The segment anything predictor.
430        box: The box prompt.
431        image_embeddings: Optional precomputed image embeddings.
432            Has to be passed if the predictor is not yet initialized.
433        i: Index for the image data. Required if the input data has three spatial dimensions
434            or a time dimension and two spatial dimensions.
435        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
436        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
437        box_extension: Relative factor used to enlarge the bounding box prompt.
438            By default, does not enlarge the bounding box.
439
440    Returns:
441        The binary segmentation mask.
442    """
443    predictor, tile, box, shape = _initialize_predictor(
444        predictor, image_embeddings, i, box, _box_to_tile
445    )
446    mask, scores, logits = predictor.predict(
447        box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output
448    )
449
450    if tile is not None:
451        mask = _tile_to_full_mask(mask, shape, tile)
452
453    if return_all:
454        return mask, scores, logits
455    else:
456        return mask
457
458
459def segment_from_box_and_points(
460    predictor: SamPredictor,
461    box: np.ndarray,
462    points: np.ndarray,
463    labels: np.ndarray,
464    image_embeddings: Optional[util.ImageEmbeddings] = None,
465    i: Optional[int] = None,
466    multimask_output: bool = False,
467    return_all: bool = False,
468):
469    """Segmentation from a box prompt and point prompts.
470
471    Args:
472        predictor: The segment anything predictor.
473        box: The box prompt.
474        points: The point prompts, given in the image coordinates system.
475        labels: The point labels, either positive or negative.
476        image_embeddings: Optional precomputed image embeddings.
477            Has to be passed if the predictor is not yet initialized.
478        i: Index for the image data. Required if the input data has three spatial dimensions
479            or a time dimension and two spatial dimensions.
480        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
481        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
482
483    Returns:
484        The binary segmentation mask.
485    """
486    def box_and_points_to_tile(prompts, shape, tile_shape, halo):
487        box, points, labels = prompts
488        tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
489        points, labels = point_prompts
490        tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
491        if tile_id_box != tile_id:
492            raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.")
493        return tile_id, tile, (box, points, labels)
494
495    predictor, tile, prompts, shape = _initialize_predictor(
496        predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile
497    )
498    box, points, labels = prompts
499
500    mask, scores, logits = predictor.predict(
501        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
502        point_labels=labels,
503        box=_process_box(box, shape),
504        multimask_output=multimask_output
505    )
506
507    if tile is not None:
508        mask = _tile_to_full_mask(mask, shape, tile)
509
510    if return_all:
511        return mask, scores, logits
512    else:
513        return mask
def segment_from_points( predictor: segment_anything.predictor.SamPredictor, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, use_best_multimask: Optional[bool] = None):
259def segment_from_points(
260    predictor: SamPredictor,
261    points: np.ndarray,
262    labels: np.ndarray,
263    image_embeddings: Optional[util.ImageEmbeddings] = None,
264    i: Optional[int] = None,
265    multimask_output: bool = False,
266    return_all: bool = False,
267    use_best_multimask: Optional[bool] = None,
268):
269    """Segmentation from point prompts.
270
271    Args:
272        predictor: The segment anything predictor.
273        points: The point prompts given in the image coordinate system.
274        labels: The labels (positive or negative) associated with the points.
275        image_embeddings: Optional precomputed image embeddings.
276            Has to be passed if the predictor is not yet initialized.
277        i: Index for the image data. Required if the input data has three spatial dimensions
278            or a time dimension and two spatial dimensions.
279        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
280        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
281        use_best_multimask: Whether to use multimask output and then choose the best mask.
282            By default this is used for a single positive point and not otherwise.
283
284    Returns:
285        The binary segmentation mask.
286    """
287    predictor, tile, prompts, shape = _initialize_predictor(
288        predictor, image_embeddings, i, (points, labels), _points_to_tile
289    )
290    points, labels = prompts
291
292    if use_best_multimask is None:
293        use_best_multimask = len(points) == 1 and labels[0] == 1
294    multimask_output_ = multimask_output or use_best_multimask
295
296    # predict the mask
297    mask, scores, logits = predictor.predict(
298        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
299        point_labels=labels,
300        multimask_output=multimask_output_,
301    )
302
303    if use_best_multimask:
304        best_mask_id = np.argmax(scores)
305        mask = mask[best_mask_id][None]
306
307    if tile is not None:
308        mask = _tile_to_full_mask(mask, shape, tile)
309
310    if return_all:
311        return mask, scores, logits
312    else:
313        return mask

Segmentation from point prompts.

Arguments:
  • predictor: The segment anything predictor.
  • points: The point prompts given in the image coordinate system.
  • labels: The labels (positive or negative) associated with the points.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
  • i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
  • return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
  • use_best_multimask: Whether to use multimask output and then choose the best mask. By default this is used for a single positive point and not otherwise.
Returns:

The binary segmentation mask.

def segment_from_mask( predictor: segment_anything.predictor.SamPredictor, mask: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, use_box: bool = True, use_mask: bool = True, use_points: bool = False, original_size: Optional[Tuple[int, ...]] = None, multimask_output: bool = False, return_all: bool = False, return_logits: bool = False, box_extension: float = 0.0, box: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, labels: Optional[numpy.ndarray] = None, use_single_point: bool = False):
316def segment_from_mask(
317    predictor: SamPredictor,
318    mask: np.ndarray,
319    image_embeddings: Optional[util.ImageEmbeddings] = None,
320    i: Optional[int] = None,
321    use_box: bool = True,
322    use_mask: bool = True,
323    use_points: bool = False,
324    original_size: Optional[Tuple[int, ...]] = None,
325    multimask_output: bool = False,
326    return_all: bool = False,
327    return_logits: bool = False,
328    box_extension: float = 0.0,
329    box: Optional[np.ndarray] = None,
330    points: Optional[np.ndarray] = None,
331    labels: Optional[np.ndarray] = None,
332    use_single_point: bool = False,
333):
334    """Segmentation from a mask prompt.
335
336    Args:
337        predictor: The segment anything predictor.
338        mask: The mask used to derive prompts.
339        image_embeddings: Optional precomputed image embeddings.
340            Has to be passed if the predictor is not yet initialized.
341        i: Index for the image data. Required if the input data has three spatial dimensions
342            or a time dimension and two spatial dimensions.
343        use_box: Whether to derive the bounding box prompt from the mask. By default, set to 'True'.
344        use_mask: Whether to use the mask itself as prompt. By default, set to 'True'.
345        use_points: Whether to derive point prompts from the mask. By default, set to 'False'.
346        original_size: Full image shape. Use this if the mask that is being passed
347            downsampled compared to the original image.
348        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
349        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
350        box_extension: Relative factor used to enlarge the bounding box prompt.
351            By default, does not enlarge the bounding box.
352        box: Precomputed bounding box.
353        points: Precomputed point prompts.
354        labels: Positive/negative labels corresponding to the point prompts.
355        use_single_point: Whether to derive just a single point from the mask.
356            In case use_points is true.
357
358    Returns:
359        The binary segmentation mask.
360    """
361    prompts = (mask, box, points, labels)
362
363    def _to_tile(prompts, shape, tile_shape, halo):
364        mask, box, points, labels = prompts
365        tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo)
366        if points is not None:
367            tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
368            if tile_id_points != tile_id:
369                raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.")
370            points, labels = point_prompts
371        if box is not None:
372            tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
373            if tile_id_box != tile_id:
374                raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.")
375        return tile_id, tile, (mask, box, points, labels)
376
377    predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile)
378    mask, box, points, labels = prompts
379
380    if points is not None:
381        if labels is None:
382            raise ValueError("If points are passed you also need to pass labels.")
383        point_coords, point_labels = points, labels
384
385    elif use_points and mask.sum() != 0:
386        point_coords, point_labels = _compute_points_from_mask(
387            mask, original_size=original_size, box_extension=box_extension,
388            use_single_point=use_single_point,
389        )
390
391    else:
392        point_coords, point_labels = None, None
393
394    if box is None:
395        box = _compute_box_from_mask(
396            mask, original_size=original_size, box_extension=box_extension
397        ) if use_box and mask.sum() != 0 else None
398    else:
399        box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension)
400
401    logits = _compute_logits_from_mask(mask) if use_mask else None
402
403    mask, scores, logits = predictor.predict(
404        point_coords=point_coords, point_labels=point_labels,
405        mask_input=logits, box=box,
406        multimask_output=multimask_output, return_logits=return_logits
407    )
408
409    if tile is not None:
410        mask = _tile_to_full_mask(mask, shape, tile)
411
412    if return_all:
413        return mask, scores, logits
414    else:
415        return mask

Segmentation from a mask prompt.

Arguments:
  • predictor: The segment anything predictor.
  • mask: The mask used to derive prompts.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
  • i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • use_box: Whether to derive the bounding box prompt from the mask. By default, set to 'True'.
  • use_mask: Whether to use the mask itself as prompt. By default, set to 'True'.
  • use_points: Whether to derive point prompts from the mask. By default, set to 'False'.
  • original_size: Full image shape. Use this if the mask that is being passed downsampled compared to the original image.
  • multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
  • return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
  • box_extension: Relative factor used to enlarge the bounding box prompt. By default, does not enlarge the bounding box.
  • box: Precomputed bounding box.
  • points: Precomputed point prompts.
  • labels: Positive/negative labels corresponding to the point prompts.
  • use_single_point: Whether to derive just a single point from the mask. In case use_points is true.
Returns:

The binary segmentation mask.

def segment_from_box( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, box_extension: float = 0.0):
418def segment_from_box(
419    predictor: SamPredictor,
420    box: np.ndarray,
421    image_embeddings: Optional[util.ImageEmbeddings] = None,
422    i: Optional[int] = None,
423    multimask_output: bool = False,
424    return_all: bool = False,
425    box_extension: float = 0.0,
426):
427    """Segmentation from a box prompt.
428
429    Args:
430        predictor: The segment anything predictor.
431        box: The box prompt.
432        image_embeddings: Optional precomputed image embeddings.
433            Has to be passed if the predictor is not yet initialized.
434        i: Index for the image data. Required if the input data has three spatial dimensions
435            or a time dimension and two spatial dimensions.
436        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
437        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
438        box_extension: Relative factor used to enlarge the bounding box prompt.
439            By default, does not enlarge the bounding box.
440
441    Returns:
442        The binary segmentation mask.
443    """
444    predictor, tile, box, shape = _initialize_predictor(
445        predictor, image_embeddings, i, box, _box_to_tile
446    )
447    mask, scores, logits = predictor.predict(
448        box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output
449    )
450
451    if tile is not None:
452        mask = _tile_to_full_mask(mask, shape, tile)
453
454    if return_all:
455        return mask, scores, logits
456    else:
457        return mask

Segmentation from a box prompt.

Arguments:
  • predictor: The segment anything predictor.
  • box: The box prompt.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
  • i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
  • return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
  • box_extension: Relative factor used to enlarge the bounding box prompt. By default, does not enlarge the bounding box.
Returns:

The binary segmentation mask.

def segment_from_box_and_points( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False):
460def segment_from_box_and_points(
461    predictor: SamPredictor,
462    box: np.ndarray,
463    points: np.ndarray,
464    labels: np.ndarray,
465    image_embeddings: Optional[util.ImageEmbeddings] = None,
466    i: Optional[int] = None,
467    multimask_output: bool = False,
468    return_all: bool = False,
469):
470    """Segmentation from a box prompt and point prompts.
471
472    Args:
473        predictor: The segment anything predictor.
474        box: The box prompt.
475        points: The point prompts, given in the image coordinates system.
476        labels: The point labels, either positive or negative.
477        image_embeddings: Optional precomputed image embeddings.
478            Has to be passed if the predictor is not yet initialized.
479        i: Index for the image data. Required if the input data has three spatial dimensions
480            or a time dimension and two spatial dimensions.
481        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
482        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
483
484    Returns:
485        The binary segmentation mask.
486    """
487    def box_and_points_to_tile(prompts, shape, tile_shape, halo):
488        box, points, labels = prompts
489        tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
490        points, labels = point_prompts
491        tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
492        if tile_id_box != tile_id:
493            raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.")
494        return tile_id, tile, (box, points, labels)
495
496    predictor, tile, prompts, shape = _initialize_predictor(
497        predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile
498    )
499    box, points, labels = prompts
500
501    mask, scores, logits = predictor.predict(
502        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
503        point_labels=labels,
504        box=_process_box(box, shape),
505        multimask_output=multimask_output
506    )
507
508    if tile is not None:
509        mask = _tile_to_full_mask(mask, shape, tile)
510
511    if return_all:
512        return mask, scores, logits
513    else:
514        return mask

Segmentation from a box prompt and point prompts.

Arguments:
  • predictor: The segment anything predictor.
  • box: The box prompt.
  • points: The point prompts, given in the image coordinates system.
  • labels: The point labels, either positive or negative.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
  • i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
  • return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
Returns:

The binary segmentation mask.