micro_sam.prompt_based_segmentation
Functions for prompt-based segmentation with Segment Anything.
1"""Functions for prompt-based segmentation with Segment Anything. 2""" 3 4import warnings 5from typing import Optional, Tuple 6 7import numpy as np 8from skimage.filters import gaussian 9from skimage.feature import peak_local_max 10from skimage.segmentation import find_boundaries 11from scipy.ndimage import distance_transform_edt 12 13import torch 14 15from nifty.tools import blocking 16 17from segment_anything.predictor import SamPredictor 18from segment_anything.utils.transforms import ResizeLongestSide 19 20from . import util 21 22 23# 24# helper functions for translating mask inputs into other prompts 25# 26 27 28# compute the bounding box from a mask. SAM expects the following input: 29# box (np.ndarray or None): A length 4 array given a box prompt to the model, in XYXY format. 30def _compute_box_from_mask(mask, original_size=None, box_extension=0): 31 coords = np.where(mask == 1) 32 min_y, min_x = coords[0].min(), coords[1].min() 33 max_y, max_x = coords[0].max(), coords[1].max() 34 box = np.array([min_y, min_x, max_y + 1, max_x + 1]) 35 return _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension) 36 37 38# sample points from a mask. SAM expects the following point inputs: 39def _compute_points_from_mask(mask, original_size, box_extension, use_single_point=False): 40 box = _compute_box_from_mask(mask, box_extension=box_extension) 41 42 # get slice and offset in python coordinate convention 43 bb = (slice(box[1], box[3]), slice(box[0], box[2])) 44 offset = np.array([box[1], box[0]]) 45 46 # crop the mask and compute distances 47 cropped_mask = mask[bb] 48 object_boundaries = find_boundaries(cropped_mask, mode="outer") 49 distances = gaussian(distance_transform_edt(object_boundaries == 0)) 50 inner_distances = distances.copy() 51 cropped_mask = cropped_mask.astype("bool") 52 inner_distances[~cropped_mask] = 0.0 53 if use_single_point: 54 center = inner_distances.argmax() 55 center = np.unravel_index(center, inner_distances.shape) 56 point_coords = (center + offset)[None] 57 point_labels = np.ones(1, dtype="uint8") 58 return point_coords[:, ::-1], point_labels 59 60 outer_distances = distances.copy() 61 outer_distances[cropped_mask] = 0.0 62 63 # sample positives and negatives from the distance maxima 64 inner_maxima = peak_local_max(inner_distances, exclude_border=False, min_distance=3) 65 outer_maxima = peak_local_max(outer_distances, exclude_border=False, min_distance=5) 66 67 # derive the positive (=inner maxima) and negative (=outer maxima) points 68 point_coords = np.concatenate([inner_maxima, outer_maxima]).astype("float64") 69 point_coords += offset 70 71 if original_size is not None: 72 scale_factor = np.array([ 73 original_size[0] / float(mask.shape[0]), original_size[1] / float(mask.shape[1]) 74 ])[None] 75 point_coords *= scale_factor 76 77 # get the point labels 78 point_labels = np.concatenate( 79 [np.ones(len(inner_maxima), dtype="uint8"), np.zeros(len(outer_maxima), dtype="uint8")] 80 ) 81 return point_coords[:, ::-1], point_labels 82 83 84def _compute_logits_from_mask(mask, eps=1e-3): 85 86 def inv_sigmoid(x): 87 return np.log(x / (1 - x)) 88 89 logits = np.zeros(mask.shape, dtype="float32") 90 logits[mask == 1] = 1 - eps 91 logits[mask == 0] = eps 92 logits = inv_sigmoid(logits) 93 94 # resize to the expected mask shape of SAM (256x256) 95 assert logits.ndim == 2 96 expected_shape = (256, 256) 97 98 if logits.shape == expected_shape: # shape matches, do nothing 99 pass 100 101 elif logits.shape[0] == logits.shape[1]: # shape is square 102 trafo = ResizeLongestSide(expected_shape[0]) 103 logits = trafo.apply_image_torch(torch.from_numpy(logits[None, None])) 104 logits = logits.numpy().squeeze() 105 106 else: # shape is not square 107 # resize the longest side to expected shape 108 trafo = ResizeLongestSide(expected_shape[0]) 109 logits = trafo.apply_image_torch(torch.from_numpy(logits[None, None])) 110 logits = logits.numpy().squeeze() 111 112 # pad the other side 113 h, w = logits.shape 114 padh = expected_shape[0] - h 115 padw = expected_shape[1] - w 116 # IMPORTANT: need to pad with zero, otherwise SAM doesn't understand the padding 117 pad_width = ((0, padh), (0, padw)) 118 logits = np.pad(logits, pad_width, mode="constant", constant_values=0) 119 120 logits = logits[None] 121 assert logits.shape == (1, 256, 256), f"{logits.shape}" 122 return logits 123 124 125# 126# other helper functions 127# 128 129 130def _process_box(box, shape, original_size=None, box_extension=0): 131 if box_extension == 0: # no extension 132 extension_y, extension_x = 0, 0 133 elif box_extension >= 1: # extension by a fixed factor 134 extension_y, extension_x = box_extension, box_extension 135 else: # extension by fraction of the box len 136 len_y, len_x = box[2] - box[0], box[3] - box[1] 137 extension_y, extension_x = box_extension * len_y, box_extension * len_x 138 139 box = np.array([ 140 max(box[1] - extension_x, 0), max(box[0] - extension_y, 0), 141 min(box[3] + extension_x, shape[1]), min(box[2] + extension_y, shape[0]), 142 ]) 143 144 if original_size is not None: 145 trafo = ResizeLongestSide(max(original_size)) 146 box = trafo.apply_boxes(box[None], (256, 256)).squeeze() 147 148 # round up the bounding box values 149 box = np.round(box).astype(int) 150 151 return box 152 153 154# Select the correct tile based on average of points 155# and bring the points to the coordinate system of the tile. 156# Discard points that are not in the tile and warn if this happens. 157def _points_to_tile(prompts, shape, tile_shape, halo): 158 points, labels = prompts 159 160 tiling = blocking([0, 0], shape, tile_shape) 161 center = np.mean(points, axis=0).round().astype("int").tolist() 162 tile_id = tiling.coordinatesToBlockId(center) 163 164 tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock 165 offset = tile.begin 166 this_tile_shape = tile.shape 167 168 points_in_tile = points - np.array(offset) 169 labels_in_tile = labels 170 171 valid_point_mask = (points_in_tile >= 0).all(axis=1) 172 valid_point_mask = np.logical_and( 173 valid_point_mask, 174 np.logical_and( 175 points_in_tile[:, 0] < this_tile_shape[0], points_in_tile[:, 1] < this_tile_shape[1] 176 ) 177 ) 178 if not valid_point_mask.all(): 179 points_in_tile = points_in_tile[valid_point_mask] 180 labels_in_tile = labels_in_tile[valid_point_mask] 181 warnings.warn( 182 f"{(~valid_point_mask).sum()} points were not in the tile and are dropped" 183 ) 184 185 return tile_id, tile, (points_in_tile, labels_in_tile) 186 187 188def _box_to_tile(box, shape, tile_shape, halo): 189 tiling = blocking([0, 0], shape, tile_shape) 190 center = np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2]).round().astype("int").tolist() 191 tile_id = tiling.coordinatesToBlockId(center) 192 193 tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock 194 offset = tile.begin 195 this_tile_shape = tile.shape 196 197 box_in_tile = np.array( 198 [ 199 max(box[0] - offset[0], 0), max(box[1] - offset[1], 0), 200 min(box[2] - offset[0], this_tile_shape[0]), min(box[3] - offset[1], this_tile_shape[1]) 201 ] 202 ) 203 204 return tile_id, tile, box_in_tile 205 206 207def _mask_to_tile(mask, shape, tile_shape, halo): 208 tiling = blocking([0, 0], shape, tile_shape) 209 210 coords = np.where(mask) 211 center = np.array([np.mean(coords[0]), np.mean(coords[1])]).round().astype("int").tolist() 212 tile_id = tiling.coordinatesToBlockId(center) 213 214 tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock 215 bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) 216 217 mask_in_tile = mask[bb] 218 return tile_id, tile, mask_in_tile 219 220 221def _initialize_predictor(predictor, image_embeddings, i, prompts, to_tile): 222 tile = None 223 224 # Set the precomputed state for tiled prediction. 225 if image_embeddings is not None and image_embeddings["input_size"] is None: 226 features = image_embeddings["features"] 227 shape, tile_shape, halo = features.attrs["shape"], features.attrs["tile_shape"], features.attrs["halo"] 228 tile_id, tile, prompts = to_tile(prompts, shape, tile_shape, halo) 229 util.set_precomputed(predictor, image_embeddings, i, tile_id=tile_id) 230 231 # Set the precomputed state for normal prediction. 232 elif image_embeddings is not None: 233 shape = image_embeddings["original_size"] 234 util.set_precomputed(predictor, image_embeddings, i) 235 236 else: 237 shape = predictor.original_size 238 239 return predictor, tile, prompts, shape 240 241 242def _tile_to_full_mask(mask, shape, tile): 243 full_mask = np.zeros(mask.shape[0:1] + tuple(shape), dtype=mask.dtype) 244 bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) 245 full_mask[(slice(None),) + bb] = mask 246 return full_mask 247 248 249# 250# functions for prompted segmentation: 251# - segment_from_points: use point prompts as input 252# - segment_from_mask: use binary mask as input, support conversion to mask, box and point prompts 253# - segment_from_box: use box prompt as input 254# - segment_from_box_and_points: use box and point prompts as input 255# 256 257 258def segment_from_points( 259 predictor: SamPredictor, 260 points: np.ndarray, 261 labels: np.ndarray, 262 image_embeddings: Optional[util.ImageEmbeddings] = None, 263 i: Optional[int] = None, 264 multimask_output: bool = False, 265 return_all: bool = False, 266 use_best_multimask: Optional[bool] = None, 267): 268 """Segmentation from point prompts. 269 270 Args: 271 predictor: The segment anything predictor. 272 points: The point prompts given in the image coordinate system. 273 labels: The labels (positive or negative) associated with the points. 274 image_embeddings: Optional precomputed image embeddings. 275 Has to be passed if the predictor is not yet initialized. 276 i: Index for the image data. Required if the input data has three spatial dimensions 277 or a time dimension and two spatial dimensions. 278 multimask_output: Whether to return multiple or just a single mask. 279 return_all: Whether to return the score and logits in addition to the mask. 280 use_best_multimask: Whether to use multimask output and then choose the best mask. 281 By default this is used for a single positive point and not otherwise. 282 283 Returns: 284 The binary segmentation mask. 285 """ 286 predictor, tile, prompts, shape = _initialize_predictor( 287 predictor, image_embeddings, i, (points, labels), _points_to_tile 288 ) 289 points, labels = prompts 290 291 if use_best_multimask is None: 292 use_best_multimask = len(points) == 1 and labels[0] == 1 293 multimask_output_ = multimask_output or use_best_multimask 294 295 # predict the mask 296 mask, scores, logits = predictor.predict( 297 point_coords=points[:, ::-1], # SAM has reversed XY conventions 298 point_labels=labels, 299 multimask_output=multimask_output_, 300 ) 301 302 if use_best_multimask: 303 best_mask_id = np.argmax(scores) 304 mask = mask[best_mask_id][None] 305 306 if tile is not None: 307 mask = _tile_to_full_mask(mask, shape, tile) 308 309 if return_all: 310 return mask, scores, logits 311 else: 312 return mask 313 314 315def segment_from_mask( 316 predictor: SamPredictor, 317 mask: np.ndarray, 318 image_embeddings: Optional[util.ImageEmbeddings] = None, 319 i: Optional[int] = None, 320 use_box: bool = True, 321 use_mask: bool = True, 322 use_points: bool = False, 323 original_size: Optional[Tuple[int, ...]] = None, 324 multimask_output: bool = False, 325 return_all: bool = False, 326 return_logits: bool = False, 327 box_extension: float = 0.0, 328 box: Optional[np.ndarray] = None, 329 points: Optional[np.ndarray] = None, 330 labels: Optional[np.ndarray] = None, 331 use_single_point: bool = False, 332): 333 """Segmentation from a mask prompt. 334 335 Args: 336 predictor: The segment anything predictor. 337 mask: The mask used to derive prompts. 338 image_embeddings: Optional precomputed image embeddings. 339 Has to be passed if the predictor is not yet initialized. 340 i: Index for the image data. Required if the input data has three spatial dimensions 341 or a time dimension and two spatial dimensions. 342 use_box: Whether to derive the bounding box prompt from the mask. 343 use_mask: Whether to use the mask itself as prompt. 344 use_points: Whether to derive point prompts from the mask. 345 original_size: Full image shape. Use this if the mask that is being passed 346 downsampled compared to the original image. 347 multimask_output: Whether to return multiple or just a single mask. 348 return_all: Whether to return the score and logits in addition to the mask. 349 box_extension: Relative factor used to enlarge the bounding box prompt. 350 box: Precomputed bounding box. 351 points: Precomputed point prompts. 352 labels: Positive/negative labels corresponding to the point prompts. 353 use_single_point: Whether to derive just a single point from the mask. 354 In case use_points is true. 355 356 Returns: 357 The binary segmentation mask. 358 """ 359 prompts = (mask, box, points, labels) 360 361 def _to_tile(prompts, shape, tile_shape, halo): 362 mask, box, points, labels = prompts 363 tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo) 364 if points is not None: 365 tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 366 if tile_id_points != tile_id: 367 raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.") 368 points, labels = point_prompts 369 if box is not None: 370 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 371 if tile_id_box != tile_id: 372 raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.") 373 return tile_id, tile, (mask, box, points, labels) 374 375 predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile) 376 mask, box, points, labels = prompts 377 378 if points is not None: 379 if labels is None: 380 raise ValueError("If points are passed you also need to pass labels.") 381 point_coords, point_labels = points, labels 382 383 elif use_points and mask.sum() != 0: 384 point_coords, point_labels = _compute_points_from_mask( 385 mask, original_size=original_size, box_extension=box_extension, 386 use_single_point=use_single_point, 387 ) 388 389 else: 390 point_coords, point_labels = None, None 391 392 if box is None: 393 box = _compute_box_from_mask( 394 mask, original_size=original_size, box_extension=box_extension 395 ) if use_box and mask.sum() != 0 else None 396 else: 397 box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension) 398 399 logits = _compute_logits_from_mask(mask) if use_mask else None 400 401 mask, scores, logits = predictor.predict( 402 point_coords=point_coords, point_labels=point_labels, 403 mask_input=logits, box=box, 404 multimask_output=multimask_output, return_logits=return_logits 405 ) 406 407 if tile is not None: 408 mask = _tile_to_full_mask(mask, shape, tile) 409 410 if return_all: 411 return mask, scores, logits 412 else: 413 return mask 414 415 416def segment_from_box( 417 predictor: SamPredictor, 418 box: np.ndarray, 419 image_embeddings: Optional[util.ImageEmbeddings] = None, 420 i: Optional[int] = None, 421 multimask_output: bool = False, 422 return_all: bool = False, 423 box_extension: float = 0.0, 424): 425 """Segmentation from a box prompt. 426 427 Args: 428 predictor: The segment anything predictor. 429 box: The box prompt. 430 image_embeddings: Optional precomputed image embeddings. 431 Has to be passed if the predictor is not yet initialized. 432 i: Index for the image data. Required if the input data has three spatial dimensions 433 or a time dimension and two spatial dimensions. 434 multimask_output: Whether to return multiple or just a single mask. 435 return_all: Whether to return the score and logits in addition to the mask. 436 box_extension: Relative factor used to enlarge the bounding box prompt. 437 438 Returns: 439 The binary segmentation mask. 440 """ 441 predictor, tile, box, shape = _initialize_predictor( 442 predictor, image_embeddings, i, box, _box_to_tile 443 ) 444 mask, scores, logits = predictor.predict( 445 box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output 446 ) 447 448 if tile is not None: 449 mask = _tile_to_full_mask(mask, shape, tile) 450 451 if return_all: 452 return mask, scores, logits 453 else: 454 return mask 455 456 457def segment_from_box_and_points( 458 predictor: SamPredictor, 459 box: np.ndarray, 460 points: np.ndarray, 461 labels: np.ndarray, 462 image_embeddings: Optional[util.ImageEmbeddings] = None, 463 i: Optional[int] = None, 464 multimask_output: bool = False, 465 return_all: bool = False, 466): 467 """Segmentation from a box prompt and point prompts. 468 469 Args: 470 predictor: The segment anything predictor. 471 box: The box prompt. 472 points: The point prompts, given in the image coordinates system. 473 labels: The point labels, either positive or negative. 474 image_embeddings: Optional precomputed image embeddings. 475 Has to be passed if the predictor is not yet initialized. 476 i: Index for the image data. Required if the input data has three spatial dimensions 477 or a time dimension and two spatial dimensions. 478 multimask_output: Whether to return multiple or just a single mask. 479 return_all: Whether to return the score and logits in addition to the mask. 480 481 Returns: 482 The binary segmentation mask. 483 """ 484 def box_and_points_to_tile(prompts, shape, tile_shape, halo): 485 box, points, labels = prompts 486 tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 487 points, labels = point_prompts 488 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 489 if tile_id_box != tile_id: 490 raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.") 491 return tile_id, tile, (box, points, labels) 492 493 predictor, tile, prompts, shape = _initialize_predictor( 494 predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile 495 ) 496 box, points, labels = prompts 497 498 mask, scores, logits = predictor.predict( 499 point_coords=points[:, ::-1], # SAM has reversed XY conventions 500 point_labels=labels, 501 box=_process_box(box, shape), 502 multimask_output=multimask_output 503 ) 504 505 if tile is not None: 506 mask = _tile_to_full_mask(mask, shape, tile) 507 508 if return_all: 509 return mask, scores, logits 510 else: 511 return mask
def
segment_from_points( predictor: segment_anything.predictor.SamPredictor, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, use_best_multimask: Optional[bool] = None):
259def segment_from_points( 260 predictor: SamPredictor, 261 points: np.ndarray, 262 labels: np.ndarray, 263 image_embeddings: Optional[util.ImageEmbeddings] = None, 264 i: Optional[int] = None, 265 multimask_output: bool = False, 266 return_all: bool = False, 267 use_best_multimask: Optional[bool] = None, 268): 269 """Segmentation from point prompts. 270 271 Args: 272 predictor: The segment anything predictor. 273 points: The point prompts given in the image coordinate system. 274 labels: The labels (positive or negative) associated with the points. 275 image_embeddings: Optional precomputed image embeddings. 276 Has to be passed if the predictor is not yet initialized. 277 i: Index for the image data. Required if the input data has three spatial dimensions 278 or a time dimension and two spatial dimensions. 279 multimask_output: Whether to return multiple or just a single mask. 280 return_all: Whether to return the score and logits in addition to the mask. 281 use_best_multimask: Whether to use multimask output and then choose the best mask. 282 By default this is used for a single positive point and not otherwise. 283 284 Returns: 285 The binary segmentation mask. 286 """ 287 predictor, tile, prompts, shape = _initialize_predictor( 288 predictor, image_embeddings, i, (points, labels), _points_to_tile 289 ) 290 points, labels = prompts 291 292 if use_best_multimask is None: 293 use_best_multimask = len(points) == 1 and labels[0] == 1 294 multimask_output_ = multimask_output or use_best_multimask 295 296 # predict the mask 297 mask, scores, logits = predictor.predict( 298 point_coords=points[:, ::-1], # SAM has reversed XY conventions 299 point_labels=labels, 300 multimask_output=multimask_output_, 301 ) 302 303 if use_best_multimask: 304 best_mask_id = np.argmax(scores) 305 mask = mask[best_mask_id][None] 306 307 if tile is not None: 308 mask = _tile_to_full_mask(mask, shape, tile) 309 310 if return_all: 311 return mask, scores, logits 312 else: 313 return mask
Segmentation from point prompts.
Arguments:
- predictor: The segment anything predictor.
- points: The point prompts given in the image coordinate system.
- labels: The labels (positive or negative) associated with the points.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized. i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- multimask_output: Whether to return multiple or just a single mask.
- return_all: Whether to return the score and logits in addition to the mask.
- use_best_multimask: Whether to use multimask output and then choose the best mask. By default this is used for a single positive point and not otherwise.
Returns:
The binary segmentation mask.
def
segment_from_mask( predictor: segment_anything.predictor.SamPredictor, mask: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, use_box: bool = True, use_mask: bool = True, use_points: bool = False, original_size: Optional[Tuple[int, ...]] = None, multimask_output: bool = False, return_all: bool = False, return_logits: bool = False, box_extension: float = 0.0, box: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, labels: Optional[numpy.ndarray] = None, use_single_point: bool = False):
316def segment_from_mask( 317 predictor: SamPredictor, 318 mask: np.ndarray, 319 image_embeddings: Optional[util.ImageEmbeddings] = None, 320 i: Optional[int] = None, 321 use_box: bool = True, 322 use_mask: bool = True, 323 use_points: bool = False, 324 original_size: Optional[Tuple[int, ...]] = None, 325 multimask_output: bool = False, 326 return_all: bool = False, 327 return_logits: bool = False, 328 box_extension: float = 0.0, 329 box: Optional[np.ndarray] = None, 330 points: Optional[np.ndarray] = None, 331 labels: Optional[np.ndarray] = None, 332 use_single_point: bool = False, 333): 334 """Segmentation from a mask prompt. 335 336 Args: 337 predictor: The segment anything predictor. 338 mask: The mask used to derive prompts. 339 image_embeddings: Optional precomputed image embeddings. 340 Has to be passed if the predictor is not yet initialized. 341 i: Index for the image data. Required if the input data has three spatial dimensions 342 or a time dimension and two spatial dimensions. 343 use_box: Whether to derive the bounding box prompt from the mask. 344 use_mask: Whether to use the mask itself as prompt. 345 use_points: Whether to derive point prompts from the mask. 346 original_size: Full image shape. Use this if the mask that is being passed 347 downsampled compared to the original image. 348 multimask_output: Whether to return multiple or just a single mask. 349 return_all: Whether to return the score and logits in addition to the mask. 350 box_extension: Relative factor used to enlarge the bounding box prompt. 351 box: Precomputed bounding box. 352 points: Precomputed point prompts. 353 labels: Positive/negative labels corresponding to the point prompts. 354 use_single_point: Whether to derive just a single point from the mask. 355 In case use_points is true. 356 357 Returns: 358 The binary segmentation mask. 359 """ 360 prompts = (mask, box, points, labels) 361 362 def _to_tile(prompts, shape, tile_shape, halo): 363 mask, box, points, labels = prompts 364 tile_id, tile, mask = _mask_to_tile(mask, shape, tile_shape, halo) 365 if points is not None: 366 tile_id_points, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 367 if tile_id_points != tile_id: 368 raise RuntimeError(f"Inconsistent tile ids for mask and point prompts: {tile_id_points} != {tile_id}.") 369 points, labels = point_prompts 370 if box is not None: 371 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 372 if tile_id_box != tile_id: 373 raise RuntimeError(f"Inconsistent tile ids for mask and box prompts: {tile_id_box} != {tile_id}.") 374 return tile_id, tile, (mask, box, points, labels) 375 376 predictor, tile, prompts, shape = _initialize_predictor(predictor, image_embeddings, i, prompts, _to_tile) 377 mask, box, points, labels = prompts 378 379 if points is not None: 380 if labels is None: 381 raise ValueError("If points are passed you also need to pass labels.") 382 point_coords, point_labels = points, labels 383 384 elif use_points and mask.sum() != 0: 385 point_coords, point_labels = _compute_points_from_mask( 386 mask, original_size=original_size, box_extension=box_extension, 387 use_single_point=use_single_point, 388 ) 389 390 else: 391 point_coords, point_labels = None, None 392 393 if box is None: 394 box = _compute_box_from_mask( 395 mask, original_size=original_size, box_extension=box_extension 396 ) if use_box and mask.sum() != 0 else None 397 else: 398 box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension) 399 400 logits = _compute_logits_from_mask(mask) if use_mask else None 401 402 mask, scores, logits = predictor.predict( 403 point_coords=point_coords, point_labels=point_labels, 404 mask_input=logits, box=box, 405 multimask_output=multimask_output, return_logits=return_logits 406 ) 407 408 if tile is not None: 409 mask = _tile_to_full_mask(mask, shape, tile) 410 411 if return_all: 412 return mask, scores, logits 413 else: 414 return mask
Segmentation from a mask prompt.
Arguments:
- predictor: The segment anything predictor.
- mask: The mask used to derive prompts.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized.
- i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- use_box: Whether to derive the bounding box prompt from the mask.
- use_mask: Whether to use the mask itself as prompt.
- use_points: Whether to derive point prompts from the mask.
- original_size: Full image shape. Use this if the mask that is being passed downsampled compared to the original image.
- multimask_output: Whether to return multiple or just a single mask.
- return_all: Whether to return the score and logits in addition to the mask.
- box_extension: Relative factor used to enlarge the bounding box prompt.
- box: Precomputed bounding box.
- points: Precomputed point prompts.
- labels: Positive/negative labels corresponding to the point prompts.
- use_single_point: Whether to derive just a single point from the mask. In case use_points is true.
Returns:
The binary segmentation mask.
def
segment_from_box( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False, box_extension: float = 0.0):
417def segment_from_box( 418 predictor: SamPredictor, 419 box: np.ndarray, 420 image_embeddings: Optional[util.ImageEmbeddings] = None, 421 i: Optional[int] = None, 422 multimask_output: bool = False, 423 return_all: bool = False, 424 box_extension: float = 0.0, 425): 426 """Segmentation from a box prompt. 427 428 Args: 429 predictor: The segment anything predictor. 430 box: The box prompt. 431 image_embeddings: Optional precomputed image embeddings. 432 Has to be passed if the predictor is not yet initialized. 433 i: Index for the image data. Required if the input data has three spatial dimensions 434 or a time dimension and two spatial dimensions. 435 multimask_output: Whether to return multiple or just a single mask. 436 return_all: Whether to return the score and logits in addition to the mask. 437 box_extension: Relative factor used to enlarge the bounding box prompt. 438 439 Returns: 440 The binary segmentation mask. 441 """ 442 predictor, tile, box, shape = _initialize_predictor( 443 predictor, image_embeddings, i, box, _box_to_tile 444 ) 445 mask, scores, logits = predictor.predict( 446 box=_process_box(box, shape, box_extension=box_extension), multimask_output=multimask_output 447 ) 448 449 if tile is not None: 450 mask = _tile_to_full_mask(mask, shape, tile) 451 452 if return_all: 453 return mask, scores, logits 454 else: 455 return mask
Segmentation from a box prompt.
Arguments:
- predictor: The segment anything predictor.
- box: The box prompt.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized. i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- multimask_output: Whether to return multiple or just a single mask.
- return_all: Whether to return the score and logits in addition to the mask.
- box_extension: Relative factor used to enlarge the bounding box prompt.
Returns:
The binary segmentation mask.
def
segment_from_box_and_points( predictor: segment_anything.predictor.SamPredictor, box: numpy.ndarray, points: numpy.ndarray, labels: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, multimask_output: bool = False, return_all: bool = False):
458def segment_from_box_and_points( 459 predictor: SamPredictor, 460 box: np.ndarray, 461 points: np.ndarray, 462 labels: np.ndarray, 463 image_embeddings: Optional[util.ImageEmbeddings] = None, 464 i: Optional[int] = None, 465 multimask_output: bool = False, 466 return_all: bool = False, 467): 468 """Segmentation from a box prompt and point prompts. 469 470 Args: 471 predictor: The segment anything predictor. 472 box: The box prompt. 473 points: The point prompts, given in the image coordinates system. 474 labels: The point labels, either positive or negative. 475 image_embeddings: Optional precomputed image embeddings. 476 Has to be passed if the predictor is not yet initialized. 477 i: Index for the image data. Required if the input data has three spatial dimensions 478 or a time dimension and two spatial dimensions. 479 multimask_output: Whether to return multiple or just a single mask. 480 return_all: Whether to return the score and logits in addition to the mask. 481 482 Returns: 483 The binary segmentation mask. 484 """ 485 def box_and_points_to_tile(prompts, shape, tile_shape, halo): 486 box, points, labels = prompts 487 tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo) 488 points, labels = point_prompts 489 tile_id_box, tile, box = _box_to_tile(box, shape, tile_shape, halo) 490 if tile_id_box != tile_id: 491 raise RuntimeError(f"Inconsistent tile ids for box and point annotations: {tile_id_box} != {tile_id}.") 492 return tile_id, tile, (box, points, labels) 493 494 predictor, tile, prompts, shape = _initialize_predictor( 495 predictor, image_embeddings, i, (box, points, labels), box_and_points_to_tile 496 ) 497 box, points, labels = prompts 498 499 mask, scores, logits = predictor.predict( 500 point_coords=points[:, ::-1], # SAM has reversed XY conventions 501 point_labels=labels, 502 box=_process_box(box, shape), 503 multimask_output=multimask_output 504 ) 505 506 if tile is not None: 507 mask = _tile_to_full_mask(mask, shape, tile) 508 509 if return_all: 510 return mask, scores, logits 511 else: 512 return mask
Segmentation from a box prompt and point prompts.
Arguments:
- predictor: The segment anything predictor.
- box: The box prompt.
- points: The point prompts, given in the image coordinates system.
- labels: The point labels, either positive or negative.
- image_embeddings: Optional precomputed image embeddings. Has to be passed if the predictor is not yet initialized. i: Index for the image data. Required if the input data has three spatial dimensions or a time dimension and two spatial dimensions.
- multimask_output: Whether to return multiple or just a single mask.
- return_all: Whether to return the score and logits in addition to the mask.
Returns:
The binary segmentation mask.