micro_sam.prompt_based_segmentation
Functions for prompt-based segmentation with Segment Anything.
1""" 2Functions for prompt-based segmentation with Segment Anything. 3""" 4 5import warnings 6from typing import Optional, Tuple 7 8import numpy as np 9from skimage.filters import gaussian 10from skimage.feature import peak_local_max 11from skimage.segmentation import find_boundaries 12from scipy.ndimage import distance_transform_edt 13 14import torch 15 16from nifty.tools import blocking 17 18from segment_anything.predictor import SamPredictor 19from segment_anything.utils.transforms import ResizeLongestSide 20 21from . import util 22 23 24# 25# helper functions for translating mask inputs into other prompts 26# 27 28 29# compute the bounding box from a mask. SAM expects the following input: 30# box (np.ndarray or None): A length 4 array given a box prompt to the model, in XYXY format. 31def _compute_box_from_mask(mask, original_size=None, box_extension=0): 32 coords = np.where(mask == 1) 33 min_y, min_x = coords[0].min(), coords[1].min() 34 max_y, max_x = coords[0].max(), coords[1].max() 35 box = np.array([min_y, min_x, max_y + 1, max_x + 1]) 36 return _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension) 37 38 39# sample points from a mask. SAM expects the following point inputs: 40def _compute_points_from_mask(mask, original_size, box_extension, use_single_point=False): 41 box = _compute_box_from_mask(mask, box_extension=box_extension) 42 43 # get slice and offset in python coordinate convention 44 bb = (slice(box[1], box[3]), slice(box[0], box[2])) 45 offset = np.array([box[1], box[0]]) 46 47 # crop the mask and compute distances 48 cropped_mask = mask[bb] 49 object_boundaries = find_boundaries(cropped_mask, mode="outer") 50 distances = gaussian(distance_transform_edt(object_boundaries == 0)) 51 inner_distances = distances.copy() 52 cropped_mask = cropped_mask.astype("bool") 53 inner_distances[~cropped_mask] = 0.0 54 if use_single_point: 55 center = inner_distances.argmax() 56 center = np.unravel_index(center, inner_distances.shape) 57 point_coords = (center + offset)[None] 58 point_labels = np.ones(1, dtype="uint8") 59 return point_coords[:, ::-1], point_labels 60 61 outer_distances = distances.copy() 62 outer_distances[cropped_mask] = 0.0 63 64 # sample positives and negatives from the distance maxima 65 inner_maxima = peak_local_max(inner_distances, exclude_border=False, min_distance=3) 66 outer_maxima = peak_local_max(outer_distances, exclude_border=False, min_distance=5) 67 68 # derive the positive (=inner maxima) and negative (=outer maxima) points 69 point_coords = np.concatenate([inner_maxima, outer_maxima]).astype("float64") 70 point_coords += offset 71 72 if original_size is not None: 73 scale_factor = np.array([ 74 original_size[0] / float(mask.shape[0]), original_size[1] / float(mask.shape[1]) 75 ])[None] 76 point_coords *= scale_factor 77 78 # get the point labels 79 point_labels = np.concatenate( 80 [ 81 np.ones(len(inner_maxima), dtype="uint8"), 82 np.zeros(len(outer_maxima), dtype="uint8"), 83 ] 84 ) 85 return point_coords[:, ::-1], point_labels 86 87 88def _compute_logits_from_mask(mask, eps=1e-3): 89 90 def inv_sigmoid(x): 91 return np.log(x / (1 - x)) 92 93 logits = np.zeros(mask.shape, dtype="float32") 94 logits[mask == 1] = 1 - eps 95 logits[mask == 0] = eps 96 logits = inv_sigmoid(logits) 97 98 # resize to the expected mask shape of SAM (256x256) 99 assert logits.ndim == 2 100 expected_shape = (256, 256) 101 102 if logits.shape == expected_shape: # shape matches, do nothing 103 pass 104 105 elif logits.shape[0] == logits.shape[1]: # shape is square 106 trafo = ResizeLongestSide(expected_shape[0]) 107 logits = trafo.apply_image_torch(torch.from_numpy(logits[None, None])) 108 logits = logits.numpy().squeeze() 109 110 else: # shape is not square 111 # resize the longest side to expected shape 112 trafo = ResizeLongestSide(expected_shape[0]) 113 logits = trafo.apply_image_torch(torch.from_numpy(logits[None, None])) 114 logits = logits.numpy().squeeze() 115 116 # pad the other side 117 h, w = logits.shape 118 padh = expected_shape[0] - h 119 padw = expected_shape[1] - w 120 # IMPORTANT: need to pad with zero, otherwise SAM doesn't understand the padding 121 pad_width = ((0, padh), (0, padw)) 122 logits = np.pad(logits, pad_width, mode="constant", constant_values=0) 123 124 logits = logits[None] 125 assert logits.shape == (1, 256, 256), f"{logits.shape}" 126 return logits 127 128 129# 130# other helper functions 131# 132 133 134def _process_box(box, shape, original_size=None, box_extension=0): 135 if box_extension == 0: # no extension 136 extension_y, extension_x = 0, 0 137 elif box_extension >= 1: # extension by a fixed factor 138 extension_y, extension_x = box_extension, box_extension 139 else: # extension by fraction of the box len 140 len_y, len_x = box[2] - box[0], box[3] - box[1] 141 extension_y, extension_x = box_extension * len_y, box_extension * len_x 142 143 box = np.array([ 144 max(box[1] - extension_x, 0), max(box[0] - extension_y, 0), 145 min(box[3] + extension_x, shape[1]), min(box[2] + extension_y, shape[0]), 146 ]) 147 148 if original_size is not None: 149 trafo = ResizeLongestSide(max(original_size)) 150 box = trafo.apply_boxes(box[None], (256, 256)).squeeze() 151 152 # round up the bounding box values 153 box = np.round(box).astype(int) 154 155 return box 156 157 158# Select the correct tile based on average of points 159# and bring the points to the coordinate system of the tile. 160# Discard points that are not in the tile and warn if this happens. 161def _points_to_tile(prompts, shape, tile_shape, halo): 162 points, labels = prompts 163 164 tiling = blocking([0, 0], shape, tile_shape) 165 center = np.mean(points, axis=0).round().astype("int").tolist() 166 tile_id = tiling.coordinatesToBlockId(center) 167 168 tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock 169 offset = tile.begin 170 this_tile_shape = tile.shape 171 172 points_in_tile = points - np.array(offset) 173 labels_in_tile = labels 174 175 valid_point_mask = (points_in_tile >= 0).all(axis=1) 176 valid_point_mask = np.logical_and( 177 valid_point_mask, 178 np.logical_and( 179 points_in_tile[:, 0] < this_tile_shape[0], points_in_tile[:, 1] < this_tile_shape[1] 180 ) 181 ) 182 if not valid_point_mask.all(): 183 points_in_tile = points_in_tile[valid_point_mask] 184 labels_in_tile = labels_in_tile[valid_point_mask] 185 warnings.warn( 186 f"{(~valid_point_mask).sum()} points were not in the tile and are dropped" 187 ) 188 189 return tile_id, tile, (points_in_tile, labels_in_tile) 190 191 192def _box_to_tile(box, shape, tile_shape, halo): 193 tiling = blocking([0, 0], shape, tile_shape) 194 center = np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2]).round().astype("int").tolist() 195 tile_id = tiling.coordinatesToBlockId(center) 196 197 tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock 198 offset = tile.begin 199 this_tile_shape = tile.shape 200 201 box_in_tile = np.array( 202 [ 203 max(box[0] - offset[0], 0), max(box[1] - offset[1], 0), 204 min(box[2] - offset[0], this_tile_shape[0]), min(box[3] - offset[1], this_tile_shape[1]) 205 ] 206 ) 207 208 return tile_id, tile, box_in_tile 209 210 211def _mask_to_tile(mask, shape, tile_shape, halo): 212 tiling = blocking([0, 0], shape, tile_shape) 213 214 coords = np.where(mask) 215 center = np.array([np.mean(coords[0]), np.mean(coords[1])]).round().astype("int").tolist() 216 tile_id = tiling.coordinatesToBlockId(center) 217 218 tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock 219 bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) 220 221 mask_in_tile = mask[bb] 222 return tile_id, tile, mask_in_tile 223 224 225def _initialize_predictor(predictor, image_embeddings, i, prompts, to_tile): 226 tile = None 227 228 # Set the precomputed state for tiled prediction. 229 if image_embeddings is not None and image_embeddings["input_size"] is None: 230 features = image_embeddings["features"] 231 shape, tile_shape, halo = features.attrs["shape"], features.attrs["tile_shape"], features.attrs["halo"] 232 tile_id, tile, prompts = to_tile(prompts, shape, tile_shape, halo) 233 util.set_precomputed(predictor, image_embeddings, i, tile_id=tile_id) 234 235 # Set the precomputed state for normal prediction. 236 elif image_embeddings is not None: 237 shape = image_embeddings["original_size"] 238 util.set_precomputed(predictor, image_embeddings, i) 239 240 else: 241 shape = predictor.original_size 242 243 return predictor, tile, prompts, shape 244 245 246def _tile_to_full_mask(mask, shape, tile): 247 full_mask = np.zeros(mask.shape[0:1] + tuple(shape), dtype=mask.dtype) 248 bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) 249 full_mask[(slice(None),) + bb] = mask 250 return full_mask 251 252 253# 254# functions for prompted segmentation: 255# - segment_from_points: use point prompts as input 256# - segment_from_mask: use binary mask as input, support conversion to mask, box and point prompts 257# - segment_from_box: use box prompt as input 258# - segment_from_box_and_points: use box and point prompts as input 259# 260 261 262def segment_from_points( 263 predictor: SamPredictor, 264 points: np.ndarray, 265 labels: np.ndarray, 266 image_embeddings: Optional[util.ImageEmbeddings] = None, 267 i: Optional[int] = None, 268 multimask_output: bool = False, 269 return_all: bool = False, 270 use_best_multimask: Optional[bool] = None, 271): 272 """Segmentation from point prompts. 273 274 Args: 275 predictor: The segment anything predictor. 276 points: The point prompts given in the image coordinate system. 277 labels: The labels (positive or negative) associated with the points. 278 image_embeddings: Optional precomputed image embeddings. 279 Has to be passed if the predictor is not yet initialized. 280 i: Index for the image data. Required if the input data has three spatial dimensions 281 or a time dimension and two spatial dimensions. 282 multimask_output: Whether to return multiple or just a single mask. 283 return_all: Whether to return the score and logits in addition to the mask. 284 use_best_multimask: Whether to use multimask output and then choose the best mask. 285 By default this is used for a single positive point and not otherwise. 286 287 Returns: 288 The binary segmentation mask. 289 """ 290 predictor, tile, prompts, shape = _initialize_predictor( 291 predictor, image_embeddings, i, (points, labels), _points_to_tile 292 ) 293 points, labels = prompts 294 295 if use_best_multimask is None: 296 use_best_multimask = len(points) == 1 and labels[0] == 1 297 multimask_output_ = multimask_output or use_best_multimask 298 299 # predict the mask 300 mask, scores, logits = predictor.predict( 301 point_coords=points[:, ::-1], # SAM has reversed XY conventions 302 point_labels=labels, 303 multimask_output=multimask_output_, 304 ) 305 306 if use_best_multimask: 307 best_mask_id = np.argmax(scores) 308 mask = mask[best_mask_id][None] 309 310 if tile is not None: 311 mask = _tile_to_full_mask(mask, shape, tile) 312 313 if return_all: 314 return mask, scores, logits 315 else: 316 return mask 317 318 319def segment_from_mask( 320 predictor: SamPredictor, 321 mask: np.ndarray, 322 image_embeddings: Optional[util.ImageEmbeddings] = None, 323 i: Optional[int] = None, 324 use_box: bool = True, 325 use_mask: bool = True, 326 use_points: bool = False, 327 original_size: Optional[Tuple[int, ...]] = None, 328 multimask_output: bool = False, 329 return_all: bool = False, 330 return_logits: bool = False, 331 box_extension: float = 0.0, 332 box: Optional[np.ndarray] = None, 333 points: Optional[np.ndarray] = None, 334 labels: Optional[np.ndarray] = None, 335 use_single_point: bool = False, 336): 337 """Segmentation from a mask prompt. 338 339 Args: 340 predictor: The segment anything predictor. 341 mask: The mask used to derive prompts. 342 image_embeddings: Optional precomputed image embeddings. 343 Has to be passed if the predictor is not yet initialized. 344 i: Index for the image data. Required if the input data has three spatial dimensions 345 or a time dimension and two spatial dimensions. 346 use_box: Whether to derive the bounding box prompt from the mask. 347 use_mask: Whether to use the mask itself as prompt. 348 use_points: Whether to derive point prompts from the mask. 349 original_size: Full image shape. Use this if the mask that is being passed 350 downsampled compared to the original image. 351 multimask_output: Whether to return multiple or just a single mask. 352 return_all: Whether to return the score and logits in addition to the mask. 353 box_extension: Relative factor used to enlarge the bounding box prompt. 354 box: Precomputed bounding box. 355 points: Precomputed point prompts. 356 labels: Positive/negative labels corresponding to the point prompts. 357 use_single_point: Whether to derive just a single point from the mask. 358 In case use_points is true. 359 360 Returns: 361 The binary segmentation mask. 362 """ 363 prompts = (mask, box, points, labels) 364 365 def _to_tile(prompts, shape, tile_shape, halo): 366 mask, box, points, labels = prompts 367 tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo) 368 if points is not None: 369 tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 370 if tile_id_points != tile_id: 371 raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.") 372 points, labels = point_prompts 373 if box is not None: 374 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 375 if tile_id_box != tile_id: 376 raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.") 377 return tile_id, tile, (mask, box, points, labels) 378 379 predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile) 380 mask, box, points, labels = prompts 381 382 if points is not None: 383 if labels is None: 384 raise ValueError("If points are passed you also need to pass labels.") 385 point_coords, point_labels = points, labels 386 387 elif use_points: 388 point_coords, point_labels = _compute_points_from_mask( 389 mask, original_size=original_size, box_extension=box_extension, 390 use_single_point=use_single_point, 391 ) 392 393 else: 394 point_coords, point_labels = None, None 395 396 if box is None: 397 box = _compute_box_from_mask( 398 mask, original_size=original_size, box_extension=box_extension 399 ) if use_box else None 400 else: 401 box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension) 402 403 logits = _compute_logits_from_mask(mask) if use_mask else None 404 405 mask, scores, logits = predictor.predict( 406 point_coords=point_coords, point_labels=point_labels, 407 mask_input=logits, box=box, 408 multimask_output=multimask_output, return_logits=return_logits 409 ) 410 411 if tile is not None: 412 mask = _tile_to_full_mask(mask, shape, tile) 413 414 if return_all: 415 return mask, scores, logits 416 else: 417 return mask 418 419 420def segment_from_box( 421 predictor: SamPredictor, 422 box: np.ndarray, 423 image_embeddings: Optional[util.ImageEmbeddings] = None, 424 i: Optional[int] = None, 425 multimask_output: bool = False, 426 return_all: bool = False, 427 box_extension: float = 0.0, 428): 429 """Segmentation from a box prompt. 430 431 Args: 432 predictor: The segment anything predictor. 433 box: The box prompt. 434 image_embeddings: Optional precomputed image embeddings. 435 Has to be passed if the predictor is not yet initialized. 436 i: Index for the image data. Required if the input data has three spatial dimensions 437 or a time dimension and two spatial dimensions. 438 multimask_output: Whether to return multiple or just a single mask. 439 return_all: Whether to return the score and logits in addition to the mask. 440 box_extension: Relative factor used to enlarge the bounding box prompt. 441 442 Returns: 443 The binary segmentation mask. 444 """ 445 predictor, tile, box, shape = _initialize_predictor( 446 predictor, image_embeddings, i, box, _box_to_tile 447 ) 448 mask, scores, logits = predictor.predict( 449 box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output 450 ) 451 452 if tile is not None: 453 mask = _tile_to_full_mask(mask, shape, tile) 454 455 if return_all: 456 return mask, scores, logits 457 else: 458 return mask 459 460 461def segment_from_box_and_points( 462 predictor: SamPredictor, 463 box: np.ndarray, 464 points: np.ndarray, 465 labels: np.ndarray, 466 image_embeddings: Optional[util.ImageEmbeddings] = None, 467 i: Optional[int] = None, 468 multimask_output: bool = False, 469 return_all: bool = False, 470): 471 """Segmentation from a box prompt and point prompts. 472 473 Args: 474 predictor: The segment anything predictor. 475 box: The box prompt. 476 points: The point prompts, given in the image coordinates system. 477 labels: The point labels, either positive or negative. 478 image_embeddings: Optional precomputed image embeddings. 479 Has to be passed if the predictor is not yet initialized. 480 i: Index for the image data. Required if the input data has three spatial dimensions 481 or a time dimension and two spatial dimensions. 482 multimask_output: Whether to return multiple or just a single mask. 483 return_all: Whether to return the score and logits in addition to the mask. 484 485 Returns: 486 The binary segmentation mask. 487 """ 488 def box_and_points_to_tile(prompts, shape, tile_shape, halo): 489 box, points, labels = prompts 490 tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 491 points, labels = point_prompts 492 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 493 if tile_id_box != tile_id: 494 raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.") 495 return tile_id, tile, (box, points, labels) 496 497 predictor, tile, prompts, shape = _initialize_predictor( 498 predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile 499 ) 500 box, points, labels = prompts 501 502 mask, scores, logits = predictor.predict( 503 point_coords=points[:, ::-1], # SAM has reversed XY conventions 504 point_labels=labels, 505 box=_process_box(box, shape), 506 multimask_output=multimask_output 507 ) 508 509 if tile is not None: 510 mask = _tile_to_full_mask(mask, shape, tile) 511 512 if return_all: 513 return mask, scores, logits 514 else: 515 return mask
def
segment_from_points( predictor: segment_anything.predictor.SamPredictor, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, use_best_multimask: Optional[bool] = None):
263def segment_from_points( 264 predictor: SamPredictor, 265 points: np.ndarray, 266 labels: np.ndarray, 267 image_embeddings: Optional[util.ImageEmbeddings] = None, 268 i: Optional[int] = None, 269 multimask_output: bool = False, 270 return_all: bool = False, 271 use_best_multimask: Optional[bool] = None, 272): 273 """Segmentation from point prompts. 274 275 Args: 276 predictor: The segment anything predictor. 277 points: The point prompts given in the image coordinate system. 278 labels: The labels (positive or negative) associated with the points. 279 image_embeddings: Optional precomputed image embeddings. 280 Has to be passed if the predictor is not yet initialized. 281 i: Index for the image data. Required if the input data has three spatial dimensions 282 or a time dimension and two spatial dimensions. 283 multimask_output: Whether to return multiple or just a single mask. 284 return_all: Whether to return the score and logits in addition to the mask. 285 use_best_multimask: Whether to use multimask output and then choose the best mask. 286 By default this is used for a single positive point and not otherwise. 287 288 Returns: 289 The binary segmentation mask. 290 """ 291 predictor, tile, prompts, shape = _initialize_predictor( 292 predictor, image_embeddings, i, (points, labels), _points_to_tile 293 ) 294 points, labels = prompts 295 296 if use_best_multimask is None: 297 use_best_multimask = len(points) == 1 and labels[0] == 1 298 multimask_output_ = multimask_output or use_best_multimask 299 300 # predict the mask 301 mask, scores, logits = predictor.predict( 302 point_coords=points[:, ::-1], # SAM has reversed XY conventions 303 point_labels=labels, 304 multimask_output=multimask_output_, 305 ) 306 307 if use_best_multimask: 308 best_mask_id = np.argmax(scores) 309 mask = mask[best_mask_id][None] 310 311 if tile is not None: 312 mask = _tile_to_full_mask(mask, shape, tile) 313 314 if return_all: 315 return mask, scores, logits 316 else: 317 return mask
Segmentation from point prompts.
Arguments:
- predictor: The segment anything predictor.
- points: The point prompts given in the image coordinate system.
- labels: The labels (positive or negative) associated with the points.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized. i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- multimask_output: Whether to return multiple or just a single mask.
- return_all: Whether to return the score and logits in addition to the mask.
- use_best_multimask: Whether to use multimask output and then choose the best mask. By default this is used for a single positive point and not otherwise.
Returns:
The binary segmentation mask.
def
segment_from_mask( predictor: segment_anything.predictor.SamPredictor, mask: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, use_box: bool = True, use_mask: bool = True, use_points: bool = False, original_size: Optional[Tuple[int, ...]] = None, multimask_output: bool = False, return_all: bool = False, return_logits: bool = False, box_extension: float = 0.0, box: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, labels: Optional[numpy.ndarray] = None, use_single_point: bool = False):
320def segment_from_mask( 321 predictor: SamPredictor, 322 mask: np.ndarray, 323 image_embeddings: Optional[util.ImageEmbeddings] = None, 324 i: Optional[int] = None, 325 use_box: bool = True, 326 use_mask: bool = True, 327 use_points: bool = False, 328 original_size: Optional[Tuple[int, ...]] = None, 329 multimask_output: bool = False, 330 return_all: bool = False, 331 return_logits: bool = False, 332 box_extension: float = 0.0, 333 box: Optional[np.ndarray] = None, 334 points: Optional[np.ndarray] = None, 335 labels: Optional[np.ndarray] = None, 336 use_single_point: bool = False, 337): 338 """Segmentation from a mask prompt. 339 340 Args: 341 predictor: The segment anything predictor. 342 mask: The mask used to derive prompts. 343 image_embeddings: Optional precomputed image embeddings. 344 Has to be passed if the predictor is not yet initialized. 345 i: Index for the image data. Required if the input data has three spatial dimensions 346 or a time dimension and two spatial dimensions. 347 use_box: Whether to derive the bounding box prompt from the mask. 348 use_mask: Whether to use the mask itself as prompt. 349 use_points: Whether to derive point prompts from the mask. 350 original_size: Full image shape. Use this if the mask that is being passed 351 downsampled compared to the original image. 352 multimask_output: Whether to return multiple or just a single mask. 353 return_all: Whether to return the score and logits in addition to the mask. 354 box_extension: Relative factor used to enlarge the bounding box prompt. 355 box: Precomputed bounding box. 356 points: Precomputed point prompts. 357 labels: Positive/negative labels corresponding to the point prompts. 358 use_single_point: Whether to derive just a single point from the mask. 359 In case use_points is true. 360 361 Returns: 362 The binary segmentation mask. 363 """ 364 prompts = (mask, box, points, labels) 365 366 def _to_tile(prompts, shape, tile_shape, halo): 367 mask, box, points, labels = prompts 368 tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo) 369 if points is not None: 370 tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 371 if tile_id_points != tile_id: 372 raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.") 373 points, labels = point_prompts 374 if box is not None: 375 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 376 if tile_id_box != tile_id: 377 raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.") 378 return tile_id, tile, (mask, box, points, labels) 379 380 predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile) 381 mask, box, points, labels = prompts 382 383 if points is not None: 384 if labels is None: 385 raise ValueError("If points are passed you also need to pass labels.") 386 point_coords, point_labels = points, labels 387 388 elif use_points: 389 point_coords, point_labels = _compute_points_from_mask( 390 mask, original_size=original_size, box_extension=box_extension, 391 use_single_point=use_single_point, 392 ) 393 394 else: 395 point_coords, point_labels = None, None 396 397 if box is None: 398 box = _compute_box_from_mask( 399 mask, original_size=original_size, box_extension=box_extension 400 ) if use_box else None 401 else: 402 box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension) 403 404 logits = _compute_logits_from_mask(mask) if use_mask else None 405 406 mask, scores, logits = predictor.predict( 407 point_coords=point_coords, point_labels=point_labels, 408 mask_input=logits, box=box, 409 multimask_output=multimask_output, return_logits=return_logits 410 ) 411 412 if tile is not None: 413 mask = _tile_to_full_mask(mask, shape, tile) 414 415 if return_all: 416 return mask, scores, logits 417 else: 418 return mask
Segmentation from a mask prompt.
Arguments:
- predictor: The segment anything predictor.
- mask: The mask used to derive prompts.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized. i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- use_box: Whether to derive the bounding box prompt from the mask.
- use_mask: Whether to use the mask itself as prompt.
- use_points: Whether to derive point prompts from the mask.
- original_size: Full image shape. Use this if the mask that is being passed downsampled compared to the original image.
- multimask_output: Whether to return multiple or just a single mask.
- return_all: Whether to return the score and logits in addition to the mask.
- box_extension: Relative factor used to enlarge the bounding box prompt.
- box: Precomputed bounding box.
- points: Precomputed point prompts.
- labels: Positive/negative labels corresponding to the point prompts.
- use_single_point: Whether to derive just a single point from the mask. In case use_points is true.
Returns:
The binary segmentation mask.
def
segment_from_box( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, box_extension: float = 0.0):
421def segment_from_box( 422 predictor: SamPredictor, 423 box: np.ndarray, 424 image_embeddings: Optional[util.ImageEmbeddings] = None, 425 i: Optional[int] = None, 426 multimask_output: bool = False, 427 return_all: bool = False, 428 box_extension: float = 0.0, 429): 430 """Segmentation from a box prompt. 431 432 Args: 433 predictor: The segment anything predictor. 434 box: The box prompt. 435 image_embeddings: Optional precomputed image embeddings. 436 Has to be passed if the predictor is not yet initialized. 437 i: Index for the image data. Required if the input data has three spatial dimensions 438 or a time dimension and two spatial dimensions. 439 multimask_output: Whether to return multiple or just a single mask. 440 return_all: Whether to return the score and logits in addition to the mask. 441 box_extension: Relative factor used to enlarge the bounding box prompt. 442 443 Returns: 444 The binary segmentation mask. 445 """ 446 predictor, tile, box, shape = _initialize_predictor( 447 predictor, image_embeddings, i, box, _box_to_tile 448 ) 449 mask, scores, logits = predictor.predict( 450 box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output 451 ) 452 453 if tile is not None: 454 mask = _tile_to_full_mask(mask, shape, tile) 455 456 if return_all: 457 return mask, scores, logits 458 else: 459 return mask
Segmentation from a box prompt.
Arguments:
- predictor: The segment anything predictor.
- box: The box prompt.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized. i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- multimask_output: Whether to return multiple or just a single mask.
- return_all: Whether to return the score and logits in addition to the mask.
- box_extension: Relative factor used to enlarge the bounding box prompt.
Returns:
The binary segmentation mask.
def
segment_from_box_and_points( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False):
462def segment_from_box_and_points( 463 predictor: SamPredictor, 464 box: np.ndarray, 465 points: np.ndarray, 466 labels: np.ndarray, 467 image_embeddings: Optional[util.ImageEmbeddings] = None, 468 i: Optional[int] = None, 469 multimask_output: bool = False, 470 return_all: bool = False, 471): 472 """Segmentation from a box prompt and point prompts. 473 474 Args: 475 predictor: The segment anything predictor. 476 box: The box prompt. 477 points: The point prompts, given in the image coordinates system. 478 labels: The point labels, either positive or negative. 479 image_embeddings: Optional precomputed image embeddings. 480 Has to be passed if the predictor is not yet initialized. 481 i: Index for the image data. Required if the input data has three spatial dimensions 482 or a time dimension and two spatial dimensions. 483 multimask_output: Whether to return multiple or just a single mask. 484 return_all: Whether to return the score and logits in addition to the mask. 485 486 Returns: 487 The binary segmentation mask. 488 """ 489 def box_and_points_to_tile(prompts, shape, tile_shape, halo): 490 box, points, labels = prompts 491 tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 492 points, labels = point_prompts 493 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 494 if tile_id_box != tile_id: 495 raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.") 496 return tile_id, tile, (box, points, labels) 497 498 predictor, tile, prompts, shape = _initialize_predictor( 499 predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile 500 ) 501 box, points, labels = prompts 502 503 mask, scores, logits = predictor.predict( 504 point_coords=points[:, ::-1], # SAM has reversed XY conventions 505 point_labels=labels, 506 box=_process_box(box, shape), 507 multimask_output=multimask_output 508 ) 509 510 if tile is not None: 511 mask = _tile_to_full_mask(mask, shape, tile) 512 513 if return_all: 514 return mask, scores, logits 515 else: 516 return mask
Segmentation from a box prompt and point prompts.
Arguments:
- predictor: The segment anything predictor.
- box: The box prompt.
- points: The point prompts, given in the image coordinates system.
- labels: The point labels, either positive or negative.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized. i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- multimask_output: Whether to return multiple or just a single mask.
- return_all: Whether to return the score and logits in addition to the mask.
Returns:
The binary segmentation mask.