micro_sam.prompt_based_segmentation

Functions for prompt-based segmentation with Segment Anything.

  1"""Functions for prompt-based segmentation with Segment Anything.
  2"""
  3
  4import warnings
  5from typing import Optional, Tuple
  6
  7import numpy as np
  8from skimage.feature import peak_local_max
  9from skimage.segmentation import find_boundaries
 10
 11import torch
 12
 13from bioimage_cpp.utils import Blocking
 14from bioimage_cpp.distance import distance_transform
 15from bioimage_cpp.filters import gaussian_smoothing
 16
 17from segment_anything.predictor import SamPredictor
 18from segment_anything.utils.transforms import ResizeLongestSide
 19
 20from . import util
 21
 22
 23#
 24# helper functions for translating mask inputs into other prompts
 25#
 26
 27
 28# compute the bounding box from a mask. SAM expects the following input:
 29# box (np.ndarray or None): A length 4 array given a box prompt to the model, in XYXY format.
 30def _compute_box_from_mask(mask, original_size=None, box_extension=0):
 31    coords = np.where(mask == 1)
 32    min_y, min_x = coords[0].min(), coords[1].min()
 33    max_y, max_x = coords[0].max(), coords[1].max()
 34    box = np.array([min_y, min_x, max_y + 1, max_x + 1])
 35    return _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension)
 36
 37
 38# sample points from a mask. SAM expects the following point inputs:
 39def _compute_points_from_mask(mask, original_size, box_extension, use_single_point=False):
 40    box = _compute_box_from_mask(mask, box_extension=box_extension)
 41
 42    # get slice and offset in python coordinate convention
 43    bb = (slice(box[1], box[3]), slice(box[0], box[2]))
 44    offset = np.array([box[1], box[0]])
 45
 46    # crop the mask and compute distances
 47    cropped_mask = mask[bb]
 48    object_boundaries = find_boundaries(cropped_mask, mode="outer")
 49    distances = gaussian_smoothing(distance_transform(object_boundaries == 0), sigma=1.0)
 50    inner_distances = distances.copy()
 51    cropped_mask = cropped_mask.astype("bool")
 52    inner_distances[~cropped_mask] = 0.0
 53    if use_single_point:
 54        center = inner_distances.argmax()
 55        center = np.unravel_index(center, inner_distances.shape)
 56        point_coords = (center + offset)[None]
 57        point_labels = np.ones(1, dtype="uint8")
 58        return point_coords[:, ::-1], point_labels
 59
 60    outer_distances = distances.copy()
 61    outer_distances[cropped_mask] = 0.0
 62
 63    # sample positives and negatives from the distance maxima
 64    inner_maxima = peak_local_max(inner_distances, exclude_border=False, min_distance=3)
 65    outer_maxima = peak_local_max(outer_distances, exclude_border=False, min_distance=5)
 66
 67    # derive the positive (=inner maxima) and negative (=outer maxima) points
 68    point_coords = np.concatenate([inner_maxima, outer_maxima]).astype("float64")
 69    point_coords += offset
 70
 71    if original_size is not None:
 72        scale_factor = np.array([
 73            original_size[0] / float(mask.shape[0]), original_size[1] / float(mask.shape[1])
 74        ])[None]
 75        point_coords *= scale_factor
 76
 77    # get the point labels
 78    point_labels = np.concatenate(
 79        [np.ones(len(inner_maxima), dtype="uint8"), np.zeros(len(outer_maxima), dtype="uint8")]
 80    )
 81    return point_coords[:, ::-1], point_labels
 82
 83
 84def _compute_logits_from_mask(mask, eps=1e-3):
 85
 86    def inv_sigmoid(x):
 87        return np.log(x / (1 - x))
 88
 89    # resize to the expected mask shape of SAM (256x256)
 90    assert mask.ndim == 2
 91    expected_shape = (256, 256)
 92
 93    # Resize the *binary* mask (instead of the inverse-sigmoid logits) to SAM's expected
 94    # mask shape and re-binarize afterwards. This keeps small objects from being washed out
 95    # by the antialiased downscaling that ResizeLongestSide applies, which otherwise makes
 96    # the mask prompt too weak for small objects in large (and non-square) images.
 97    binary_mask = (mask == 1).astype("float32")
 98
 99    if binary_mask.shape != expected_shape:
100        trafo = ResizeLongestSide(expected_shape[0])
101        binary_mask = trafo.apply_image_torch(torch.from_numpy(binary_mask[None, None]))
102        binary_mask = binary_mask.numpy().squeeze()
103
104        if binary_mask.shape != expected_shape:  # shape is not square -> pad the other side
105            h, w = binary_mask.shape
106            padh = expected_shape[0] - h
107            padw = expected_shape[1] - w
108            # IMPORTANT: need to pad with zero, otherwise SAM doesn't understand the padding
109            pad_width = ((0, padh), (0, padw))
110            binary_mask = np.pad(binary_mask, pad_width, mode="constant", constant_values=0)
111
112    logits = np.where(binary_mask > 0.5, inv_sigmoid(1 - eps), inv_sigmoid(eps)).astype("float32")
113    logits = logits[None]
114    assert logits.shape == (1, 256, 256), f"{logits.shape}"
115    return logits
116
117
118#
119# other helper functions
120#
121
122
123def _process_box(box, shape, original_size=None, box_extension=0):
124    if box_extension == 0:  # no extension
125        extension_y, extension_x = 0, 0
126    elif box_extension >= 1:  # extension by a fixed factor
127        extension_y, extension_x = box_extension, box_extension
128    else:  # extension by fraction of the box len
129        len_y, len_x = box[2] - box[0], box[3] - box[1]
130        extension_y, extension_x = box_extension * len_y, box_extension * len_x
131
132    box = np.array([
133        max(box[1] - extension_x, 0), max(box[0] - extension_y, 0),
134        min(box[3] + extension_x, shape[1]), min(box[2] + extension_y, shape[0]),
135    ])
136
137    if original_size is not None:
138        trafo = ResizeLongestSide(max(original_size))
139        box = trafo.apply_boxes(box[None], (256, 256)).squeeze()
140
141    # round up the bounding box values
142    box = np.round(box).astype(int)
143
144    return box
145
146
147# Select the correct tile based on average of points
148# and bring the points to the coordinate system of the tile.
149# Discard points that are not in the tile and warn if this happens.
150def _points_to_tile(prompts, shape, tile_shape, halo):
151    points, labels = prompts
152
153    tiling = Blocking([0, 0], shape, tile_shape)
154    center = np.mean(points, axis=0).round().astype("int").tolist()
155    tile_id = tiling.coordinates_to_block_id(center)
156
157    tile = tiling.get_block_with_halo(tile_id, list(halo)).outer_block
158    offset = tile.begin
159    this_tile_shape = tile.shape
160
161    points_in_tile = points - np.array(offset)
162    labels_in_tile = labels
163
164    valid_point_mask = (points_in_tile >= 0).all(axis=1)
165    valid_point_mask = np.logical_and(
166        valid_point_mask,
167        np.logical_and(
168            points_in_tile[:, 0] < this_tile_shape[0], points_in_tile[:, 1] < this_tile_shape[1]
169        )
170    )
171    if not valid_point_mask.all():
172        points_in_tile = points_in_tile[valid_point_mask]
173        labels_in_tile = labels_in_tile[valid_point_mask]
174        warnings.warn(
175            f"{(~valid_point_mask).sum()} points were not in the tile and are dropped"
176        )
177
178    return tile_id, tile, (points_in_tile, labels_in_tile)
179
180
181def _box_to_tile(box, shape, tile_shape, halo):
182    tiling = Blocking([0, 0], shape, tile_shape)
183    center = np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2]).round().astype("int").tolist()
184    tile_id = tiling.coordinates_to_block_id(center)
185
186    tile = tiling.get_block_with_halo(tile_id, list(halo)).outer_block
187    offset = tile.begin
188    this_tile_shape = tile.shape
189
190    box_in_tile = np.array(
191        [
192            max(box[0] - offset[0], 0), max(box[1] - offset[1], 0),
193            min(box[2] - offset[0], this_tile_shape[0]), min(box[3] - offset[1], this_tile_shape[1])
194        ]
195    )
196
197    return tile_id, tile, box_in_tile
198
199
200def _mask_to_tile(mask, shape, tile_shape, halo):
201    tiling = Blocking([0, 0], shape, tile_shape)
202
203    coords = np.where(mask)
204    center = np.array([np.mean(coords[0]), np.mean(coords[1])]).round().astype("int").tolist()
205    tile_id = tiling.coordinates_to_block_id(center)
206
207    tile = tiling.get_block_with_halo(tile_id, list(halo)).outer_block
208    bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end))
209
210    mask_in_tile = mask[bb]
211    return tile_id, tile, mask_in_tile
212
213
214def _initialize_predictor(predictor, image_embeddings, i, prompts, to_tile):
215    tile = None
216
217    # Set the precomputed state for tiled prediction.
218    if image_embeddings is not None and image_embeddings["input_size"] is None:
219        features = image_embeddings["features"]
220        shape, tile_shape, halo = features.attrs["shape"], features.attrs["tile_shape"], features.attrs["halo"]
221        tile_id, tile, prompts = to_tile(prompts, shape, tile_shape, halo)
222        util.set_precomputed(predictor, image_embeddings, i, tile_id=tile_id)
223
224    # Set the precomputed state for normal prediction.
225    elif image_embeddings is not None:
226        shape = image_embeddings["original_size"]
227        util.set_precomputed(predictor, image_embeddings, i)
228
229    else:
230        shape = predictor.original_size
231
232    return predictor, tile, prompts, shape
233
234
235def _tile_to_full_mask(mask, shape, tile):
236    full_mask = np.zeros(mask.shape[0:1] + tuple(shape), dtype=mask.dtype)
237    bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end))
238    full_mask[(slice(None),) + bb] = mask
239    return full_mask
240
241
242#
243# functions for prompted segmentation:
244# - segment_from_points: use point prompts as input
245# - segment_from_mask: use binary mask as input, support conversion to mask, box and point prompts
246# - segment_from_box: use box prompt as input
247# - segment_from_box_and_points: use box and point prompts as input
248#
249
250
251def segment_from_points(
252    predictor: SamPredictor,
253    points: np.ndarray,
254    labels: np.ndarray,
255    image_embeddings: Optional[util.ImageEmbeddings] = None,
256    i: Optional[int] = None,
257    multimask_output: bool = False,
258    return_all: bool = False,
259    use_best_multimask: Optional[bool] = None,
260):
261    """Segmentation from point prompts.
262
263    Args:
264        predictor: The segment anything predictor.
265        points: The point prompts given in the image coordinate system.
266        labels: The labels (positive or negative) associated with the points.
267        image_embeddings: Optional precomputed image embeddings.
268            Has to be passed if the predictor is not yet initialized.
269        i: Index for the image data. Required if the input data has three spatial dimensions
270            or a time dimension and two spatial dimensions.
271        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
272        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
273        use_best_multimask: Whether to use multimask output and then choose the best mask.
274            By default this is used for a single positive point and not otherwise.
275
276    Returns:
277        The binary segmentation mask.
278    """
279    predictor, tile, prompts, shape = _initialize_predictor(
280        predictor, image_embeddings, i, (points, labels), _points_to_tile
281    )
282    points, labels = prompts
283
284    if use_best_multimask is None:
285        use_best_multimask = len(points) == 1 and labels[0] == 1
286    multimask_output_ = multimask_output or use_best_multimask
287
288    # predict the mask
289    mask, scores, logits = predictor.predict(
290        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
291        point_labels=labels,
292        multimask_output=multimask_output_,
293    )
294
295    if use_best_multimask:
296        best_mask_id = np.argmax(scores)
297        mask = mask[best_mask_id][None]
298
299    if tile is not None:
300        mask = _tile_to_full_mask(mask, shape, tile)
301
302    if return_all:
303        return mask, scores, logits
304    else:
305        return mask
306
307
308def segment_from_mask(
309    predictor: SamPredictor,
310    mask: np.ndarray,
311    image_embeddings: Optional[util.ImageEmbeddings] = None,
312    i: Optional[int] = None,
313    use_box: bool = True,
314    use_mask: bool = True,
315    use_points: bool = False,
316    original_size: Optional[Tuple[int, ...]] = None,
317    multimask_output: bool = False,
318    return_all: bool = False,
319    return_logits: bool = False,
320    box_extension: float = 0.0,
321    box: Optional[np.ndarray] = None,
322    points: Optional[np.ndarray] = None,
323    labels: Optional[np.ndarray] = None,
324    use_single_point: bool = False,
325):
326    """Segmentation from a mask prompt.
327
328    Args:
329        predictor: The segment anything predictor.
330        mask: The mask used to derive prompts.
331        image_embeddings: Optional precomputed image embeddings.
332            Has to be passed if the predictor is not yet initialized.
333        i: Index for the image data. Required if the input data has three spatial dimensions
334            or a time dimension and two spatial dimensions.
335        use_box: Whether to derive the bounding box prompt from the mask. By default, set to 'True'.
336        use_mask: Whether to use the mask itself as prompt. By default, set to 'True'.
337        use_points: Whether to derive point prompts from the mask. By default, set to 'False'.
338        original_size: Full image shape. Use this if the mask that is being passed
339            downsampled compared to the original image.
340        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
341        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
342        box_extension: Relative factor used to enlarge the bounding box prompt.
343            By default, does not enlarge the bounding box.
344        box: Precomputed bounding box.
345        points: Precomputed point prompts.
346        labels: Positive/negative labels corresponding to the point prompts.
347        use_single_point: Whether to derive just a single point from the mask.
348            In case use_points is true.
349
350    Returns:
351        The binary segmentation mask.
352    """
353    prompts = (mask, box, points, labels)
354
355    def _to_tile(prompts, shape, tile_shape, halo):
356        mask, box, points, labels = prompts
357        tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo)
358        if points is not None:
359            tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
360            if tile_id_points != tile_id:
361                raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.")
362            points, labels = point_prompts
363        if box is not None:
364            tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
365            if tile_id_box != tile_id:
366                raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.")
367        return tile_id, tile, (mask, box, points, labels)
368
369    predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile)
370    mask, box, points, labels = prompts
371
372    if points is not None:
373        if labels is None:
374            raise ValueError("If points are passed you also need to pass labels.")
375        point_coords, point_labels = points, labels
376
377    elif use_points and mask.sum() != 0:
378        point_coords, point_labels = _compute_points_from_mask(
379            mask, original_size=original_size, box_extension=box_extension,
380            use_single_point=use_single_point,
381        )
382
383    else:
384        point_coords, point_labels = None, None
385
386    if box is None:
387        box = _compute_box_from_mask(
388            mask, original_size=original_size, box_extension=box_extension
389        ) if use_box and mask.sum() != 0 else None
390    else:
391        box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension)
392
393    logits = _compute_logits_from_mask(mask) if use_mask else None
394
395    mask, scores, logits = predictor.predict(
396        point_coords=point_coords, point_labels=point_labels,
397        mask_input=logits, box=box,
398        multimask_output=multimask_output, return_logits=return_logits
399    )
400
401    if tile is not None:
402        mask = _tile_to_full_mask(mask, shape, tile)
403
404    if return_all:
405        return mask, scores, logits
406    else:
407        return mask
408
409
410def segment_from_box(
411    predictor: SamPredictor,
412    box: np.ndarray,
413    image_embeddings: Optional[util.ImageEmbeddings] = None,
414    i: Optional[int] = None,
415    multimask_output: bool = False,
416    return_all: bool = False,
417    box_extension: float = 0.0,
418):
419    """Segmentation from a box prompt.
420
421    Args:
422        predictor: The segment anything predictor.
423        box: The box prompt.
424        image_embeddings: Optional precomputed image embeddings.
425            Has to be passed if the predictor is not yet initialized.
426        i: Index for the image data. Required if the input data has three spatial dimensions
427            or a time dimension and two spatial dimensions.
428        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
429        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
430        box_extension: Relative factor used to enlarge the bounding box prompt.
431            By default, does not enlarge the bounding box.
432
433    Returns:
434        The binary segmentation mask.
435    """
436    predictor, tile, box, shape = _initialize_predictor(
437        predictor, image_embeddings, i, box, _box_to_tile
438    )
439    mask, scores, logits = predictor.predict(
440        box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output
441    )
442
443    if tile is not None:
444        mask = _tile_to_full_mask(mask, shape, tile)
445
446    if return_all:
447        return mask, scores, logits
448    else:
449        return mask
450
451
452def segment_from_box_and_points(
453    predictor: SamPredictor,
454    box: np.ndarray,
455    points: np.ndarray,
456    labels: np.ndarray,
457    image_embeddings: Optional[util.ImageEmbeddings] = None,
458    i: Optional[int] = None,
459    multimask_output: bool = False,
460    return_all: bool = False,
461):
462    """Segmentation from a box prompt and point prompts.
463
464    Args:
465        predictor: The segment anything predictor.
466        box: The box prompt.
467        points: The point prompts, given in the image coordinates system.
468        labels: The point labels, either positive or negative.
469        image_embeddings: Optional precomputed image embeddings.
470            Has to be passed if the predictor is not yet initialized.
471        i: Index for the image data. Required if the input data has three spatial dimensions
472            or a time dimension and two spatial dimensions.
473        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
474        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
475
476    Returns:
477        The binary segmentation mask.
478    """
479    def box_and_points_to_tile(prompts, shape, tile_shape, halo):
480        box, points, labels = prompts
481        tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
482        points, labels = point_prompts
483        tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
484        if tile_id_box != tile_id:
485            raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.")
486        return tile_id, tile, (box, points, labels)
487
488    predictor, tile, prompts, shape = _initialize_predictor(
489        predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile
490    )
491    box, points, labels = prompts
492
493    mask, scores, logits = predictor.predict(
494        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
495        point_labels=labels,
496        box=_process_box(box, shape),
497        multimask_output=multimask_output
498    )
499
500    if tile is not None:
501        mask = _tile_to_full_mask(mask, shape, tile)
502
503    if return_all:
504        return mask, scores, logits
505    else:
506        return mask
def segment_from_points( predictor: segment_anything.predictor.SamPredictor, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, use_best_multimask: Optional[bool] = None):
252def segment_from_points(
253    predictor: SamPredictor,
254    points: np.ndarray,
255    labels: np.ndarray,
256    image_embeddings: Optional[util.ImageEmbeddings] = None,
257    i: Optional[int] = None,
258    multimask_output: bool = False,
259    return_all: bool = False,
260    use_best_multimask: Optional[bool] = None,
261):
262    """Segmentation from point prompts.
263
264    Args:
265        predictor: The segment anything predictor.
266        points: The point prompts given in the image coordinate system.
267        labels: The labels (positive or negative) associated with the points.
268        image_embeddings: Optional precomputed image embeddings.
269            Has to be passed if the predictor is not yet initialized.
270        i: Index for the image data. Required if the input data has three spatial dimensions
271            or a time dimension and two spatial dimensions.
272        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
273        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
274        use_best_multimask: Whether to use multimask output and then choose the best mask.
275            By default this is used for a single positive point and not otherwise.
276
277    Returns:
278        The binary segmentation mask.
279    """
280    predictor, tile, prompts, shape = _initialize_predictor(
281        predictor, image_embeddings, i, (points, labels), _points_to_tile
282    )
283    points, labels = prompts
284
285    if use_best_multimask is None:
286        use_best_multimask = len(points) == 1 and labels[0] == 1
287    multimask_output_ = multimask_output or use_best_multimask
288
289    # predict the mask
290    mask, scores, logits = predictor.predict(
291        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
292        point_labels=labels,
293        multimask_output=multimask_output_,
294    )
295
296    if use_best_multimask:
297        best_mask_id = np.argmax(scores)
298        mask = mask[best_mask_id][None]
299
300    if tile is not None:
301        mask = _tile_to_full_mask(mask, shape, tile)
302
303    if return_all:
304        return mask, scores, logits
305    else:
306        return mask

Segmentation from point prompts.

Arguments:
  • predictor: The segment anything predictor.
  • points: The point prompts given in the image coordinate system.
  • labels: The labels (positive or negative) associated with the points.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
  • i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
  • return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
  • use_best_multimask: Whether to use multimask output and then choose the best mask. By default this is used for a single positive point and not otherwise.
Returns:

The binary segmentation mask.

def segment_from_mask( predictor: segment_anything.predictor.SamPredictor, mask: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, use_box: bool = True, use_mask: bool = True, use_points: bool = False, original_size: Optional[Tuple[int, ...]] = None, multimask_output: bool = False, return_all: bool = False, return_logits: bool = False, box_extension: float = 0.0, box: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, labels: Optional[numpy.ndarray] = None, use_single_point: bool = False):
309def segment_from_mask(
310    predictor: SamPredictor,
311    mask: np.ndarray,
312    image_embeddings: Optional[util.ImageEmbeddings] = None,
313    i: Optional[int] = None,
314    use_box: bool = True,
315    use_mask: bool = True,
316    use_points: bool = False,
317    original_size: Optional[Tuple[int, ...]] = None,
318    multimask_output: bool = False,
319    return_all: bool = False,
320    return_logits: bool = False,
321    box_extension: float = 0.0,
322    box: Optional[np.ndarray] = None,
323    points: Optional[np.ndarray] = None,
324    labels: Optional[np.ndarray] = None,
325    use_single_point: bool = False,
326):
327    """Segmentation from a mask prompt.
328
329    Args:
330        predictor: The segment anything predictor.
331        mask: The mask used to derive prompts.
332        image_embeddings: Optional precomputed image embeddings.
333            Has to be passed if the predictor is not yet initialized.
334        i: Index for the image data. Required if the input data has three spatial dimensions
335            or a time dimension and two spatial dimensions.
336        use_box: Whether to derive the bounding box prompt from the mask. By default, set to 'True'.
337        use_mask: Whether to use the mask itself as prompt. By default, set to 'True'.
338        use_points: Whether to derive point prompts from the mask. By default, set to 'False'.
339        original_size: Full image shape. Use this if the mask that is being passed
340            downsampled compared to the original image.
341        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
342        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
343        box_extension: Relative factor used to enlarge the bounding box prompt.
344            By default, does not enlarge the bounding box.
345        box: Precomputed bounding box.
346        points: Precomputed point prompts.
347        labels: Positive/negative labels corresponding to the point prompts.
348        use_single_point: Whether to derive just a single point from the mask.
349            In case use_points is true.
350
351    Returns:
352        The binary segmentation mask.
353    """
354    prompts = (mask, box, points, labels)
355
356    def _to_tile(prompts, shape, tile_shape, halo):
357        mask, box, points, labels = prompts
358        tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo)
359        if points is not None:
360            tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
361            if tile_id_points != tile_id:
362                raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.")
363            points, labels = point_prompts
364        if box is not None:
365            tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
366            if tile_id_box != tile_id:
367                raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.")
368        return tile_id, tile, (mask, box, points, labels)
369
370    predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile)
371    mask, box, points, labels = prompts
372
373    if points is not None:
374        if labels is None:
375            raise ValueError("If points are passed you also need to pass labels.")
376        point_coords, point_labels = points, labels
377
378    elif use_points and mask.sum() != 0:
379        point_coords, point_labels = _compute_points_from_mask(
380            mask, original_size=original_size, box_extension=box_extension,
381            use_single_point=use_single_point,
382        )
383
384    else:
385        point_coords, point_labels = None, None
386
387    if box is None:
388        box = _compute_box_from_mask(
389            mask, original_size=original_size, box_extension=box_extension
390        ) if use_box and mask.sum() != 0 else None
391    else:
392        box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension)
393
394    logits = _compute_logits_from_mask(mask) if use_mask else None
395
396    mask, scores, logits = predictor.predict(
397        point_coords=point_coords, point_labels=point_labels,
398        mask_input=logits, box=box,
399        multimask_output=multimask_output, return_logits=return_logits
400    )
401
402    if tile is not None:
403        mask = _tile_to_full_mask(mask, shape, tile)
404
405    if return_all:
406        return mask, scores, logits
407    else:
408        return mask

Segmentation from a mask prompt.

Arguments:
  • predictor: The segment anything predictor.
  • mask: The mask used to derive prompts.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
  • i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • use_box: Whether to derive the bounding box prompt from the mask. By default, set to 'True'.
  • use_mask: Whether to use the mask itself as prompt. By default, set to 'True'.
  • use_points: Whether to derive point prompts from the mask. By default, set to 'False'.
  • original_size: Full image shape. Use this if the mask that is being passed downsampled compared to the original image.
  • multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
  • return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
  • box_extension: Relative factor used to enlarge the bounding box prompt. By default, does not enlarge the bounding box.
  • box: Precomputed bounding box.
  • points: Precomputed point prompts.
  • labels: Positive/negative labels corresponding to the point prompts.
  • use_single_point: Whether to derive just a single point from the mask. In case use_points is true.
Returns:

The binary segmentation mask.

def segment_from_box( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, box_extension: float = 0.0):
411def segment_from_box(
412    predictor: SamPredictor,
413    box: np.ndarray,
414    image_embeddings: Optional[util.ImageEmbeddings] = None,
415    i: Optional[int] = None,
416    multimask_output: bool = False,
417    return_all: bool = False,
418    box_extension: float = 0.0,
419):
420    """Segmentation from a box prompt.
421
422    Args:
423        predictor: The segment anything predictor.
424        box: The box prompt.
425        image_embeddings: Optional precomputed image embeddings.
426            Has to be passed if the predictor is not yet initialized.
427        i: Index for the image data. Required if the input data has three spatial dimensions
428            or a time dimension and two spatial dimensions.
429        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
430        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
431        box_extension: Relative factor used to enlarge the bounding box prompt.
432            By default, does not enlarge the bounding box.
433
434    Returns:
435        The binary segmentation mask.
436    """
437    predictor, tile, box, shape = _initialize_predictor(
438        predictor, image_embeddings, i, box, _box_to_tile
439    )
440    mask, scores, logits = predictor.predict(
441        box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output
442    )
443
444    if tile is not None:
445        mask = _tile_to_full_mask(mask, shape, tile)
446
447    if return_all:
448        return mask, scores, logits
449    else:
450        return mask

Segmentation from a box prompt.

Arguments:
  • predictor: The segment anything predictor.
  • box: The box prompt.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
  • i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
  • return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
  • box_extension: Relative factor used to enlarge the bounding box prompt. By default, does not enlarge the bounding box.
Returns:

The binary segmentation mask.

def segment_from_box_and_points( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False):
453def segment_from_box_and_points(
454    predictor: SamPredictor,
455    box: np.ndarray,
456    points: np.ndarray,
457    labels: np.ndarray,
458    image_embeddings: Optional[util.ImageEmbeddings] = None,
459    i: Optional[int] = None,
460    multimask_output: bool = False,
461    return_all: bool = False,
462):
463    """Segmentation from a box prompt and point prompts.
464
465    Args:
466        predictor: The segment anything predictor.
467        box: The box prompt.
468        points: The point prompts, given in the image coordinates system.
469        labels: The point labels, either positive or negative.
470        image_embeddings: Optional precomputed image embeddings.
471            Has to be passed if the predictor is not yet initialized.
472        i: Index for the image data. Required if the input data has three spatial dimensions
473            or a time dimension and two spatial dimensions.
474        multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
475        return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
476
477    Returns:
478        The binary segmentation mask.
479    """
480    def box_and_points_to_tile(prompts, shape, tile_shape, halo):
481        box, points, labels = prompts
482        tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)
483        points, labels = point_prompts
484        tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo)
485        if tile_id_box != tile_id:
486            raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.")
487        return tile_id, tile, (box, points, labels)
488
489    predictor, tile, prompts, shape = _initialize_predictor(
490        predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile
491    )
492    box, points, labels = prompts
493
494    mask, scores, logits = predictor.predict(
495        point_coords=points[:, ::-1],  # SAM has reversed XY conventions
496        point_labels=labels,
497        box=_process_box(box, shape),
498        multimask_output=multimask_output
499    )
500
501    if tile is not None:
502        mask = _tile_to_full_mask(mask, shape, tile)
503
504    if return_all:
505        return mask, scores, logits
506    else:
507        return mask

Segmentation from a box prompt and point prompts.

Arguments:
  • predictor: The segment anything predictor.
  • box: The box prompt.
  • points: The point prompts, given in the image coordinates system.
  • labels: The point labels, either positive or negative.
  • image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
  • i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
  • multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
  • return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
Returns:

The binary segmentation mask.