micro_sam.prompt_based_segmentation
Functions for prompt-based segmentation with Segment Anything.
1"""Functions for prompt-based segmentation with Segment Anything. 2""" 3 4import warnings 5from typing import Optional, Tuple 6 7import numpy as np 8from skimage.filters import gaussian 9from skimage.feature import peak_local_max 10from skimage.segmentation import find_boundaries 11from scipy.ndimage import distance_transform_edt 12 13import torch 14 15from nifty.tools import blocking 16 17from segment_anything.predictor import SamPredictor 18from segment_anything.utils.transforms import ResizeLongestSide 19 20from . import util 21 22 23# 24# helper functions for translating mask inputs into other prompts 25# 26 27 28# compute the bounding box from a mask. SAM expects the following input: 29# box (np.ndarray or None): A length 4 array given a box prompt to the model, in XYXY format. 30def _compute_box_from_mask(mask, original_size=None, box_extension=0): 31 coords = np.where(mask == 1) 32 min_y, min_x = coords[0].min(), coords[1].min() 33 max_y, max_x = coords[0].max(), coords[1].max() 34 box = np.array([min_y, min_x, max_y + 1, max_x + 1]) 35 return _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension) 36 37 38# sample points from a mask. SAM expects the following point inputs: 39def _compute_points_from_mask(mask, original_size, box_extension, use_single_point=False): 40 box = _compute_box_from_mask(mask, box_extension=box_extension) 41 42 # get slice and offset in python coordinate convention 43 bb = (slice(box[1], box[3]), slice(box[0], box[2])) 44 offset = np.array([box[1], box[0]]) 45 46 # crop the mask and compute distances 47 cropped_mask = mask[bb] 48 object_boundaries = find_boundaries(cropped_mask, mode="outer") 49 distances = gaussian(distance_transform_edt(object_boundaries == 0)) 50 inner_distances = distances.copy() 51 cropped_mask = cropped_mask.astype("bool") 52 inner_distances[~cropped_mask] = 0.0 53 if use_single_point: 54 center = inner_distances.argmax() 55 center = np.unravel_index(center, inner_distances.shape) 56 point_coords = (center + offset)[None] 57 point_labels = np.ones(1, dtype="uint8") 58 return point_coords[:, ::-1], point_labels 59 60 outer_distances = distances.copy() 61 outer_distances[cropped_mask] = 0.0 62 63 # sample positives and negatives from the distance maxima 64 inner_maxima = peak_local_max(inner_distances, exclude_border=False, min_distance=3) 65 outer_maxima = peak_local_max(outer_distances, exclude_border=False, min_distance=5) 66 67 # derive the positive (=inner maxima) and negative (=outer maxima) points 68 point_coords = np.concatenate([inner_maxima, outer_maxima]).astype("float64") 69 point_coords += offset 70 71 if original_size is not None: 72 scale_factor = np.array([ 73 original_size[0] / float(mask.shape[0]), original_size[1] / float(mask.shape[1]) 74 ])[None] 75 point_coords *= scale_factor 76 77 # get the point labels 78 point_labels = np.concatenate( 79 [np.ones(len(inner_maxima), dtype="uint8"), np.zeros(len(outer_maxima), dtype="uint8")] 80 ) 81 return point_coords[:, ::-1], point_labels 82 83 84def _compute_logits_from_mask(mask, eps=1e-3): 85 86 def inv_sigmoid(x): 87 return np.log(x / (1 - x)) 88 89 logits = np.zeros(mask.shape, dtype="float32") 90 logits[mask == 1] = 1 - eps 91 logits[mask == 0] = eps 92 logits = inv_sigmoid(logits) 93 94 # resize to the expected mask shape of SAM (256x256) 95 assert logits.ndim == 2 96 expected_shape = (256, 256) 97 98 if logits.shape == expected_shape: # shape matches, do nothing 99 pass 100 101 elif logits.shape[0] == logits.shape[1]: # shape is square 102 trafo = ResizeLongestSide(expected_shape[0]) 103 logits = trafo.apply_image_torch(torch.from_numpy(logits[None, None])) 104 logits = logits.numpy().squeeze() 105 106 else: # shape is not square 107 # resize the longest side to expected shape 108 trafo = ResizeLongestSide(expected_shape[0]) 109 logits = trafo.apply_image_torch(torch.from_numpy(logits[None, None])) 110 logits = logits.numpy().squeeze() 111 112 # pad the other side 113 h, w = logits.shape 114 padh = expected_shape[0] - h 115 padw = expected_shape[1] - w 116 # IMPORTANT: need to pad with zero, otherwise SAM doesn't understand the padding 117 pad_width = ((0, padh), (0, padw)) 118 logits = np.pad(logits, pad_width, mode="constant", constant_values=0) 119 120 logits = logits[None] 121 assert logits.shape == (1, 256, 256), f"{logits.shape}" 122 return logits 123 124 125# 126# other helper functions 127# 128 129 130def _process_box(box, shape, original_size=None, box_extension=0): 131 if box_extension == 0: # no extension 132 extension_y, extension_x = 0, 0 133 elif box_extension >= 1: # extension by a fixed factor 134 extension_y, extension_x = box_extension, box_extension 135 else: # extension by fraction of the box len 136 len_y, len_x = box[2] - box[0], box[3] - box[1] 137 extension_y, extension_x = box_extension * len_y, box_extension * len_x 138 139 box = np.array([ 140 max(box[1] - extension_x, 0), max(box[0] - extension_y, 0), 141 min(box[3] + extension_x, shape[1]), min(box[2] + extension_y, shape[0]), 142 ]) 143 144 if original_size is not None: 145 trafo = ResizeLongestSide(max(original_size)) 146 box = trafo.apply_boxes(box[None], (256, 256)).squeeze() 147 148 # round up the bounding box values 149 box = np.round(box).astype(int) 150 151 return box 152 153 154# Select the correct tile based on average of points 155# and bring the points to the coordinate system of the tile. 156# Discard points that are not in the tile and warn if this happens. 157def _points_to_tile(prompts, shape, tile_shape, halo): 158 points, labels = prompts 159 160 tiling = blocking([0, 0], shape, tile_shape) 161 center = np.mean(points, axis=0).round().astype("int").tolist() 162 tile_id = tiling.coordinatesToBlockId(center) 163 164 tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock 165 offset = tile.begin 166 this_tile_shape = tile.shape 167 168 points_in_tile = points - np.array(offset) 169 labels_in_tile = labels 170 171 valid_point_mask = (points_in_tile >= 0).all(axis=1) 172 valid_point_mask = np.logical_and( 173 valid_point_mask, 174 np.logical_and( 175 points_in_tile[:, 0] < this_tile_shape[0], points_in_tile[:, 1] < this_tile_shape[1] 176 ) 177 ) 178 if not valid_point_mask.all(): 179 points_in_tile = points_in_tile[valid_point_mask] 180 labels_in_tile = labels_in_tile[valid_point_mask] 181 warnings.warn( 182 f"{(~valid_point_mask).sum()} points were not in the tile and are dropped" 183 ) 184 185 return tile_id, tile, (points_in_tile, labels_in_tile) 186 187 188def _box_to_tile(box, shape, tile_shape, halo): 189 tiling = blocking([0, 0], shape, tile_shape) 190 center = np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2]).round().astype("int").tolist() 191 tile_id = tiling.coordinatesToBlockId(center) 192 193 tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock 194 offset = tile.begin 195 this_tile_shape = tile.shape 196 197 box_in_tile = np.array( 198 [ 199 max(box[0] - offset[0], 0), max(box[1] - offset[1], 0), 200 min(box[2] - offset[0], this_tile_shape[0]), min(box[3] - offset[1], this_tile_shape[1]) 201 ] 202 ) 203 204 return tile_id, tile, box_in_tile 205 206 207def _mask_to_tile(mask, shape, tile_shape, halo): 208 tiling = blocking([0, 0], shape, tile_shape) 209 210 coords = np.where(mask) 211 center = np.array([np.mean(coords[0]), np.mean(coords[1])]).round().astype("int").tolist() 212 tile_id = tiling.coordinatesToBlockId(center) 213 214 tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock 215 bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) 216 217 mask_in_tile = mask[bb] 218 return tile_id, tile, mask_in_tile 219 220 221def _initialize_predictor(predictor, image_embeddings, i, prompts, to_tile): 222 tile = None 223 224 # Set the precomputed state for tiled prediction. 225 if image_embeddings is not None and image_embeddings["input_size"] is None: 226 features = image_embeddings["features"] 227 shape, tile_shape, halo = features.attrs["shape"], features.attrs["tile_shape"], features.attrs["halo"] 228 tile_id, tile, prompts = to_tile(prompts, shape, tile_shape, halo) 229 util.set_precomputed(predictor, image_embeddings, i, tile_id=tile_id) 230 231 # Set the precomputed state for normal prediction. 232 elif image_embeddings is not None: 233 shape = image_embeddings["original_size"] 234 util.set_precomputed(predictor, image_embeddings, i) 235 236 else: 237 shape = predictor.original_size 238 239 return predictor, tile, prompts, shape 240 241 242def _tile_to_full_mask(mask, shape, tile): 243 full_mask = np.zeros(mask.shape[0:1] + tuple(shape), dtype=mask.dtype) 244 bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) 245 full_mask[(slice(None),) + bb] = mask 246 return full_mask 247 248 249# 250# functions for prompted segmentation: 251# - segment_from_points: use point prompts as input 252# - segment_from_mask: use binary mask as input, support conversion to mask, box and point prompts 253# - segment_from_box: use box prompt as input 254# - segment_from_box_and_points: use box and point prompts as input 255# 256 257 258def segment_from_points( 259 predictor: SamPredictor, 260 points: np.ndarray, 261 labels: np.ndarray, 262 image_embeddings: Optional[util.ImageEmbeddings] = None, 263 i: Optional[int] = None, 264 multimask_output: bool = False, 265 return_all: bool = False, 266 use_best_multimask: Optional[bool] = None, 267): 268 """Segmentation from point prompts. 269 270 Args: 271 predictor: The segment anything predictor. 272 points: The point prompts given in the image coordinate system. 273 labels: The labels (positive or negative) associated with the points. 274 image_embeddings: Optional precomputed image embeddings. 275 Has to be passed if the predictor is not yet initialized. 276 i: Index for the image data. Required if the input data has three spatial dimensions 277 or a time dimension and two spatial dimensions. 278 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 279 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 280 use_best_multimask: Whether to use multimask output and then choose the best mask. 281 By default this is used for a single positive point and not otherwise. 282 283 Returns: 284 The binary segmentation mask. 285 """ 286 predictor, tile, prompts, shape = _initialize_predictor( 287 predictor, image_embeddings, i, (points, labels), _points_to_tile 288 ) 289 points, labels = prompts 290 291 if use_best_multimask is None: 292 use_best_multimask = len(points) == 1 and labels[0] == 1 293 multimask_output_ = multimask_output or use_best_multimask 294 295 # predict the mask 296 mask, scores, logits = predictor.predict( 297 point_coords=points[:, ::-1], # SAM has reversed XY conventions 298 point_labels=labels, 299 multimask_output=multimask_output_, 300 ) 301 302 if use_best_multimask: 303 best_mask_id = np.argmax(scores) 304 mask = mask[best_mask_id][None] 305 306 if tile is not None: 307 mask = _tile_to_full_mask(mask, shape, tile) 308 309 if return_all: 310 return mask, scores, logits 311 else: 312 return mask 313 314 315def segment_from_mask( 316 predictor: SamPredictor, 317 mask: np.ndarray, 318 image_embeddings: Optional[util.ImageEmbeddings] = None, 319 i: Optional[int] = None, 320 use_box: bool = True, 321 use_mask: bool = True, 322 use_points: bool = False, 323 original_size: Optional[Tuple[int, ...]] = None, 324 multimask_output: bool = False, 325 return_all: bool = False, 326 return_logits: bool = False, 327 box_extension: float = 0.0, 328 box: Optional[np.ndarray] = None, 329 points: Optional[np.ndarray] = None, 330 labels: Optional[np.ndarray] = None, 331 use_single_point: bool = False, 332): 333 """Segmentation from a mask prompt. 334 335 Args: 336 predictor: The segment anything predictor. 337 mask: The mask used to derive prompts. 338 image_embeddings: Optional precomputed image embeddings. 339 Has to be passed if the predictor is not yet initialized. 340 i: Index for the image data. Required if the input data has three spatial dimensions 341 or a time dimension and two spatial dimensions. 342 use_box: Whether to derive the bounding box prompt from the mask. By default, set to 'True'. 343 use_mask: Whether to use the mask itself as prompt. By default, set to 'True'. 344 use_points: Whether to derive point prompts from the mask. By default, set to 'False'. 345 original_size: Full image shape. Use this if the mask that is being passed 346 downsampled compared to the original image. 347 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 348 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 349 box_extension: Relative factor used to enlarge the bounding box prompt. 350 By default, does not enlarge the bounding box. 351 box: Precomputed bounding box. 352 points: Precomputed point prompts. 353 labels: Positive/negative labels corresponding to the point prompts. 354 use_single_point: Whether to derive just a single point from the mask. 355 In case use_points is true. 356 357 Returns: 358 The binary segmentation mask. 359 """ 360 prompts = (mask, box, points, labels) 361 362 def _to_tile(prompts, shape, tile_shape, halo): 363 mask, box, points, labels = prompts 364 tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo) 365 if points is not None: 366 tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 367 if tile_id_points != tile_id: 368 raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.") 369 points, labels = point_prompts 370 if box is not None: 371 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 372 if tile_id_box != tile_id: 373 raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.") 374 return tile_id, tile, (mask, box, points, labels) 375 376 predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile) 377 mask, box, points, labels = prompts 378 379 if points is not None: 380 if labels is None: 381 raise ValueError("If points are passed you also need to pass labels.") 382 point_coords, point_labels = points, labels 383 384 elif use_points and mask.sum() != 0: 385 point_coords, point_labels = _compute_points_from_mask( 386 mask, original_size=original_size, box_extension=box_extension, 387 use_single_point=use_single_point, 388 ) 389 390 else: 391 point_coords, point_labels = None, None 392 393 if box is None: 394 box = _compute_box_from_mask( 395 mask, original_size=original_size, box_extension=box_extension 396 ) if use_box and mask.sum() != 0 else None 397 else: 398 box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension) 399 400 logits = _compute_logits_from_mask(mask) if use_mask else None 401 402 mask, scores, logits = predictor.predict( 403 point_coords=point_coords, point_labels=point_labels, 404 mask_input=logits, box=box, 405 multimask_output=multimask_output, return_logits=return_logits 406 ) 407 408 if tile is not None: 409 mask = _tile_to_full_mask(mask, shape, tile) 410 411 if return_all: 412 return mask, scores, logits 413 else: 414 return mask 415 416 417def segment_from_box( 418 predictor: SamPredictor, 419 box: np.ndarray, 420 image_embeddings: Optional[util.ImageEmbeddings] = None, 421 i: Optional[int] = None, 422 multimask_output: bool = False, 423 return_all: bool = False, 424 box_extension: float = 0.0, 425): 426 """Segmentation from a box prompt. 427 428 Args: 429 predictor: The segment anything predictor. 430 box: The box prompt. 431 image_embeddings: Optional precomputed image embeddings. 432 Has to be passed if the predictor is not yet initialized. 433 i: Index for the image data. Required if the input data has three spatial dimensions 434 or a time dimension and two spatial dimensions. 435 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 436 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 437 box_extension: Relative factor used to enlarge the bounding box prompt. 438 By default, does not enlarge the bounding box. 439 440 Returns: 441 The binary segmentation mask. 442 """ 443 predictor, tile, box, shape = _initialize_predictor( 444 predictor, image_embeddings, i, box, _box_to_tile 445 ) 446 mask, scores, logits = predictor.predict( 447 box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output 448 ) 449 450 if tile is not None: 451 mask = _tile_to_full_mask(mask, shape, tile) 452 453 if return_all: 454 return mask, scores, logits 455 else: 456 return mask 457 458 459def segment_from_box_and_points( 460 predictor: SamPredictor, 461 box: np.ndarray, 462 points: np.ndarray, 463 labels: np.ndarray, 464 image_embeddings: Optional[util.ImageEmbeddings] = None, 465 i: Optional[int] = None, 466 multimask_output: bool = False, 467 return_all: bool = False, 468): 469 """Segmentation from a box prompt and point prompts. 470 471 Args: 472 predictor: The segment anything predictor. 473 box: The box prompt. 474 points: The point prompts, given in the image coordinates system. 475 labels: The point labels, either positive or negative. 476 image_embeddings: Optional precomputed image embeddings. 477 Has to be passed if the predictor is not yet initialized. 478 i: Index for the image data. Required if the input data has three spatial dimensions 479 or a time dimension and two spatial dimensions. 480 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 481 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 482 483 Returns: 484 The binary segmentation mask. 485 """ 486 def box_and_points_to_tile(prompts, shape, tile_shape, halo): 487 box, points, labels = prompts 488 tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 489 points, labels = point_prompts 490 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 491 if tile_id_box != tile_id: 492 raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.") 493 return tile_id, tile, (box, points, labels) 494 495 predictor, tile, prompts, shape = _initialize_predictor( 496 predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile 497 ) 498 box, points, labels = prompts 499 500 mask, scores, logits = predictor.predict( 501 point_coords=points[:, ::-1], # SAM has reversed XY conventions 502 point_labels=labels, 503 box=_process_box(box, shape), 504 multimask_output=multimask_output 505 ) 506 507 if tile is not None: 508 mask = _tile_to_full_mask(mask, shape, tile) 509 510 if return_all: 511 return mask, scores, logits 512 else: 513 return mask
def
segment_from_points( predictor: segment_anything.predictor.SamPredictor, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, use_best_multimask: Optional[bool] = None):
259def segment_from_points( 260 predictor: SamPredictor, 261 points: np.ndarray, 262 labels: np.ndarray, 263 image_embeddings: Optional[util.ImageEmbeddings] = None, 264 i: Optional[int] = None, 265 multimask_output: bool = False, 266 return_all: bool = False, 267 use_best_multimask: Optional[bool] = None, 268): 269 """Segmentation from point prompts. 270 271 Args: 272 predictor: The segment anything predictor. 273 points: The point prompts given in the image coordinate system. 274 labels: The labels (positive or negative) associated with the points. 275 image_embeddings: Optional precomputed image embeddings. 276 Has to be passed if the predictor is not yet initialized. 277 i: Index for the image data. Required if the input data has three spatial dimensions 278 or a time dimension and two spatial dimensions. 279 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 280 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 281 use_best_multimask: Whether to use multimask output and then choose the best mask. 282 By default this is used for a single positive point and not otherwise. 283 284 Returns: 285 The binary segmentation mask. 286 """ 287 predictor, tile, prompts, shape = _initialize_predictor( 288 predictor, image_embeddings, i, (points, labels), _points_to_tile 289 ) 290 points, labels = prompts 291 292 if use_best_multimask is None: 293 use_best_multimask = len(points) == 1 and labels[0] == 1 294 multimask_output_ = multimask_output or use_best_multimask 295 296 # predict the mask 297 mask, scores, logits = predictor.predict( 298 point_coords=points[:, ::-1], # SAM has reversed XY conventions 299 point_labels=labels, 300 multimask_output=multimask_output_, 301 ) 302 303 if use_best_multimask: 304 best_mask_id = np.argmax(scores) 305 mask = mask[best_mask_id][None] 306 307 if tile is not None: 308 mask = _tile_to_full_mask(mask, shape, tile) 309 310 if return_all: 311 return mask, scores, logits 312 else: 313 return mask
Segmentation from point prompts.
Arguments:
- predictor: The segment anything predictor.
- points: The point prompts given in the image coordinate system.
- labels: The labels (positive or negative) associated with the points.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
- i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
- return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
- use_best_multimask: Whether to use multimask output and then choose the best mask. By default this is used for a single positive point and not otherwise.
Returns:
The binary segmentation mask.
def
segment_from_mask( predictor: segment_anything.predictor.SamPredictor, mask: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, use_box: bool = True, use_mask: bool = True, use_points: bool = False, original_size: Optional[Tuple[int, ...]] = None, multimask_output: bool = False, return_all: bool = False, return_logits: bool = False, box_extension: float = 0.0, box: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, labels: Optional[numpy.ndarray] = None, use_single_point: bool = False):
316def segment_from_mask( 317 predictor: SamPredictor, 318 mask: np.ndarray, 319 image_embeddings: Optional[util.ImageEmbeddings] = None, 320 i: Optional[int] = None, 321 use_box: bool = True, 322 use_mask: bool = True, 323 use_points: bool = False, 324 original_size: Optional[Tuple[int, ...]] = None, 325 multimask_output: bool = False, 326 return_all: bool = False, 327 return_logits: bool = False, 328 box_extension: float = 0.0, 329 box: Optional[np.ndarray] = None, 330 points: Optional[np.ndarray] = None, 331 labels: Optional[np.ndarray] = None, 332 use_single_point: bool = False, 333): 334 """Segmentation from a mask prompt. 335 336 Args: 337 predictor: The segment anything predictor. 338 mask: The mask used to derive prompts. 339 image_embeddings: Optional precomputed image embeddings. 340 Has to be passed if the predictor is not yet initialized. 341 i: Index for the image data. Required if the input data has three spatial dimensions 342 or a time dimension and two spatial dimensions. 343 use_box: Whether to derive the bounding box prompt from the mask. By default, set to 'True'. 344 use_mask: Whether to use the mask itself as prompt. By default, set to 'True'. 345 use_points: Whether to derive point prompts from the mask. By default, set to 'False'. 346 original_size: Full image shape. Use this if the mask that is being passed 347 downsampled compared to the original image. 348 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 349 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 350 box_extension: Relative factor used to enlarge the bounding box prompt. 351 By default, does not enlarge the bounding box. 352 box: Precomputed bounding box. 353 points: Precomputed point prompts. 354 labels: Positive/negative labels corresponding to the point prompts. 355 use_single_point: Whether to derive just a single point from the mask. 356 In case use_points is true. 357 358 Returns: 359 The binary segmentation mask. 360 """ 361 prompts = (mask, box, points, labels) 362 363 def _to_tile(prompts, shape, tile_shape, halo): 364 mask, box, points, labels = prompts 365 tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo) 366 if points is not None: 367 tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 368 if tile_id_points != tile_id: 369 raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.") 370 points, labels = point_prompts 371 if box is not None: 372 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 373 if tile_id_box != tile_id: 374 raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.") 375 return tile_id, tile, (mask, box, points, labels) 376 377 predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile) 378 mask, box, points, labels = prompts 379 380 if points is not None: 381 if labels is None: 382 raise ValueError("If points are passed you also need to pass labels.") 383 point_coords, point_labels = points, labels 384 385 elif use_points and mask.sum() != 0: 386 point_coords, point_labels = _compute_points_from_mask( 387 mask, original_size=original_size, box_extension=box_extension, 388 use_single_point=use_single_point, 389 ) 390 391 else: 392 point_coords, point_labels = None, None 393 394 if box is None: 395 box = _compute_box_from_mask( 396 mask, original_size=original_size, box_extension=box_extension 397 ) if use_box and mask.sum() != 0 else None 398 else: 399 box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension) 400 401 logits = _compute_logits_from_mask(mask) if use_mask else None 402 403 mask, scores, logits = predictor.predict( 404 point_coords=point_coords, point_labels=point_labels, 405 mask_input=logits, box=box, 406 multimask_output=multimask_output, return_logits=return_logits 407 ) 408 409 if tile is not None: 410 mask = _tile_to_full_mask(mask, shape, tile) 411 412 if return_all: 413 return mask, scores, logits 414 else: 415 return mask
Segmentation from a mask prompt.
Arguments:
- predictor: The segment anything predictor.
- mask: The mask used to derive prompts.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
- i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- use_box: Whether to derive the bounding box prompt from the mask. By default, set to 'True'.
- use_mask: Whether to use the mask itself as prompt. By default, set to 'True'.
- use_points: Whether to derive point prompts from the mask. By default, set to 'False'.
- original_size: Full image shape. Use this if the mask that is being passed downsampled compared to the original image.
- multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
- return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
- box_extension: Relative factor used to enlarge the bounding box prompt. By default, does not enlarge the bounding box.
- box: Precomputed bounding box.
- points: Precomputed point prompts.
- labels: Positive/negative labels corresponding to the point prompts.
- use_single_point: Whether to derive just a single point from the mask. In case use_points is true.
Returns:
The binary segmentation mask.
def
segment_from_box( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, box_extension: float = 0.0):
418def segment_from_box( 419 predictor: SamPredictor, 420 box: np.ndarray, 421 image_embeddings: Optional[util.ImageEmbeddings] = None, 422 i: Optional[int] = None, 423 multimask_output: bool = False, 424 return_all: bool = False, 425 box_extension: float = 0.0, 426): 427 """Segmentation from a box prompt. 428 429 Args: 430 predictor: The segment anything predictor. 431 box: The box prompt. 432 image_embeddings: Optional precomputed image embeddings. 433 Has to be passed if the predictor is not yet initialized. 434 i: Index for the image data. Required if the input data has three spatial dimensions 435 or a time dimension and two spatial dimensions. 436 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 437 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 438 box_extension: Relative factor used to enlarge the bounding box prompt. 439 By default, does not enlarge the bounding box. 440 441 Returns: 442 The binary segmentation mask. 443 """ 444 predictor, tile, box, shape = _initialize_predictor( 445 predictor, image_embeddings, i, box, _box_to_tile 446 ) 447 mask, scores, logits = predictor.predict( 448 box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output 449 ) 450 451 if tile is not None: 452 mask = _tile_to_full_mask(mask, shape, tile) 453 454 if return_all: 455 return mask, scores, logits 456 else: 457 return mask
Segmentation from a box prompt.
Arguments:
- predictor: The segment anything predictor.
- box: The box prompt.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
- i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
- return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
- box_extension: Relative factor used to enlarge the bounding box prompt. By default, does not enlarge the bounding box.
Returns:
The binary segmentation mask.
def
segment_from_box_and_points( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False):
460def segment_from_box_and_points( 461 predictor: SamPredictor, 462 box: np.ndarray, 463 points: np.ndarray, 464 labels: np.ndarray, 465 image_embeddings: Optional[util.ImageEmbeddings] = None, 466 i: Optional[int] = None, 467 multimask_output: bool = False, 468 return_all: bool = False, 469): 470 """Segmentation from a box prompt and point prompts. 471 472 Args: 473 predictor: The segment anything predictor. 474 box: The box prompt. 475 points: The point prompts, given in the image coordinates system. 476 labels: The point labels, either positive or negative. 477 image_embeddings: Optional precomputed image embeddings. 478 Has to be passed if the predictor is not yet initialized. 479 i: Index for the image data. Required if the input data has three spatial dimensions 480 or a time dimension and two spatial dimensions. 481 multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'. 482 return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'. 483 484 Returns: 485 The binary segmentation mask. 486 """ 487 def box_and_points_to_tile(prompts, shape, tile_shape, halo): 488 box, points, labels = prompts 489 tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 490 points, labels = point_prompts 491 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 492 if tile_id_box != tile_id: 493 raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.") 494 return tile_id, tile, (box, points, labels) 495 496 predictor, tile, prompts, shape = _initialize_predictor( 497 predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile 498 ) 499 box, points, labels = prompts 500 501 mask, scores, logits = predictor.predict( 502 point_coords=points[:, ::-1], # SAM has reversed XY conventions 503 point_labels=labels, 504 box=_process_box(box, shape), 505 multimask_output=multimask_output 506 ) 507 508 if tile is not None: 509 mask = _tile_to_full_mask(mask, shape, tile) 510 511 if return_all: 512 return mask, scores, logits 513 else: 514 return mask
Segmentation from a box prompt and point prompts.
Arguments:
- predictor: The segment anything predictor.
- box: The box prompt.
- points: The point prompts, given in the image coordinates system.
- labels: The point labels, either positive or negative.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
- i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- multimask_output: Whether to return multiple or just a single mask. By default, set to 'False'.
- return_all: Whether to return the score and logits in addition to the mask. By default, set to 'False'.
Returns:
The binary segmentation mask.