micro_sam.inference
1import os 2import gc 3from typing import Any, Dict, List, Optional, Union, Tuple 4 5import numpy as np 6import torch 7import torch.nn.functional as F 8from nifty.tools import blocking 9import nifty.ground_truth as ngt 10 11import segment_anything.utils.amg as amg_utils 12from segment_anything import SamPredictor 13from segment_anything.utils.transforms import ResizeLongestSide 14try: 15 from napari.utils import progress as tqdm 16except ImportError: 17 from tqdm import tqdm 18 19from . import util 20from ._vendored import batched_mask_to_box 21 22 23def _validate_inputs( 24 boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks 25): 26 if multimasking and (segmentation_ids is not None) and (not return_instance_segmentation): 27 raise NotImplementedError 28 29 if (points is None) != (point_labels is None): 30 raise ValueError( 31 "If you have point prompts both `points` and `point_labels` have to be passed, " 32 "but you passed only one of them." 33 ) 34 35 have_points = points is not None 36 have_boxes = boxes is not None 37 have_logits = logits_masks is not None 38 if (not have_points) and (not have_boxes): 39 raise ValueError("Point and/or box prompts have to be passed, you passed neither.") 40 41 if have_points and (len(point_labels) != len(points)): 42 raise ValueError( 43 f"The number of point coordinates and labels does not match: {len(point_labels)} != {len(points)}" 44 ) 45 46 if (have_points and have_boxes) and (len(points) != len(boxes)): 47 raise ValueError( 48 f"The number of point and box prompts does not match: {len(points)} != {len(boxes)}" 49 ) 50 51 if have_logits: 52 if have_points and (len(logits_masks) != len(point_labels)): 53 raise ValueError( 54 f"The number of point and logits does not match: {len(points) != len(logits_masks)}" 55 ) 56 elif have_boxes and (len(logits_masks) != len(boxes)): 57 raise ValueError( 58 f"The number of boxes and logits does not match: {len(boxes)} != {len(logits_masks)}" 59 ) 60 61 n_prompts = boxes.shape[0] if have_boxes else points.shape[0] 62 63 if (segmentation_ids is not None) and (len(segmentation_ids) != n_prompts): 64 raise ValueError( 65 f"The number of segmentation ids and prompts does not match: {len(segmentation_ids)} != {n_prompts}" 66 ) 67 68 return n_prompts, have_boxes, have_points, have_logits 69 70 71def _local_otsu_threshold(images: torch.Tensor, window_size: int = 31, num_bins: int = 64, eps: float = 1e-6): 72 x = images 73 B, _, H, W = x.shape 74 device = x.device 75 76 # Work in float32 for stability even if input is fp16 77 x = x.to(torch.float32) 78 79 # --- per-image min/max for normalization to [0, 1] --- 80 x_flat = x.view(B, -1) 81 x_min = x_flat.min(dim=1).values.view(B, 1, 1, 1) 82 x_max = x_flat.max(dim=1).values.view(B, 1, 1, 1) 83 x_range = (x_max - x_min).clamp_min(eps) 84 85 x_norm = (x - x_min) / x_range # (B,1,H,W), in [0,1] 86 87 # --- extract local patches via unfold --- 88 pad = window_size // 2 89 patches = F.unfold(x_norm, kernel_size=window_size, padding=pad) # (B, P, L) 90 # P = window_size * window_size, L = H * W 91 B_, P, L = patches.shape 92 93 # --- quantize to bins --- 94 bin_idx = (patches * (num_bins - 1)).long().clamp(0, num_bins - 1) # (B, P, L) 95 96 # --- build histograms per patch --- 97 # one_hot: (B, L, num_bins) 98 one_hot = torch.zeros(B, L, num_bins, device=device, dtype=torch.float32) 99 idx = bin_idx.transpose(1, 2) # (B, L, P) 100 src = torch.ones_like(idx, dtype=one_hot.dtype) # (B, L, P) 101 one_hot.scatter_add_(2, idx, src) 102 # hist: (B, num_bins, L) 103 hist = one_hot.permute(0, 2, 1) 104 105 # --- Otsu per patch (vectorized) --- 106 # p: (B, bins, L) 107 p = hist / hist.sum(dim=1, keepdim=True).clamp_min(eps) 108 109 bins = torch.arange(num_bins, device=device, dtype=torch.float32).view(1, num_bins, 1) 110 111 omega1 = torch.cumsum(p, dim=1) # (B, bins, L) 112 mu = torch.cumsum(p * bins, dim=1) # (B, bins, L) 113 mu_T = mu[:, -1:, :] # (B, 1, L) 114 115 omega2 = 1.0 - omega1 116 117 mu1 = mu / omega1.clamp_min(eps) 118 mu2 = (mu_T - mu) / omega2.clamp_min(eps) 119 120 sigma_b2 = omega1 * omega2 * (mu1 - mu2) ** 2 # (B, bins, L) 121 122 # argmax over bins gives local threshold bin per patch 123 t_bin = torch.argmax(sigma_b2, dim=1) # (B, L) 124 t_norm = t_bin.to(torch.float32) / (num_bins - 1) # normalized [0,1] 125 126 # --- map thresholds back to original intensity scale (per-image) --- 127 # x_min, x_range: (B,1,1,1) -> flatten batch dims 128 thr_vals = x_min.view(B, 1) + t_norm * x_range.view(B, 1) # (B, L) 129 # clamp to >= 0 because foreground is positive 130 thr_vals = thr_vals.clamp_min(0.0) 131 132 thresholds = thr_vals.view(B, H, W) 133 # Take the spatial max over the thresholds. 134 thresholds = torch.amax(thresholds, dim=(1, 2), keepdims=True) 135 return thresholds 136 137 138def _process_masks_for_batch(batch_masks, batch_ious, batch_logits, return_highres_logits, mask_threshold): 139 batch_data = amg_utils.MaskData(masks=batch_masks.flatten(0, 1), iou_preds=batch_ious.flatten(0, 1)) 140 batch_data["logits"] = batch_masks.clone() if return_highres_logits else batch_logits 141 # TODO: probably the best heuristic: choose the lowest threshold which still has stable masks. 142 # To do this: go through the thresholds, starting from 0 in smallish increments, compute stability score, 143 # select the threshold before the stability score starts to drop. 144 if mask_threshold == "auto": 145 thresholds = _local_otsu_threshold(batch_logits) 146 batch_data["stability_scores"] = amg_utils.calculate_stability_score(batch_data["masks"], thresholds, 1.0) 147 batch_data["masks"] = (batch_data["masks"] > thresholds).type(torch.bool) 148 else: 149 batch_data["stability_scores"] = amg_utils.calculate_stability_score(batch_data["masks"], mask_threshold, 1.0) 150 batch_data["masks"] = (batch_data["masks"] > mask_threshold).type(torch.bool) 151 batch_data["boxes"] = batched_mask_to_box(batch_data["masks"]) 152 return batch_data 153 154 155@torch.no_grad() 156def batched_inference( 157 predictor: SamPredictor, 158 image: Optional[np.ndarray], 159 batch_size: int, 160 boxes: Optional[np.ndarray] = None, 161 points: Optional[np.ndarray] = None, 162 point_labels: Optional[np.ndarray] = None, 163 multimasking: bool = False, 164 embedding_path: Optional[Union[str, os.PathLike]] = None, 165 return_instance_segmentation: bool = True, 166 segmentation_ids: Optional[list] = None, 167 reduce_multimasking: bool = True, 168 logits_masks: Optional[torch.Tensor] = None, 169 verbose_embeddings: bool = True, 170 mask_threshold: Optional[Union[float, str]] = None, 171 return_highres_logits: bool = False, 172 i: Optional[int] = None, 173) -> Union[List[List[Dict[str, Any]]], np.ndarray]: 174 """Run batched inference for input prompts. 175 176 Args: 177 predictor: The Segment Anything predictor. 178 image: The input image. If None, we assume that the image embeddings have already been computed. 179 batch_size: The batch size to use for inference. 180 boxes: The box prompts. Array of shape N_PROMPTS x 4. 181 The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y]. 182 points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. 183 The points are represented by their coordinates [X, Y], which are given in the last dimension. 184 point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. 185 The labels are either 0 (negative prompt) or 1 (positive prompt). 186 multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'. 187 embedding_path: Cache path for the image embeddings. By default, computed on-the-fly. 188 return_instance_segmentation: Whether to return a instance segmentation 189 or the individual mask data. By default, set to 'True'. 190 segmentation_ids: Fixed segmentation ids to assign to the masks 191 derived from the prompts. 192 reduce_multimasking: Whether to choose the most likely masks with 193 highest ious from multimasking. By default, set to 'True'. 194 logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. 195 Whether to use the logits masks from previous segmentation. 196 verbose_embeddings: Whether to show progress outputs of computing image embeddings. 197 By default, set to 'True'. 198 mask_threshold: The theshold for binarizing masks based on the predicted values. 199 If None, the default threshold 0 is used. If "auto" is passed then the threshold is 200 determined with a local otsu filter. 201 return_highres_logits: Wheher to return high-resolution logits. 202 i: Index for the image data. Required if `image` has three spatial dimensions 203 or a time dimension and two spatial dimensions. 204 205 Returns: 206 The predicted segmentation masks. 207 """ 208 n_prompts, have_boxes, have_points, have_logits = _validate_inputs( 209 boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks 210 ) 211 212 # Compute the image embeddings. 213 if image is None: # This means the image embeddings are computed already. 214 # Call get image embeddings, this will throw an error if they have not yet been computed. 215 predictor.get_image_embedding() 216 else: 217 image_embeddings = util.precompute_image_embeddings( 218 predictor, image, embedding_path, verbose=verbose_embeddings, i=i, 219 ) 220 util.set_precomputed(predictor, image_embeddings) 221 222 # Determine the number of batches. 223 n_batches = int(np.ceil(float(n_prompts) / batch_size)) 224 225 # Preprocess the prompts. 226 device = predictor.device 227 transform_function = ResizeLongestSide(1024) 228 image_shape = predictor.original_size 229 if have_boxes: 230 boxes = transform_function.apply_boxes(boxes, image_shape) 231 boxes = torch.tensor(boxes, dtype=torch.float32).to(device) 232 if have_points: 233 points = transform_function.apply_coords(points, image_shape) 234 points = torch.tensor(points, dtype=torch.float32).to(device) 235 point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device) 236 237 masks = amg_utils.MaskData() 238 mask_threshold = predictor.model.mask_threshold if mask_threshold is None else mask_threshold 239 for batch_idx in range(n_batches): 240 batch_start = batch_idx * batch_size 241 batch_stop = min((batch_idx + 1) * batch_size, n_prompts) 242 243 batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None 244 batch_points = points[batch_start:batch_stop] if have_points else None 245 batch_labels = point_labels[batch_start:batch_stop] if have_points else None 246 batch_logits = logits_masks[batch_start:batch_stop] if have_logits else None 247 248 batch_masks, batch_ious, batch_logits = predictor.predict_torch( 249 point_coords=batch_points, 250 point_labels=batch_labels, 251 boxes=batch_boxes, 252 mask_input=batch_logits, 253 multimask_output=multimasking, 254 return_logits=True, 255 ) 256 257 # If we expect to reduce the masks from multimasking and use multi-masking, 258 # then we need to select the most likely mask (according to the predicted IOU) here. 259 if reduce_multimasking and multimasking: 260 _, max_index = batch_ious.max(axis=1) 261 batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 262 batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 263 batch_logits = torch.cat([batch_logits[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 264 265 batch_data = _process_masks_for_batch( 266 batch_masks, batch_ious, batch_logits, return_highres_logits, mask_threshold 267 ) 268 masks.cat(batch_data) 269 270 # Mask data to records. 271 masks = [ 272 { 273 "segmentation": masks["masks"][idx], 274 "area": masks["masks"][idx].sum(), 275 "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(), 276 "predicted_iou": masks["iou_preds"][idx].item(), 277 "stability_score": masks["stability_scores"][idx].item(), 278 "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]), 279 "logits": masks["logits"][idx] 280 } 281 for idx in range(len(masks["masks"])) 282 ] 283 284 if return_instance_segmentation: 285 masks = util.mask_data_to_segmentation(masks, min_object_size=0) 286 return masks 287 288 289def _require_tiled_embeddings( 290 predictor, image, image_embeddings, embedding_path, tile_shape, halo, verbose_embeddings 291): 292 if image_embeddings is None: 293 assert image is not None 294 assert (tile_shape is not None) and (halo is not None) 295 shape = image.shape 296 image_embeddings = util.precompute_image_embeddings( 297 predictor, image, embedding_path, ndim=2, tile_shape=tile_shape, halo=halo, verbose=verbose_embeddings 298 ) 299 else: # This means the image embeddings are computed already. 300 attrs = image_embeddings["features"].attrs 301 tile_shape_, halo_ = attrs["tile_shape"], attrs["halo"] 302 shape = attrs["shape"] 303 if tile_shape is None: 304 tile_shape = tile_shape_ 305 elif any(ts != ts_ for ts, ts_ in zip(tile_shape, tile_shape_)): 306 raise ValueError(f"Incompatible tile shapes: {tile_shape} != {tile_shape_}") 307 if halo is None: 308 halo = halo_ 309 elif any(ts != ts_ for ts, ts_ in zip(halo, halo_)): 310 raise ValueError(f"Incompatible tile shapes: {halo} != {halo_}") 311 312 return image_embeddings, shape, tile_shape, halo 313 314 315def _merge_segmentations(this_seg, prev_seg, overlap_threshold=0.75): 316 # Discard new ids with too much overlap. 317 ovlp = ngt.overlap(this_seg, prev_seg) 318 ids = np.unique(this_seg) 319 if ids[0] == 0: 320 ids = ids[1:] 321 discard_ids = [] 322 for seg_id in ids: 323 ovlp_ids, ovlp_vals = ovlp.overlapArraysNormalized(seg_id, True) 324 ovlp_vals = ovlp_vals[ovlp_ids != 0] 325 if ovlp_vals.size > 0 and ovlp_vals[0] > overlap_threshold: 326 discard_ids.append(seg_id) 327 328 # Make sure the previous segmentation is fully preserved. 329 captured = prev_seg != 0 330 this_seg[captured] = prev_seg[captured] 331 return this_seg 332 333 334# Note: we merge with a very simple first come first serve strategy. 335# This could also be improved by merging according to IoUs. 336# (But would add a lot of complexity) 337def _stitch_segmentation(masks, tile_ids, tiling, halo, output_shape, verbose=False): 338 assert len(masks) == len(tile_ids), f"{len(masks)}, {len(tile_ids)}" 339 segmentation = np.zeros(output_shape, dtype="uint32") 340 341 for tile_id, this_seg in tqdm(zip(tile_ids, masks), desc="Stitch tiles", disable=not verbose): 342 tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock 343 bb = tuple(slice(begin, end) for begin, end in zip(tile.begin, tile.end)) 344 if tile_id == 0: 345 segmentation[bb] = this_seg 346 else: 347 # Merge the segmentation, discarding ids with too much overlap. 348 prev_seg = segmentation[bb] 349 assert prev_seg.shape == this_seg.shape, f"{tile_id}: {prev_seg.shape}, {this_seg.shape}" 350 this_seg = _merge_segmentations(this_seg, prev_seg) 351 segmentation[bb] = this_seg 352 353 return segmentation 354 355 356@torch.no_grad() 357def batched_tiled_inference( 358 predictor: SamPredictor, 359 image: Optional[np.ndarray], 360 batch_size: int, 361 image_embeddings: Optional[util.ImageEmbeddings] = None, 362 boxes: Optional[np.ndarray] = None, 363 points: Optional[np.ndarray] = None, 364 point_labels: Optional[np.ndarray] = None, 365 multimasking: bool = False, 366 embedding_path: Optional[Union[str, os.PathLike]] = None, 367 return_instance_segmentation: bool = True, 368 reduce_multimasking: bool = True, 369 logits_masks: Optional[torch.Tensor] = None, 370 verbose_embeddings: bool = True, 371 mask_threshold: Optional[Union[float, str]] = None, 372 tile_shape: Optional[Tuple[int, int]] = None, 373 halo: Optional[Tuple[int, int]] = None, 374 optimize_memory: bool = False, 375 i: Optional[int] = None, 376 **nms_kwargs, 377) -> Union[List[List[Dict[str, Any]]], np.ndarray]: 378 """Run batched inference for input prompts. 379 380 Args: 381 predictor: The Segment Anything predictor. 382 image: The input image. If None, we assume that the image embeddings have already been computed. 383 batch_size: The batch size to use for inference. 384 boxes: The box prompts. Array of shape N_PROMPTS x 4. 385 The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y]. 386 points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. 387 The points are represented by their coordinates [X, Y], which are given in the last dimension. 388 point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. 389 The labels are either 0 (negative prompt) or 1 (positive prompt). 390 multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'. 391 embedding_path: Cache path for the image embeddings. By default, computed on-the-fly. 392 return_instance_segmentation: Whether to return a instance segmentation 393 or the individual mask data. By default, set to 'True'. 394 segmentation_ids: Fixed segmentation ids to assign to the masks 395 derived from the prompts. 396 reduce_multimasking: Whether to choose the most likely masks with 397 highest ious from multimasking. By default, set to 'True'. 398 logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. 399 Whether to use the logits masks from previous segmentation. 400 verbose_embeddings: Whether to show progress outputs of computing image embeddings. 401 By default, set to 'True'. 402 mask_threshold: The theshold for binarizing masks based on the predicted values. 403 If None, the default threshold 0 is used. If "auto" is passed then the threshold is 404 determined with a local otsu filter. 405 tile_shape: The tile shape for embedding prediction. 406 halo: The overlap of between tiles. 407 optimize_memory: Whether to optimize the memory usage. If set to True: 408 - NMS will be applied directly for each tile to reduce it to a per-tile instance segmentation. 409 - The per-tile segmentations will be stitched. 410 - The result will be returned as an instance segmentation. 411 i: Index for the image data. Required if `image` has three spatial dimensions 412 or a time dimension and two spatial dimensions. 413 nms_kwargs: Keyword arguments for the NMS operations that is used for optimize_memory=True. 414 Does not have any effcet for optimize_memory=False. 415 416 Returns: 417 The predicted segmentation masks. 418 """ 419 # Validate inputs and get input prompt summary. 420 segmentation_ids = None 421 n_prompts, have_boxes, have_points, have_logits = _validate_inputs( 422 boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks 423 ) 424 if have_logits: 425 raise NotImplementedError 426 427 # Get the tiling parameters and compute embeddings if needed. 428 image_embeddings, shape, tile_shape, halo = _require_tiled_embeddings( 429 predictor, image, image_embeddings, embedding_path, tile_shape, halo, verbose_embeddings 430 ) 431 432 # Order the prompts by tile and then iterate over the tiles. 433 tiling = blocking([0, 0], shape, tile_shape) 434 box_to_tile, point_to_tile, label_to_tile, logits_to_tile = {}, {}, {}, {} 435 tile_ids = [] 436 437 # Box prompts are in the format N x 4. Boxes are stored in order [MIN_X, MIN_Y, MAX_X, MAX_Y]. 438 # Points are in the Format N x 1 x 2 with coordinate order X, Y. 439 # Point labels are in the format N x 1. 440 # Mask prompts are in the format N x 1 x 256 x 256. 441 for prompt_id in range(n_prompts): 442 this_tile_id = None 443 444 if have_boxes: 445 box = boxes[prompt_id] 446 center = np.array([(box[1] + box[3]) / 2, (box[0] + box[2]) / 2]).round().astype("int").tolist() 447 this_tile_id = tiling.coordinatesToBlockId(center) 448 tile = tiling.getBlockWithHalo(this_tile_id, list(halo)).outerBlock 449 offset = tile.begin 450 this_tile_shape = tile.shape 451 box_in_tile = np.array( 452 [ 453 max(box[1] - offset[0], 0), max(box[0] - offset[1], 0), 454 min(box[3] - offset[0], this_tile_shape[0]), min(box[2] - offset[1], this_tile_shape[1]) 455 ] 456 )[None] 457 if this_tile_id in box_to_tile: 458 box_to_tile[this_tile_id] = np.concatenate([box_to_tile[this_tile_id], box_in_tile]) 459 else: 460 box_to_tile[this_tile_id] = box_in_tile 461 462 if have_points: 463 point = points[prompt_id, 0][::-1].round().astype("int").tolist() 464 if this_tile_id is None: 465 this_tile_id = tiling.coordinatesToBlockId(point) 466 else: 467 assert this_tile_id == tiling.coordinatesToBlockId(point) 468 tile = tiling.getBlockWithHalo(this_tile_id, list(halo)).outerBlock 469 offset = tile.begin 470 point_in_tile = (points[prompt_id, 0] - np.array(offset)[::-1])[None, None] 471 label_in_tile = point_labels[prompt_id][None] 472 if this_tile_id in point_to_tile: 473 point_to_tile[this_tile_id] = np.concatenate([point_to_tile[this_tile_id], point_in_tile]) 474 label_to_tile[this_tile_id] = np.concatenate([label_to_tile[this_tile_id], label_in_tile]) 475 else: 476 point_to_tile[this_tile_id] = point_in_tile 477 label_to_tile[this_tile_id] = label_in_tile 478 479 # NOTE: logits are not yet supported. 480 tile_ids.append(this_tile_id) 481 482 # Find the tiles with prompts. 483 tile_ids = sorted(list(set(tile_ids))) 484 485 # Run batched inference for each tile. 486 masks = [] 487 # Additional variables needed for optimized memory mode. 488 id_offset = 0 489 for tile_id in tqdm(tile_ids, desc="Run batched inference"): 490 # Get the prompts for this tile. 491 tile_boxes = box_to_tile.get(tile_id) 492 tile_logits = logits_to_tile.get(tile_id) 493 tile_points, tile_labels = point_to_tile.get(tile_id), label_to_tile.get(tile_id) 494 495 # Set the correct embeddings, run inference. 496 predictor = util.set_precomputed(predictor, image_embeddings, tile_id=tile_id, i=i) 497 this_masks = batched_inference( 498 predictor=predictor, 499 image=None, 500 batch_size=batch_size, 501 boxes=tile_boxes, 502 points=tile_points, 503 point_labels=tile_labels, 504 multimasking=multimasking, 505 return_instance_segmentation=False, 506 segmentation_ids=segmentation_ids, 507 reduce_multimasking=reduce_multimasking, 508 logits_masks=tile_logits, 509 mask_threshold=mask_threshold, 510 ) 511 512 if optimize_memory: 513 # Apply NMS directly to get a segmentation. 514 segmentation = util.apply_nms(this_masks, **nms_kwargs) 515 fg_mask = segmentation != 0 516 segmentation[fg_mask] += id_offset 517 id_offset = segmentation.max() 518 masks.append(segmentation) 519 else: 520 # Add the offset for the current tile to the bounding box. 521 tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock 522 offset = np.array(tile.begin[::-1] + [0, 0]) 523 this_masks = [ 524 {**mask, "global_bbox": (np.array(mask["bbox"]) + offset).tolist()} for mask in this_masks 525 ] 526 masks.extend(this_masks) 527 528 # Try to keep the memory clean. 529 del this_masks 530 gc.collect() 531 532 if optimize_memory: 533 return _stitch_segmentation(masks, tile_ids, tiling, halo, output_shape=shape) 534 535 if return_instance_segmentation: 536 masks = util.mask_data_to_segmentation(masks, shape=shape, min_object_size=0) 537 return masks
@torch.no_grad()
def
batched_inference( predictor: segment_anything.predictor.SamPredictor, image: Optional[numpy.ndarray], batch_size: int, boxes: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, point_labels: Optional[numpy.ndarray] = None, multimasking: bool = False, embedding_path: Union[str, os.PathLike, NoneType] = None, return_instance_segmentation: bool = True, segmentation_ids: Optional[list] = None, reduce_multimasking: bool = True, logits_masks: Optional[torch.Tensor] = None, verbose_embeddings: bool = True, mask_threshold: Union[float, str, NoneType] = None, return_highres_logits: bool = False, i: Optional[int] = None) -> Union[List[List[Dict[str, Any]]], numpy.ndarray]:
156@torch.no_grad() 157def batched_inference( 158 predictor: SamPredictor, 159 image: Optional[np.ndarray], 160 batch_size: int, 161 boxes: Optional[np.ndarray] = None, 162 points: Optional[np.ndarray] = None, 163 point_labels: Optional[np.ndarray] = None, 164 multimasking: bool = False, 165 embedding_path: Optional[Union[str, os.PathLike]] = None, 166 return_instance_segmentation: bool = True, 167 segmentation_ids: Optional[list] = None, 168 reduce_multimasking: bool = True, 169 logits_masks: Optional[torch.Tensor] = None, 170 verbose_embeddings: bool = True, 171 mask_threshold: Optional[Union[float, str]] = None, 172 return_highres_logits: bool = False, 173 i: Optional[int] = None, 174) -> Union[List[List[Dict[str, Any]]], np.ndarray]: 175 """Run batched inference for input prompts. 176 177 Args: 178 predictor: The Segment Anything predictor. 179 image: The input image. If None, we assume that the image embeddings have already been computed. 180 batch_size: The batch size to use for inference. 181 boxes: The box prompts. Array of shape N_PROMPTS x 4. 182 The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y]. 183 points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. 184 The points are represented by their coordinates [X, Y], which are given in the last dimension. 185 point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. 186 The labels are either 0 (negative prompt) or 1 (positive prompt). 187 multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'. 188 embedding_path: Cache path for the image embeddings. By default, computed on-the-fly. 189 return_instance_segmentation: Whether to return a instance segmentation 190 or the individual mask data. By default, set to 'True'. 191 segmentation_ids: Fixed segmentation ids to assign to the masks 192 derived from the prompts. 193 reduce_multimasking: Whether to choose the most likely masks with 194 highest ious from multimasking. By default, set to 'True'. 195 logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. 196 Whether to use the logits masks from previous segmentation. 197 verbose_embeddings: Whether to show progress outputs of computing image embeddings. 198 By default, set to 'True'. 199 mask_threshold: The theshold for binarizing masks based on the predicted values. 200 If None, the default threshold 0 is used. If "auto" is passed then the threshold is 201 determined with a local otsu filter. 202 return_highres_logits: Wheher to return high-resolution logits. 203 i: Index for the image data. Required if `image` has three spatial dimensions 204 or a time dimension and two spatial dimensions. 205 206 Returns: 207 The predicted segmentation masks. 208 """ 209 n_prompts, have_boxes, have_points, have_logits = _validate_inputs( 210 boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks 211 ) 212 213 # Compute the image embeddings. 214 if image is None: # This means the image embeddings are computed already. 215 # Call get image embeddings, this will throw an error if they have not yet been computed. 216 predictor.get_image_embedding() 217 else: 218 image_embeddings = util.precompute_image_embeddings( 219 predictor, image, embedding_path, verbose=verbose_embeddings, i=i, 220 ) 221 util.set_precomputed(predictor, image_embeddings) 222 223 # Determine the number of batches. 224 n_batches = int(np.ceil(float(n_prompts) / batch_size)) 225 226 # Preprocess the prompts. 227 device = predictor.device 228 transform_function = ResizeLongestSide(1024) 229 image_shape = predictor.original_size 230 if have_boxes: 231 boxes = transform_function.apply_boxes(boxes, image_shape) 232 boxes = torch.tensor(boxes, dtype=torch.float32).to(device) 233 if have_points: 234 points = transform_function.apply_coords(points, image_shape) 235 points = torch.tensor(points, dtype=torch.float32).to(device) 236 point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device) 237 238 masks = amg_utils.MaskData() 239 mask_threshold = predictor.model.mask_threshold if mask_threshold is None else mask_threshold 240 for batch_idx in range(n_batches): 241 batch_start = batch_idx * batch_size 242 batch_stop = min((batch_idx + 1) * batch_size, n_prompts) 243 244 batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None 245 batch_points = points[batch_start:batch_stop] if have_points else None 246 batch_labels = point_labels[batch_start:batch_stop] if have_points else None 247 batch_logits = logits_masks[batch_start:batch_stop] if have_logits else None 248 249 batch_masks, batch_ious, batch_logits = predictor.predict_torch( 250 point_coords=batch_points, 251 point_labels=batch_labels, 252 boxes=batch_boxes, 253 mask_input=batch_logits, 254 multimask_output=multimasking, 255 return_logits=True, 256 ) 257 258 # If we expect to reduce the masks from multimasking and use multi-masking, 259 # then we need to select the most likely mask (according to the predicted IOU) here. 260 if reduce_multimasking and multimasking: 261 _, max_index = batch_ious.max(axis=1) 262 batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 263 batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 264 batch_logits = torch.cat([batch_logits[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 265 266 batch_data = _process_masks_for_batch( 267 batch_masks, batch_ious, batch_logits, return_highres_logits, mask_threshold 268 ) 269 masks.cat(batch_data) 270 271 # Mask data to records. 272 masks = [ 273 { 274 "segmentation": masks["masks"][idx], 275 "area": masks["masks"][idx].sum(), 276 "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(), 277 "predicted_iou": masks["iou_preds"][idx].item(), 278 "stability_score": masks["stability_scores"][idx].item(), 279 "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]), 280 "logits": masks["logits"][idx] 281 } 282 for idx in range(len(masks["masks"])) 283 ] 284 285 if return_instance_segmentation: 286 masks = util.mask_data_to_segmentation(masks, min_object_size=0) 287 return masks
Run batched inference for input prompts.
Arguments:
- predictor: The Segment Anything predictor.
- image: The input image. If None, we assume that the image embeddings have already been computed.
- batch_size: The batch size to use for inference.
- boxes: The box prompts. Array of shape N_PROMPTS x 4. The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
- points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. The points are represented by their coordinates [X, Y], which are given in the last dimension.
- point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. The labels are either 0 (negative prompt) or 1 (positive prompt).
- multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'.
- embedding_path: Cache path for the image embeddings. By default, computed on-the-fly.
- return_instance_segmentation: Whether to return a instance segmentation or the individual mask data. By default, set to 'True'.
- segmentation_ids: Fixed segmentation ids to assign to the masks derived from the prompts.
- reduce_multimasking: Whether to choose the most likely masks with highest ious from multimasking. By default, set to 'True'.
- logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. Whether to use the logits masks from previous segmentation.
- verbose_embeddings: Whether to show progress outputs of computing image embeddings. By default, set to 'True'.
- mask_threshold: The theshold for binarizing masks based on the predicted values. If None, the default threshold 0 is used. If "auto" is passed then the threshold is determined with a local otsu filter.
- return_highres_logits: Wheher to return high-resolution logits.
- i: Index for the image data. Required if
imagehas three spatial dimensions or a time dimension and two spatial dimensions.
Returns:
The predicted segmentation masks.
@torch.no_grad()
def
batched_tiled_inference( predictor: segment_anything.predictor.SamPredictor, image: Optional[numpy.ndarray], batch_size: int, image_embeddings: Optional[Dict[str, Any]] = None, boxes: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, point_labels: Optional[numpy.ndarray] = None, multimasking: bool = False, embedding_path: Union[str, os.PathLike, NoneType] = None, return_instance_segmentation: bool = True, reduce_multimasking: bool = True, logits_masks: Optional[torch.Tensor] = None, verbose_embeddings: bool = True, mask_threshold: Union[float, str, NoneType] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, optimize_memory: bool = False, i: Optional[int] = None, **nms_kwargs) -> Union[List[List[Dict[str, Any]]], numpy.ndarray]:
357@torch.no_grad() 358def batched_tiled_inference( 359 predictor: SamPredictor, 360 image: Optional[np.ndarray], 361 batch_size: int, 362 image_embeddings: Optional[util.ImageEmbeddings] = None, 363 boxes: Optional[np.ndarray] = None, 364 points: Optional[np.ndarray] = None, 365 point_labels: Optional[np.ndarray] = None, 366 multimasking: bool = False, 367 embedding_path: Optional[Union[str, os.PathLike]] = None, 368 return_instance_segmentation: bool = True, 369 reduce_multimasking: bool = True, 370 logits_masks: Optional[torch.Tensor] = None, 371 verbose_embeddings: bool = True, 372 mask_threshold: Optional[Union[float, str]] = None, 373 tile_shape: Optional[Tuple[int, int]] = None, 374 halo: Optional[Tuple[int, int]] = None, 375 optimize_memory: bool = False, 376 i: Optional[int] = None, 377 **nms_kwargs, 378) -> Union[List[List[Dict[str, Any]]], np.ndarray]: 379 """Run batched inference for input prompts. 380 381 Args: 382 predictor: The Segment Anything predictor. 383 image: The input image. If None, we assume that the image embeddings have already been computed. 384 batch_size: The batch size to use for inference. 385 boxes: The box prompts. Array of shape N_PROMPTS x 4. 386 The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y]. 387 points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. 388 The points are represented by their coordinates [X, Y], which are given in the last dimension. 389 point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. 390 The labels are either 0 (negative prompt) or 1 (positive prompt). 391 multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'. 392 embedding_path: Cache path for the image embeddings. By default, computed on-the-fly. 393 return_instance_segmentation: Whether to return a instance segmentation 394 or the individual mask data. By default, set to 'True'. 395 segmentation_ids: Fixed segmentation ids to assign to the masks 396 derived from the prompts. 397 reduce_multimasking: Whether to choose the most likely masks with 398 highest ious from multimasking. By default, set to 'True'. 399 logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. 400 Whether to use the logits masks from previous segmentation. 401 verbose_embeddings: Whether to show progress outputs of computing image embeddings. 402 By default, set to 'True'. 403 mask_threshold: The theshold for binarizing masks based on the predicted values. 404 If None, the default threshold 0 is used. If "auto" is passed then the threshold is 405 determined with a local otsu filter. 406 tile_shape: The tile shape for embedding prediction. 407 halo: The overlap of between tiles. 408 optimize_memory: Whether to optimize the memory usage. If set to True: 409 - NMS will be applied directly for each tile to reduce it to a per-tile instance segmentation. 410 - The per-tile segmentations will be stitched. 411 - The result will be returned as an instance segmentation. 412 i: Index for the image data. Required if `image` has three spatial dimensions 413 or a time dimension and two spatial dimensions. 414 nms_kwargs: Keyword arguments for the NMS operations that is used for optimize_memory=True. 415 Does not have any effcet for optimize_memory=False. 416 417 Returns: 418 The predicted segmentation masks. 419 """ 420 # Validate inputs and get input prompt summary. 421 segmentation_ids = None 422 n_prompts, have_boxes, have_points, have_logits = _validate_inputs( 423 boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks 424 ) 425 if have_logits: 426 raise NotImplementedError 427 428 # Get the tiling parameters and compute embeddings if needed. 429 image_embeddings, shape, tile_shape, halo = _require_tiled_embeddings( 430 predictor, image, image_embeddings, embedding_path, tile_shape, halo, verbose_embeddings 431 ) 432 433 # Order the prompts by tile and then iterate over the tiles. 434 tiling = blocking([0, 0], shape, tile_shape) 435 box_to_tile, point_to_tile, label_to_tile, logits_to_tile = {}, {}, {}, {} 436 tile_ids = [] 437 438 # Box prompts are in the format N x 4. Boxes are stored in order [MIN_X, MIN_Y, MAX_X, MAX_Y]. 439 # Points are in the Format N x 1 x 2 with coordinate order X, Y. 440 # Point labels are in the format N x 1. 441 # Mask prompts are in the format N x 1 x 256 x 256. 442 for prompt_id in range(n_prompts): 443 this_tile_id = None 444 445 if have_boxes: 446 box = boxes[prompt_id] 447 center = np.array([(box[1] + box[3]) / 2, (box[0] + box[2]) / 2]).round().astype("int").tolist() 448 this_tile_id = tiling.coordinatesToBlockId(center) 449 tile = tiling.getBlockWithHalo(this_tile_id, list(halo)).outerBlock 450 offset = tile.begin 451 this_tile_shape = tile.shape 452 box_in_tile = np.array( 453 [ 454 max(box[1] - offset[0], 0), max(box[0] - offset[1], 0), 455 min(box[3] - offset[0], this_tile_shape[0]), min(box[2] - offset[1], this_tile_shape[1]) 456 ] 457 )[None] 458 if this_tile_id in box_to_tile: 459 box_to_tile[this_tile_id] = np.concatenate([box_to_tile[this_tile_id], box_in_tile]) 460 else: 461 box_to_tile[this_tile_id] = box_in_tile 462 463 if have_points: 464 point = points[prompt_id, 0][::-1].round().astype("int").tolist() 465 if this_tile_id is None: 466 this_tile_id = tiling.coordinatesToBlockId(point) 467 else: 468 assert this_tile_id == tiling.coordinatesToBlockId(point) 469 tile = tiling.getBlockWithHalo(this_tile_id, list(halo)).outerBlock 470 offset = tile.begin 471 point_in_tile = (points[prompt_id, 0] - np.array(offset)[::-1])[None, None] 472 label_in_tile = point_labels[prompt_id][None] 473 if this_tile_id in point_to_tile: 474 point_to_tile[this_tile_id] = np.concatenate([point_to_tile[this_tile_id], point_in_tile]) 475 label_to_tile[this_tile_id] = np.concatenate([label_to_tile[this_tile_id], label_in_tile]) 476 else: 477 point_to_tile[this_tile_id] = point_in_tile 478 label_to_tile[this_tile_id] = label_in_tile 479 480 # NOTE: logits are not yet supported. 481 tile_ids.append(this_tile_id) 482 483 # Find the tiles with prompts. 484 tile_ids = sorted(list(set(tile_ids))) 485 486 # Run batched inference for each tile. 487 masks = [] 488 # Additional variables needed for optimized memory mode. 489 id_offset = 0 490 for tile_id in tqdm(tile_ids, desc="Run batched inference"): 491 # Get the prompts for this tile. 492 tile_boxes = box_to_tile.get(tile_id) 493 tile_logits = logits_to_tile.get(tile_id) 494 tile_points, tile_labels = point_to_tile.get(tile_id), label_to_tile.get(tile_id) 495 496 # Set the correct embeddings, run inference. 497 predictor = util.set_precomputed(predictor, image_embeddings, tile_id=tile_id, i=i) 498 this_masks = batched_inference( 499 predictor=predictor, 500 image=None, 501 batch_size=batch_size, 502 boxes=tile_boxes, 503 points=tile_points, 504 point_labels=tile_labels, 505 multimasking=multimasking, 506 return_instance_segmentation=False, 507 segmentation_ids=segmentation_ids, 508 reduce_multimasking=reduce_multimasking, 509 logits_masks=tile_logits, 510 mask_threshold=mask_threshold, 511 ) 512 513 if optimize_memory: 514 # Apply NMS directly to get a segmentation. 515 segmentation = util.apply_nms(this_masks, **nms_kwargs) 516 fg_mask = segmentation != 0 517 segmentation[fg_mask] += id_offset 518 id_offset = segmentation.max() 519 masks.append(segmentation) 520 else: 521 # Add the offset for the current tile to the bounding box. 522 tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock 523 offset = np.array(tile.begin[::-1] + [0, 0]) 524 this_masks = [ 525 {**mask, "global_bbox": (np.array(mask["bbox"]) + offset).tolist()} for mask in this_masks 526 ] 527 masks.extend(this_masks) 528 529 # Try to keep the memory clean. 530 del this_masks 531 gc.collect() 532 533 if optimize_memory: 534 return _stitch_segmentation(masks, tile_ids, tiling, halo, output_shape=shape) 535 536 if return_instance_segmentation: 537 masks = util.mask_data_to_segmentation(masks, shape=shape, min_object_size=0) 538 return masks
Run batched inference for input prompts.
Arguments:
- predictor: The Segment Anything predictor.
- image: The input image. If None, we assume that the image embeddings have already been computed.
- batch_size: The batch size to use for inference.
- boxes: The box prompts. Array of shape N_PROMPTS x 4. The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
- points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. The points are represented by their coordinates [X, Y], which are given in the last dimension.
- point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. The labels are either 0 (negative prompt) or 1 (positive prompt).
- multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'.
- embedding_path: Cache path for the image embeddings. By default, computed on-the-fly.
- return_instance_segmentation: Whether to return a instance segmentation or the individual mask data. By default, set to 'True'.
- segmentation_ids: Fixed segmentation ids to assign to the masks derived from the prompts.
- reduce_multimasking: Whether to choose the most likely masks with highest ious from multimasking. By default, set to 'True'.
- logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. Whether to use the logits masks from previous segmentation.
- verbose_embeddings: Whether to show progress outputs of computing image embeddings. By default, set to 'True'.
- mask_threshold: The theshold for binarizing masks based on the predicted values. If None, the default threshold 0 is used. If "auto" is passed then the threshold is determined with a local otsu filter.
- tile_shape: The tile shape for embedding prediction.
- halo: The overlap of between tiles.
- optimize_memory: Whether to optimize the memory usage. If set to True:
- NMS will be applied directly for each tile to reduce it to a per-tile instance segmentation.
- The per-tile segmentations will be stitched.
- The result will be returned as an instance segmentation.
- i: Index for the image data. Required if
imagehas three spatial dimensions or a time dimension and two spatial dimensions. - nms_kwargs: Keyword arguments for the NMS operations that is used for optimize_memory=True. Does not have any effcet for optimize_memory=False.
Returns:
The predicted segmentation masks.