micro_sam.prompt_based_segmentation
Functions for prompt-based segmentation with Segment Anything.
1"""Functions for prompt-based segmentation with Segment Anything. 2""" 3 4import warnings 5from typing import Optional, Tuple 6 7import numpy as np 8from skimage.feature import peak_local_max 9from skimage.segmentation import find_boundaries 10 11import torch 12 13from bioimage_cpp.utils import Blocking 14from bioimage_cpp.distance import distance_transform 15from bioimage_cpp.filters import gaussian_smoothing 16 17from segment_anything.predictor import SamPredictor 18from segment_anything.utils.transforms import ResizeLongestSide 19 20from . import util 21 22 23# 24# helper functions for translating mask inputs into other prompts 25# 26 27 28# compute the bounding box from a mask. SAM expects the following input: 29# box (np.ndarray or None): A length 4 array given a box prompt to the model, in XYXY format. 30def _compute_box_from_mask(mask, original_size=None, box_extension=0): 31 coords = np.where(mask == 1) 32 min_y, min_x = coords[0].min(), coords[1].min() 33 max_y, max_x = coords[0].max(), coords[1].max() 34 box = np.array([min_y, min_x, max_y + 1, max_x + 1]) 35 return _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension) 36 37 38# sample points from a mask. SAM expects the following point inputs: 39def _compute_points_from_mask(mask, original_size, box_extension, use_single_point=False): 40 box = _compute_box_from_mask(mask, box_extension=box_extension) 41 42 # get slice and offset in python coordinate convention 43 bb = (slice(box[1], box[3]), slice(box[0], box[2])) 44 offset = np.array([box[1], box[0]]) 45 46 # crop the mask and compute distances 47 cropped_mask = mask[bb] 48 object_boundaries = find_boundaries(cropped_mask, mode="outer") 49 distances = gaussian_smoothing(distance_transform(object_boundaries == 0), sigma=1.0) 50 inner_distances = distances.copy() 51 cropped_mask = cropped_mask.astype("bool") 52 inner_distances[~cropped_mask] = 0.0 53 if use_single_point: 54 center = inner_distances.argmax() 55 center = np.unravel_index(center, inner_distances.shape) 56 point_coords = (center + offset)[None] 57 point_labels = np.ones(1, dtype="uint8") 58 return point_coords[:, ::-1], point_labels 59 60 outer_distances = distances.copy() 61 outer_distances[cropped_mask] = 0.0 62 63 # sample positives and negatives from the distance maxima 64 inner_maxima = peak_local_max(inner_distances, exclude_border=False, min_distance=3) 65 outer_maxima = peak_local_max(outer_distances, exclude_border=False, min_distance=5) 66 67 # derive the positive (=inner maxima) and negative (=outer maxima) points 68 point_coords = np.concatenate([inner_maxima, outer_maxima]).astype("float64") 69 point_coords += offset 70 71 if original_size is not None: 72 scale_factor = np.array([ 73 original_size[0] / float(mask.shape[0]), original_size[1] / float(mask.shape[1]) 74 ])[None] 75 point_coords *= scale_factor 76 77 # get the point labels 78 point_labels = np.concatenate( 79 [np.ones(len(inner_maxima), dtype="uint8"), np.zeros(len(outer_maxima), dtype="uint8")] 80 ) 81 return point_coords[:, ::-1], point_labels 82 83 84def _compute_logits_from_mask(mask, eps=1e-3): 85 86 def inv_sigmoid(x): 87 return np.log(x / (1 - x)) 88 89 # resize to the expected mask shape of SAM (256x256) 90 assert mask.ndim == 2 91 expected_shape = (256, 256) 92 93 # Resize the *binary* mask (instead of the inverse-sigmoid logits) to SAM's expected 94 # mask shape and re-binarize afterwards. This keeps small objects from being washed out 95 # by the antialiased downscaling that ResizeLongestSide applies, which otherwise makes 96 # the mask prompt too weak for small objects in large (and non-square) images. 97 binary_mask = (mask == 1).astype("float32") 98 99 if binary_mask.shape != expected_shape: 100 trafo = ResizeLongestSide(expected_shape[0]) 101 binary_mask = trafo.apply_image_torch(torch.from_numpy(binary_mask[None, None])) 102 binary_mask = binary_mask.numpy().squeeze() 103 104 if binary_mask.shape != expected_shape: # shape is not square -> pad the other side 105 h, w = binary_mask.shape 106 padh = expected_shape[0] - h 107 padw = expected_shape[1] - w 108 # IMPORTANT: need to pad with zero, otherwise SAM doesn't understand the padding 109 pad_width = ((0, padh), (0, padw)) 110 binary_mask = np.pad(binary_mask, pad_width, mode="constant", constant_values=0) 111 112 logits = np.where(binary_mask > 0.5, inv_sigmoid(1 - eps), inv_sigmoid(eps)).astype("float32") 113 logits = logits[None] 114 assert logits.shape == (1, 256, 256), f"{logits.shape}" 115 return logits 116 117 118# 119# other helper functions 120# 121 122 123def _process_box(box, shape, original_size=None, box_extension=0): 124 if box_extension == 0: # no extension 125 extension_y, extension_x = 0, 0 126 elif box_extension >= 1: # extension by a fixed factor 127 extension_y, extension_x = box_extension, box_extension 128 else: # extension by fraction of the box len 129 len_y, len_x = box[2] - box[0], box[3] - box[1] 130 extension_y, extension_x = box_extension * len_y, box_extension * len_x 131 132 box = np.array([ 133 max(box[1] - extension_x, 0), max(box[0] - extension_y, 0), 134 min(box[3] + extension_x, shape[1]), min(box[2] + extension_y, shape[0]), 135 ]) 136 137 if original_size is not None: 138 trafo = ResizeLongestSide(max(original_size)) 139 box = trafo.apply_boxes(box[None], (256, 256)).squeeze() 140 141 # round up the bounding box values 142 box = np.round(box).astype(int) 143 144 return box 145 146 147# Select the correct tile based on average of points 148# and bring the points to the coordinate system of the tile. 149# Discard points that are not in the tile and warn if this happens. 150def _points_to_tile(prompts, shape, tile_shape, halo): 151 points, labels = prompts 152 153 tiling = Blocking([0, 0], shape, tile_shape) 154 center = np.mean(points, axis=0).round().astype("int").tolist() 155 tile_id = tiling.coordinates_to_block_id(center) 156 157 tile = tiling.get_block_with_halo(tile_id, list(halo)).outer_block 158 offset = tile.begin 159 this_tile_shape = tile.shape 160 161 points_in_tile = points - np.array(offset) 162 labels_in_tile = labels 163 164 valid_point_mask = (points_in_tile >= 0).all(axis=1) 165 valid_point_mask = np.logical_and( 166 valid_point_mask, 167 np.logical_and( 168 points_in_tile[:, 0] < this_tile_shape[0], points_in_tile[:, 1] < this_tile_shape[1] 169 ) 170 ) 171 if not valid_point_mask.all(): 172 points_in_tile = points_in_tile[valid_point_mask] 173 labels_in_tile = labels_in_tile[valid_point_mask] 174 warnings.warn( 175 f"{(~valid_point_mask).sum()} points were not in the tile and are dropped" 176 ) 177 178 return tile_id, tile, (points_in_tile, labels_in_tile) 179 180 181def _box_to_tile(box, shape, tile_shape, halo): 182 tiling = Blocking([0, 0], shape, tile_shape) 183 center = np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2]).round().astype("int").tolist() 184 tile_id = tiling.coordinates_to_block_id(center) 185 186 tile = tiling.get_block_with_halo(tile_id, list(halo)).outer_block 187 offset = tile.begin 188 this_tile_shape = tile.shape 189 190 box_in_tile = np.array( 191 [ 192 max(box[0] - offset[0], 0), max(box[1] - offset[1], 0), 193 min(box[2] - offset[0], this_tile_shape[0]), min(box[3] - offset[1], this_tile_shape[1]) 194 ] 195 ) 196 197 return tile_id, tile, box_in_tile 198 199 200def _mask_to_tile(mask, shape, tile_shape, halo): 201 tiling = Blocking([0, 0], shape, tile_shape) 202 203 coords = np.where(mask) 204 center = np.array([np.mean(coords[0]), np.mean(coords[1])]).round().astype("int").tolist() 205 tile_id = tiling.coordinates_to_block_id(center) 206 207 tile = tiling.get_block_with_halo(tile_id, list(halo)).outer_block 208 bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) 209 210 mask_in_tile = mask[bb] 211 return tile_id, tile, mask_in_tile 212 213 214def _initialize_predictor(predictor, image_embeddings, i, prompts, to_tile): 215 tile = None 216 217 # Set the precomputed state for tiled prediction. 218 if image_embeddings is not None and image_embeddings["input_size"] is None: 219 features = image_embeddings["features"] 220 shape, tile_shape, halo = features.attrs["shape"], features.attrs["tile_shape"], features.attrs["halo"] 221 tile_id, tile, prompts = to_tile(prompts, shape, tile_shape, halo) 222 util.set_precomputed(predictor, image_embeddings, i, tile_id=tile_id) 223 224 # Set the precomputed state for normal prediction. 225 elif image_embeddings is not None: 226 shape = image_embeddings["original_size"] 227 util.set_precomputed(predictor, image_embeddings, i) 228 229 else: 230 shape = predictor.original_size 231 232 return predictor, tile, prompts, shape 233 234 235def _tile_to_full_mask(mask, shape, tile): 236 full_mask = np.zeros(mask.shape[0:1] + tuple(shape), dtype=mask.dtype) 237 bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) 238 full_mask[(slice(None),) + bb] = mask 239 return full_mask 240 241 242# 243# functions for prompted segmentation: 244# - segment_from_points: use point prompts as input 245# - segment_from_mask: use binary mask as input, support conversion to mask, box and point prompts 246# - segment_from_box: use box prompt as input 247# - segment_from_box_and_points: use box and point prompts as input 248# 249 250 251def segment_from_points( 252 predictor: SamPredictor, 253 points: np.ndarray, 254 labels: np.ndarray, 255 image_embeddings: Optional[util.ImageEmbeddings] = None, 256 i: Optional[int] = None, 257 multimask_output: bool = False, 258 return_all: bool = False, 259 use_best_multimask: Optional[bool] = None, 260): 261 """Segmentation from point prompts. 262 263 Args: 264 predictor: The segment anything predictor. 265 points: The point prompts given in the image coordinate system. 266 labels: The labels (positive or negative) associated with the points. 267 image_embeddings: Optional precomputed image embeddings. 268 Has to be passed if the predictor is not yet initialized. 269 i: Index for the image data. Required if the input data has three spatial dimensions 270 or a time dimension and two spatial dimensions. 271 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 272 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 273 use_best_multimask: Whether to use multimask output and then choose the best mask. 274 By default this is used for a single positive point and not otherwise. 275 276 Returns: 277 The binary segmentation mask. 278 """ 279 predictor, tile, prompts, shape = _initialize_predictor( 280 predictor, image_embeddings, i, (points, labels), _points_to_tile 281 ) 282 points, labels = prompts 283 284 if use_best_multimask is None: 285 use_best_multimask = len(points) == 1 and labels[0] == 1 286 multimask_output_ = multimask_output or use_best_multimask 287 288 # predict the mask 289 mask, scores, logits = predictor.predict( 290 point_coords=points[:, ::-1], # SAM has reversed XY conventions 291 point_labels=labels, 292 multimask_output=multimask_output_, 293 ) 294 295 if use_best_multimask: 296 best_mask_id = np.argmax(scores) 297 mask = mask[best_mask_id][None] 298 299 if tile is not None: 300 mask = _tile_to_full_mask(mask, shape, tile) 301 302 if return_all: 303 return mask, scores, logits 304 else: 305 return mask 306 307 308def segment_from_mask( 309 predictor: SamPredictor, 310 mask: np.ndarray, 311 image_embeddings: Optional[util.ImageEmbeddings] = None, 312 i: Optional[int] = None, 313 use_box: bool = True, 314 use_mask: bool = True, 315 use_points: bool = False, 316 original_size: Optional[Tuple[int, ...]] = None, 317 multimask_output: bool = False, 318 return_all: bool = False, 319 return_logits: bool = False, 320 box_extension: float = 0.0, 321 box: Optional[np.ndarray] = None, 322 points: Optional[np.ndarray] = None, 323 labels: Optional[np.ndarray] = None, 324 use_single_point: bool = False, 325): 326 """Segmentation from a mask prompt. 327 328 Args: 329 predictor: The segment anything predictor. 330 mask: The mask used to derive prompts. 331 image_embeddings: Optional precomputed image embeddings. 332 Has to be passed if the predictor is not yet initialized. 333 i: Index for the image data. Required if the input data has three spatial dimensions 334 or a time dimension and two spatial dimensions. 335 use_box: Whether to derive the bounding box prompt from the mask. By default, set to 'True'. 336 use_mask: Whether to use the mask itself as prompt. By default, set to 'True'. 337 use_points: Whether to derive point prompts from the mask. By default, set to 'False'. 338 original_size: Full image shape. Use this if the mask that is being passed 339 downsampled compared to the original image. 340 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 341 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 342 box_extension: Relative factor used to enlarge the bounding box prompt. 343 By default, does not enlarge the bounding box. 344 box: Precomputed bounding box. 345 points: Precomputed point prompts. 346 labels: Positive/negative labels corresponding to the point prompts. 347 use_single_point: Whether to derive just a single point from the mask. 348 In case use_points is true. 349 350 Returns: 351 The binary segmentation mask. 352 """ 353 prompts = (mask, box, points, labels) 354 355 def _to_tile(prompts, shape, tile_shape, halo): 356 mask, box, points, labels = prompts 357 tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo) 358 if points is not None: 359 tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 360 if tile_id_points != tile_id: 361 raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.") 362 points, labels = point_prompts 363 if box is not None: 364 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 365 if tile_id_box != tile_id: 366 raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.") 367 return tile_id, tile, (mask, box, points, labels) 368 369 predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile) 370 mask, box, points, labels = prompts 371 372 if points is not None: 373 if labels is None: 374 raise ValueError("If points are passed you also need to pass labels.") 375 point_coords, point_labels = points, labels 376 377 elif use_points and mask.sum() != 0: 378 point_coords, point_labels = _compute_points_from_mask( 379 mask, original_size=original_size, box_extension=box_extension, 380 use_single_point=use_single_point, 381 ) 382 383 else: 384 point_coords, point_labels = None, None 385 386 if box is None: 387 box = _compute_box_from_mask( 388 mask, original_size=original_size, box_extension=box_extension 389 ) if use_box and mask.sum() != 0 else None 390 else: 391 box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension) 392 393 logits = _compute_logits_from_mask(mask) if use_mask else None 394 395 mask, scores, logits = predictor.predict( 396 point_coords=point_coords, point_labels=point_labels, 397 mask_input=logits, box=box, 398 multimask_output=multimask_output, return_logits=return_logits 399 ) 400 401 if tile is not None: 402 mask = _tile_to_full_mask(mask, shape, tile) 403 404 if return_all: 405 return mask, scores, logits 406 else: 407 return mask 408 409 410def segment_from_box( 411 predictor: SamPredictor, 412 box: np.ndarray, 413 image_embeddings: Optional[util.ImageEmbeddings] = None, 414 i: Optional[int] = None, 415 multimask_output: bool = False, 416 return_all: bool = False, 417 box_extension: float = 0.0, 418): 419 """Segmentation from a box prompt. 420 421 Args: 422 predictor: The segment anything predictor. 423 box: The box prompt. 424 image_embeddings: Optional precomputed image embeddings. 425 Has to be passed if the predictor is not yet initialized. 426 i: Index for the image data. Required if the input data has three spatial dimensions 427 or a time dimension and two spatial dimensions. 428 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 429 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 430 box_extension: Relative factor used to enlarge the bounding box prompt. 431 By default, does not enlarge the bounding box. 432 433 Returns: 434 The binary segmentation mask. 435 """ 436 predictor, tile, box, shape = _initialize_predictor( 437 predictor, image_embeddings, i, box, _box_to_tile 438 ) 439 mask, scores, logits = predictor.predict( 440 box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output 441 ) 442 443 if tile is not None: 444 mask = _tile_to_full_mask(mask, shape, tile) 445 446 if return_all: 447 return mask, scores, logits 448 else: 449 return mask 450 451 452def segment_from_box_and_points( 453 predictor: SamPredictor, 454 box: np.ndarray, 455 points: np.ndarray, 456 labels: np.ndarray, 457 image_embeddings: Optional[util.ImageEmbeddings] = None, 458 i: Optional[int] = None, 459 multimask_output: bool = False, 460 return_all: bool = False, 461): 462 """Segmentation from a box prompt and point prompts. 463 464 Args: 465 predictor: The segment anything predictor. 466 box: The box prompt. 467 points: The point prompts, given in the image coordinates system. 468 labels: The point labels, either positive or negative. 469 image_embeddings: Optional precomputed image embeddings. 470 Has to be passed if the predictor is not yet initialized. 471 i: Index for the image data. Required if the input data has three spatial dimensions 472 or a time dimension and two spatial dimensions. 473 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 474 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 475 476 Returns: 477 The binary segmentation mask. 478 """ 479 def box_and_points_to_tile(prompts, shape, tile_shape, halo): 480 box, points, labels = prompts 481 tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 482 points, labels = point_prompts 483 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 484 if tile_id_box != tile_id: 485 raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.") 486 return tile_id, tile, (box, points, labels) 487 488 predictor, tile, prompts, shape = _initialize_predictor( 489 predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile 490 ) 491 box, points, labels = prompts 492 493 mask, scores, logits = predictor.predict( 494 point_coords=points[:, ::-1], # SAM has reversed XY conventions 495 point_labels=labels, 496 box=_process_box(box, shape), 497 multimask_output=multimask_output 498 ) 499 500 if tile is not None: 501 mask = _tile_to_full_mask(mask, shape, tile) 502 503 if return_all: 504 return mask, scores, logits 505 else: 506 return mask
def
segment_from_points( predictor: segment_anything.predictor.SamPredictor, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, use_best_multimask: Optional[bool] = None):
252def segment_from_points( 253 predictor: SamPredictor, 254 points: np.ndarray, 255 labels: np.ndarray, 256 image_embeddings: Optional[util.ImageEmbeddings] = None, 257 i: Optional[int] = None, 258 multimask_output: bool = False, 259 return_all: bool = False, 260 use_best_multimask: Optional[bool] = None, 261): 262 """Segmentation from point prompts. 263 264 Args: 265 predictor: The segment anything predictor. 266 points: The point prompts given in the image coordinate system. 267 labels: The labels (positive or negative) associated with the points. 268 image_embeddings: Optional precomputed image embeddings. 269 Has to be passed if the predictor is not yet initialized. 270 i: Index for the image data. Required if the input data has three spatial dimensions 271 or a time dimension and two spatial dimensions. 272 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 273 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 274 use_best_multimask: Whether to use multimask output and then choose the best mask. 275 By default this is used for a single positive point and not otherwise. 276 277 Returns: 278 The binary segmentation mask. 279 """ 280 predictor, tile, prompts, shape = _initialize_predictor( 281 predictor, image_embeddings, i, (points, labels), _points_to_tile 282 ) 283 points, labels = prompts 284 285 if use_best_multimask is None: 286 use_best_multimask = len(points) == 1 and labels[0] == 1 287 multimask_output_ = multimask_output or use_best_multimask 288 289 # predict the mask 290 mask, scores, logits = predictor.predict( 291 point_coords=points[:, ::-1], # SAM has reversed XY conventions 292 point_labels=labels, 293 multimask_output=multimask_output_, 294 ) 295 296 if use_best_multimask: 297 best_mask_id = np.argmax(scores) 298 mask = mask[best_mask_id][None] 299 300 if tile is not None: 301 mask = _tile_to_full_mask(mask, shape, tile) 302 303 if return_all: 304 return mask, scores, logits 305 else: 306 return mask
Segmentation from point prompts.
Arguments:
- predictor: The segment anything predictor.
- points: The point prompts given in the image coordinate system.
- labels: The labels (positive or negative) associated with the points.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
- i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
- return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
- use_best_multimask: Whether to use multimask output and then choose the best mask. By default this is used for a single positive point and not otherwise.
Returns:
The binary segmentation mask.
def
segment_from_mask( predictor: segment_anything.predictor.SamPredictor, mask: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, use_box: bool = True, use_mask: bool = True, use_points: bool = False, original_size: Optional[Tuple[int, ...]] = None, multimask_output: bool = False, return_all: bool = False, return_logits: bool = False, box_extension: float = 0.0, box: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, labels: Optional[numpy.ndarray] = None, use_single_point: bool = False):
309def segment_from_mask( 310 predictor: SamPredictor, 311 mask: np.ndarray, 312 image_embeddings: Optional[util.ImageEmbeddings] = None, 313 i: Optional[int] = None, 314 use_box: bool = True, 315 use_mask: bool = True, 316 use_points: bool = False, 317 original_size: Optional[Tuple[int, ...]] = None, 318 multimask_output: bool = False, 319 return_all: bool = False, 320 return_logits: bool = False, 321 box_extension: float = 0.0, 322 box: Optional[np.ndarray] = None, 323 points: Optional[np.ndarray] = None, 324 labels: Optional[np.ndarray] = None, 325 use_single_point: bool = False, 326): 327 """Segmentation from a mask prompt. 328 329 Args: 330 predictor: The segment anything predictor. 331 mask: The mask used to derive prompts. 332 image_embeddings: Optional precomputed image embeddings. 333 Has to be passed if the predictor is not yet initialized. 334 i: Index for the image data. Required if the input data has three spatial dimensions 335 or a time dimension and two spatial dimensions. 336 use_box: Whether to derive the bounding box prompt from the mask. By default, set to 'True'. 337 use_mask: Whether to use the mask itself as prompt. By default, set to 'True'. 338 use_points: Whether to derive point prompts from the mask. By default, set to 'False'. 339 original_size: Full image shape. Use this if the mask that is being passed 340 downsampled compared to the original image. 341 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 342 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 343 box_extension: Relative factor used to enlarge the bounding box prompt. 344 By default, does not enlarge the bounding box. 345 box: Precomputed bounding box. 346 points: Precomputed point prompts. 347 labels: Positive/negative labels corresponding to the point prompts. 348 use_single_point: Whether to derive just a single point from the mask. 349 In case use_points is true. 350 351 Returns: 352 The binary segmentation mask. 353 """ 354 prompts = (mask, box, points, labels) 355 356 def _to_tile(prompts, shape, tile_shape, halo): 357 mask, box, points, labels = prompts 358 tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo) 359 if points is not None: 360 tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 361 if tile_id_points != tile_id: 362 raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.") 363 points, labels = point_prompts 364 if box is not None: 365 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 366 if tile_id_box != tile_id: 367 raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.") 368 return tile_id, tile, (mask, box, points, labels) 369 370 predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile) 371 mask, box, points, labels = prompts 372 373 if points is not None: 374 if labels is None: 375 raise ValueError("If points are passed you also need to pass labels.") 376 point_coords, point_labels = points, labels 377 378 elif use_points and mask.sum() != 0: 379 point_coords, point_labels = _compute_points_from_mask( 380 mask, original_size=original_size, box_extension=box_extension, 381 use_single_point=use_single_point, 382 ) 383 384 else: 385 point_coords, point_labels = None, None 386 387 if box is None: 388 box = _compute_box_from_mask( 389 mask, original_size=original_size, box_extension=box_extension 390 ) if use_box and mask.sum() != 0 else None 391 else: 392 box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension) 393 394 logits = _compute_logits_from_mask(mask) if use_mask else None 395 396 mask, scores, logits = predictor.predict( 397 point_coords=point_coords, point_labels=point_labels, 398 mask_input=logits, box=box, 399 multimask_output=multimask_output, return_logits=return_logits 400 ) 401 402 if tile is not None: 403 mask = _tile_to_full_mask(mask, shape, tile) 404 405 if return_all: 406 return mask, scores, logits 407 else: 408 return mask
Segmentation from a mask prompt.
Arguments:
- predictor: The segment anything predictor.
- mask: The mask used to derive prompts.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
- i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- use_box: Whether to derive the bounding box prompt from the mask. By default, set to 'True'.
- use_mask: Whether to use the mask itself as prompt. By default, set to 'True'.
- use_points: Whether to derive point prompts from the mask. By default, set to 'False'.
- original_size: Full image shape. Use this if the mask that is being passed downsampled compared to the original image.
- multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
- return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
- box_extension: Relative factor used to enlarge the bounding box prompt. By default, does not enlarge the bounding box.
- box: Precomputed bounding box.
- points: Precomputed point prompts.
- labels: Positive/negative labels corresponding to the point prompts.
- use_single_point: Whether to derive just a single point from the mask. In case use_points is true.
Returns:
The binary segmentation mask.
def
segment_from_box( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, box_extension: float = 0.0):
411def segment_from_box( 412 predictor: SamPredictor, 413 box: np.ndarray, 414 image_embeddings: Optional[util.ImageEmbeddings] = None, 415 i: Optional[int] = None, 416 multimask_output: bool = False, 417 return_all: bool = False, 418 box_extension: float = 0.0, 419): 420 """Segmentation from a box prompt. 421 422 Args: 423 predictor: The segment anything predictor. 424 box: The box prompt. 425 image_embeddings: Optional precomputed image embeddings. 426 Has to be passed if the predictor is not yet initialized. 427 i: Index for the image data. Required if the input data has three spatial dimensions 428 or a time dimension and two spatial dimensions. 429 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 430 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 431 box_extension: Relative factor used to enlarge the bounding box prompt. 432 By default, does not enlarge the bounding box. 433 434 Returns: 435 The binary segmentation mask. 436 """ 437 predictor, tile, box, shape = _initialize_predictor( 438 predictor, image_embeddings, i, box, _box_to_tile 439 ) 440 mask, scores, logits = predictor.predict( 441 box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output 442 ) 443 444 if tile is not None: 445 mask = _tile_to_full_mask(mask, shape, tile) 446 447 if return_all: 448 return mask, scores, logits 449 else: 450 return mask
Segmentation from a box prompt.
Arguments:
- predictor: The segment anything predictor.
- box: The box prompt.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
- i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
- return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
- box_extension: Relative factor used to enlarge the bounding box prompt. By default, does not enlarge the bounding box.
Returns:
The binary segmentation mask.
def
segment_from_box_and_points( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False):
453def segment_from_box_and_points( 454 predictor: SamPredictor, 455 box: np.ndarray, 456 points: np.ndarray, 457 labels: np.ndarray, 458 image_embeddings: Optional[util.ImageEmbeddings] = None, 459 i: Optional[int] = None, 460 multimask_output: bool = False, 461 return_all: bool = False, 462): 463 """Segmentation from a box prompt and point prompts. 464 465 Args: 466 predictor: The segment anything predictor. 467 box: The box prompt. 468 points: The point prompts, given in the image coordinates system. 469 labels: The point labels, either positive or negative. 470 image_embeddings: Optional precomputed image embeddings. 471 Has to be passed if the predictor is not yet initialized. 472 i: Index for the image data. Required if the input data has three spatial dimensions 473 or a time dimension and two spatial dimensions. 474 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 475 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 476 477 Returns: 478 The binary segmentation mask. 479 """ 480 def box_and_points_to_tile(prompts, shape, tile_shape, halo): 481 box, points, labels = prompts 482 tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 483 points, labels = point_prompts 484 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 485 if tile_id_box != tile_id: 486 raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.") 487 return tile_id, tile, (box, points, labels) 488 489 predictor, tile, prompts, shape = _initialize_predictor( 490 predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile 491 ) 492 box, points, labels = prompts 493 494 mask, scores, logits = predictor.predict( 495 point_coords=points[:, ::-1], # SAM has reversed XY conventions 496 point_labels=labels, 497 box=_process_box(box, shape), 498 multimask_output=multimask_output 499 ) 500 501 if tile is not None: 502 mask = _tile_to_full_mask(mask, shape, tile) 503 504 if return_all: 505 return mask, scores, logits 506 else: 507 return mask
Segmentation from a box prompt and point prompts.
Arguments:
- predictor: The segment anything predictor.
- box: The box prompt.
- points: The point prompts, given in the image coordinates system.
- labels: The point labels, either positive or negative.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
- i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
- return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
Returns:
The binary segmentation mask.