micro_sam.inference
1import os 2import gc 3from typing import Any, Dict, List, Optional, Union, Tuple 4 5import numpy as np 6import torch 7import torch.nn.functional as F 8from bioimage_cpp.utils import Blocking, segmentation_overlap 9 10import segment_anything.utils.amg as amg_utils 11from segment_anything import SamPredictor 12from segment_anything.utils.transforms import ResizeLongestSide 13try: 14 from napari.utils import progress as tqdm 15except ImportError: 16 from tqdm import tqdm 17 18from . import util 19from ._vendored import batched_mask_to_box 20 21 22def _validate_inputs( 23 boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks 24): 25 if multimasking and (segmentation_ids is not None) and (not return_instance_segmentation): 26 raise NotImplementedError 27 28 if (points is None) != (point_labels is None): 29 raise ValueError( 30 "If you have point prompts both `points` and `point_labels` have to be passed, " 31 "but you passed only one of them." 32 ) 33 34 have_points = points is not None 35 have_boxes = boxes is not None 36 have_logits = logits_masks is not None 37 if (not have_points) and (not have_boxes): 38 raise ValueError("Point and/or box prompts have to be passed, you passed neither.") 39 40 if have_points and (len(point_labels) != len(points)): 41 raise ValueError( 42 f"The number of point coordinates and labels does not match: {len(point_labels)} != {len(points)}" 43 ) 44 45 if (have_points and have_boxes) and (len(points) != len(boxes)): 46 raise ValueError( 47 f"The number of point and box prompts does not match: {len(points)} != {len(boxes)}" 48 ) 49 50 if have_logits: 51 if have_points and (len(logits_masks) != len(point_labels)): 52 raise ValueError( 53 f"The number of point and logits does not match: {len(points) != len(logits_masks)}" 54 ) 55 elif have_boxes and (len(logits_masks) != len(boxes)): 56 raise ValueError( 57 f"The number of boxes and logits does not match: {len(boxes)} != {len(logits_masks)}" 58 ) 59 60 n_prompts = boxes.shape[0] if have_boxes else points.shape[0] 61 62 if (segmentation_ids is not None) and (len(segmentation_ids) != n_prompts): 63 raise ValueError( 64 f"The number of segmentation ids and prompts does not match: {len(segmentation_ids)} != {n_prompts}" 65 ) 66 67 return n_prompts, have_boxes, have_points, have_logits 68 69 70def _local_otsu_threshold(images: torch.Tensor, window_size: int = 31, num_bins: int = 64, eps: float = 1e-6): 71 x = images 72 B, _, H, W = x.shape 73 device = x.device 74 75 # Work in float32 for stability even if input is fp16 76 x = x.to(torch.float32) 77 78 # per-image min/max for normalization to [0, 1] 79 x_flat = x.view(B, -1) 80 x_min = x_flat.min(dim=1).values.view(B, 1, 1, 1) 81 x_max = x_flat.max(dim=1).values.view(B, 1, 1, 1) 82 x_range = (x_max - x_min).clamp_min(eps) 83 84 x_norm = (x - x_min) / x_range # (B,1,H,W), in [0,1] 85 86 # extract local patches via unfold 87 pad = window_size // 2 88 patches = F.unfold(x_norm, kernel_size=window_size, padding=pad) # (B, P, L) 89 # P = window_size * window_size, L = H * W 90 B_, P, L = patches.shape 91 92 # quantize to bins 93 bin_idx = (patches * (num_bins - 1)).long().clamp(0, num_bins - 1) # (B, P, L) 94 95 # build histograms per patch 96 # one_hot: (B, L, num_bins) 97 one_hot = torch.zeros(B, L, num_bins, device=device, dtype=torch.float32) 98 idx = bin_idx.transpose(1, 2) # (B, L, P) 99 src = torch.ones_like(idx, dtype=one_hot.dtype) # (B, L, P) 100 one_hot.scatter_add_(2, idx, src) 101 # hist: (B, num_bins, L) 102 hist = one_hot.permute(0, 2, 1) 103 104 # Otsu per patch (vectorized) 105 # p: (B, bins, L) 106 p = hist / hist.sum(dim=1, keepdim=True).clamp_min(eps) 107 108 bins = torch.arange(num_bins, device=device, dtype=torch.float32).view(1, num_bins, 1) 109 110 omega1 = torch.cumsum(p, dim=1) # (B, bins, L) 111 mu = torch.cumsum(p * bins, dim=1) # (B, bins, L) 112 mu_T = mu[:, -1:, :] # (B, 1, L) 113 114 omega2 = 1.0 - omega1 115 116 mu1 = mu / omega1.clamp_min(eps) 117 mu2 = (mu_T - mu) / omega2.clamp_min(eps) 118 119 sigma_b2 = omega1 * omega2 * (mu1 - mu2) ** 2 # (B, bins, L) 120 121 # argmax over bins gives local threshold bin per patch 122 t_bin = torch.argmax(sigma_b2, dim=1) # (B, L) 123 t_norm = t_bin.to(torch.float32) / (num_bins - 1) # normalized [0,1] 124 125 # map thresholds back to original intensity scale (per-image) 126 # x_min, x_range: (B,1,1,1) -> flatten batch dims 127 thr_vals = x_min.view(B, 1) + t_norm * x_range.view(B, 1) # (B, L) 128 # clamp to >= 0 because foreground is positive 129 thr_vals = thr_vals.clamp_min(0.0) 130 131 thresholds = thr_vals.view(B, H, W) 132 # Take the spatial max over the thresholds. 133 thresholds = torch.amax(thresholds, dim=(1, 2), keepdims=True) 134 return thresholds 135 136 137def _process_masks_for_batch(batch_masks, batch_ious, batch_logits, return_highres_logits, mask_threshold): 138 batch_data = amg_utils.MaskData(masks=batch_masks.flatten(0, 1), iou_preds=batch_ious.flatten(0, 1)) 139 batch_data["logits"] = batch_masks.clone() if return_highres_logits else batch_logits 140 # TODO: probably the best heuristic: choose the lowest threshold which still has stable masks. 141 # To do this: go through the thresholds, starting from 0 in smallish increments, compute stability score, 142 # select the threshold before the stability score starts to drop. 143 if mask_threshold == "auto": 144 thresholds = _local_otsu_threshold(batch_logits) 145 batch_data["stability_scores"] = amg_utils.calculate_stability_score(batch_data["masks"], thresholds, 1.0) 146 batch_data["masks"] = (batch_data["masks"] > thresholds).type(torch.bool) 147 else: 148 batch_data["stability_scores"] = amg_utils.calculate_stability_score(batch_data["masks"], mask_threshold, 1.0) 149 batch_data["masks"] = (batch_data["masks"] > mask_threshold).type(torch.bool) 150 batch_data["boxes"] = batched_mask_to_box(batch_data["masks"]) 151 return batch_data 152 153 154@torch.no_grad() 155def batched_inference( 156 predictor: SamPredictor, 157 image: Optional[np.ndarray], 158 batch_size: int, 159 boxes: Optional[np.ndarray] = None, 160 points: Optional[np.ndarray] = None, 161 point_labels: Optional[np.ndarray] = None, 162 multimasking: bool = False, 163 embedding_path: Optional[Union[str, os.PathLike]] = None, 164 return_instance_segmentation: bool = True, 165 segmentation_ids: Optional[list] = None, 166 reduce_multimasking: bool = True, 167 logits_masks: Optional[torch.Tensor] = None, 168 verbose_embeddings: bool = True, 169 mask_threshold: Optional[Union[float, str]] = None, 170 return_highres_logits: bool = False, 171 i: Optional[int] = None, 172) -> Union[List[List[Dict[str, Any]]], np.ndarray]: 173 """Run batched inference for input prompts. 174 175 Args: 176 predictor: The Segment Anything predictor. 177 image: The input image. If None, we assume that the image embeddings have already been computed. 178 batch_size: The batch size to use for inference. 179 boxes: The box prompts. Array of shape N_PROMPTS x 4. 180 The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y]. 181 points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. 182 The points are represented by their coordinates [X, Y], which are given in the last dimension. 183 point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. 184 The labels are either 0 (negative prompt) or 1 (positive prompt). 185 multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'. 186 embedding_path: Cache path for the image embeddings. By default, computed on-the-fly. 187 return_instance_segmentation: Whether to return a instance segmentation 188 or the individual mask data. By default, set to 'True'. 189 segmentation_ids: Fixed segmentation ids to assign to the masks 190 derived from the prompts. 191 reduce_multimasking: Whether to choose the most likely masks with 192 highest ious from multimasking. By default, set to 'True'. 193 logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. 194 Whether to use the logits masks from previous segmentation. 195 verbose_embeddings: Whether to show progress outputs of computing image embeddings. 196 By default, set to 'True'. 197 mask_threshold: The theshold for binarizing masks based on the predicted values. 198 If None, the default threshold 0 is used. If "auto" is passed then the threshold is 199 determined with a local otsu filter. 200 return_highres_logits: Wheher to return high-resolution logits. 201 i: Index for the image data. Required if `image` has three spatial dimensions 202 or a time dimension and two spatial dimensions. 203 204 Returns: 205 The predicted segmentation masks. 206 """ 207 n_prompts, have_boxes, have_points, have_logits = _validate_inputs( 208 boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks 209 ) 210 211 # Compute the image embeddings. 212 if image is None: # This means the image embeddings are computed already. 213 # Call get image embeddings, this will throw an error if they have not yet been computed. 214 predictor.get_image_embedding() 215 else: 216 input_ = image if i is None else image[i] 217 image_embeddings = util.precompute_image_embeddings( 218 predictor, input_, embedding_path, verbose=verbose_embeddings 219 ) 220 util.set_precomputed(predictor, image_embeddings) 221 222 # Determine the number of batches. 223 n_batches = int(np.ceil(float(n_prompts) / batch_size)) 224 225 # Preprocess the prompts. 226 device = predictor.device 227 transform_function = ResizeLongestSide(1024) 228 image_shape = predictor.original_size 229 if have_boxes: 230 boxes = transform_function.apply_boxes(boxes, image_shape) 231 boxes = torch.tensor(boxes, dtype=torch.float32).to(device) 232 if have_points: 233 points = transform_function.apply_coords(points, image_shape) 234 points = torch.tensor(points, dtype=torch.float32).to(device) 235 point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device) 236 237 masks = amg_utils.MaskData() 238 mask_threshold = predictor.model.mask_threshold if mask_threshold is None else mask_threshold 239 for batch_idx in range(n_batches): 240 batch_start = batch_idx * batch_size 241 batch_stop = min((batch_idx + 1) * batch_size, n_prompts) 242 243 batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None 244 batch_points = points[batch_start:batch_stop] if have_points else None 245 batch_labels = point_labels[batch_start:batch_stop] if have_points else None 246 batch_logits = logits_masks[batch_start:batch_stop] if have_logits else None 247 248 batch_masks, batch_ious, batch_logits = predictor.predict_torch( 249 point_coords=batch_points, 250 point_labels=batch_labels, 251 boxes=batch_boxes, 252 mask_input=batch_logits, 253 multimask_output=multimasking, 254 return_logits=True, 255 ) 256 257 # If we expect to reduce the masks from multimasking and use multi-masking, 258 # then we need to select the most likely mask (according to the predicted IOU) here. 259 if reduce_multimasking and multimasking: 260 _, max_index = batch_ious.max(axis=1) 261 batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 262 batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 263 batch_logits = torch.cat([batch_logits[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 264 265 batch_data = _process_masks_for_batch( 266 batch_masks, batch_ious, batch_logits, return_highres_logits, mask_threshold 267 ) 268 masks.cat(batch_data) 269 270 # Mask data to records. 271 masks = [ 272 { 273 "segmentation": masks["masks"][idx], 274 "area": masks["masks"][idx].sum(), 275 "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(), 276 "predicted_iou": masks["iou_preds"][idx].item(), 277 "stability_score": masks["stability_scores"][idx].item(), 278 "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]), 279 "logits": masks["logits"][idx] 280 } 281 for idx in range(len(masks["masks"])) 282 ] 283 284 if return_instance_segmentation: 285 masks = util.mask_data_to_segmentation(masks, min_object_size=0) 286 return masks 287 288 289def _require_tiled_embeddings( 290 predictor, image, image_embeddings, embedding_path, tile_shape, halo, verbose_embeddings 291): 292 if image_embeddings is None: 293 assert image is not None 294 assert (tile_shape is not None) and (halo is not None) 295 shape = image.shape[:2] 296 image_embeddings = util.precompute_image_embeddings( 297 predictor, image, embedding_path, ndim=2, tile_shape=tile_shape, halo=halo, verbose=verbose_embeddings 298 ) 299 else: # This means the image embeddings are computed already. 300 attrs = image_embeddings["features"].attrs 301 tile_shape_, halo_ = attrs["tile_shape"], attrs["halo"] 302 shape = attrs["shape"] 303 if tile_shape is None: 304 tile_shape = tile_shape_ 305 elif any(ts != ts_ for ts, ts_ in zip(tile_shape, tile_shape_)): 306 raise ValueError(f"Incompatible tile shapes: {tile_shape} != {tile_shape_}") 307 if halo is None: 308 halo = halo_ 309 elif any(ts != ts_ for ts, ts_ in zip(halo, halo_)): 310 raise ValueError(f"Incompatible tile shapes: {halo} != {halo_}") 311 312 return image_embeddings, shape, tile_shape, halo 313 314 315def _merge_segmentations(this_seg, prev_seg, overlap_threshold=0.75): 316 # Discard new ids with too much overlap. 317 ovlp = segmentation_overlap(this_seg, prev_seg) 318 ids = np.unique(this_seg) 319 if ids[0] == 0: 320 ids = ids[1:] 321 discard_ids = [] 322 for seg_id in ids: 323 ovlp_table = ovlp.overlaps_for_label_a(seg_id, normalize=True) 324 ovlp_ids, ovlp_vals = ovlp_table["label"], ovlp_table["fraction"] 325 ovlp_vals = ovlp_vals[ovlp_ids != 0] 326 if ovlp_vals.size > 0 and ovlp_vals[0] > overlap_threshold: 327 discard_ids.append(seg_id) 328 329 # Make sure the previous segmentation is fully preserved. 330 captured = prev_seg != 0 331 this_seg[captured] = prev_seg[captured] 332 return this_seg 333 334 335# Note: we merge with a very simple first come first serve strategy. 336# This could also be improved by merging according to IoUs. 337# (But would add a lot of complexity) 338def _stitch_segmentation(masks, tile_ids, tiling, halo, output_shape, verbose=False): 339 assert len(masks) == len(tile_ids), f"{len(masks)}, {len(tile_ids)}" 340 segmentation = np.zeros(output_shape, dtype="uint32") 341 342 for tile_id, this_seg in tqdm(zip(tile_ids, masks), desc="Stitch tiles", disable=not verbose): 343 tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock 344 bb = tuple(slice(begin, end) for begin, end in zip(tile.begin, tile.end)) 345 if tile_id == 0: 346 segmentation[bb] = this_seg 347 else: 348 # Merge the segmentation, discarding ids with too much overlap. 349 prev_seg = segmentation[bb] 350 assert prev_seg.shape == this_seg.shape, f"{tile_id}: {prev_seg.shape}, {this_seg.shape}" 351 this_seg = _merge_segmentations(this_seg, prev_seg) 352 segmentation[bb] = this_seg 353 354 return segmentation 355 356 357@torch.no_grad() 358def batched_tiled_inference( 359 predictor: SamPredictor, 360 image: Optional[np.ndarray], 361 batch_size: int, 362 image_embeddings: Optional[util.ImageEmbeddings] = None, 363 boxes: Optional[np.ndarray] = None, 364 points: Optional[np.ndarray] = None, 365 point_labels: Optional[np.ndarray] = None, 366 multimasking: bool = False, 367 embedding_path: Optional[Union[str, os.PathLike]] = None, 368 return_instance_segmentation: bool = True, 369 reduce_multimasking: bool = True, 370 logits_masks: Optional[torch.Tensor] = None, 371 verbose_embeddings: bool = True, 372 mask_threshold: Optional[Union[float, str]] = None, 373 tile_shape: Optional[Tuple[int, int]] = None, 374 halo: Optional[Tuple[int, int]] = None, 375 optimize_memory: bool = False, 376 i: Optional[int] = None, 377 **nms_kwargs, 378) -> Union[List[List[Dict[str, Any]]], np.ndarray]: 379 """Run batched inference for input prompts. 380 381 Args: 382 predictor: The Segment Anything predictor. 383 image: The input image. If None, we assume that the image embeddings have already been computed. 384 batch_size: The batch size to use for inference. 385 boxes: The box prompts. Array of shape N_PROMPTS x 4. 386 The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y]. 387 points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. 388 The points are represented by their coordinates [X, Y], which are given in the last dimension. 389 point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. 390 The labels are either 0 (negative prompt) or 1 (positive prompt). 391 multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'. 392 embedding_path: Cache path for the image embeddings. By default, computed on-the-fly. 393 return_instance_segmentation: Whether to return a instance segmentation 394 or the individual mask data. By default, set to 'True'. 395 segmentation_ids: Fixed segmentation ids to assign to the masks 396 derived from the prompts. 397 reduce_multimasking: Whether to choose the most likely masks with 398 highest ious from multimasking. By default, set to 'True'. 399 logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. 400 Whether to use the logits masks from previous segmentation. 401 verbose_embeddings: Whether to show progress outputs of computing image embeddings. 402 By default, set to 'True'. 403 mask_threshold: The theshold for binarizing masks based on the predicted values. 404 If None, the default threshold 0 is used. If "auto" is passed then the threshold is 405 determined with a local otsu filter. 406 tile_shape: The tile shape for embedding prediction. 407 halo: The overlap of between tiles. 408 optimize_memory: Whether to optimize the memory usage. If set to True: 409 - NMS will be applied directly for each tile to reduce it to a per-tile instance segmentation. 410 - The per-tile segmentations will be stitched. 411 - The result will be returned as an instance segmentation. 412 i: Index for the image data. Required if `image` has three spatial dimensions 413 or a time dimension and two spatial dimensions. 414 nms_kwargs: Keyword arguments for the NMS operations that is used for optimize_memory=True. 415 Does not have any effect for optimize_memory=False. 416 417 Returns: 418 The predicted segmentation masks. 419 """ 420 # Validate inputs and get input prompt summary. 421 segmentation_ids = None 422 n_prompts, have_boxes, have_points, have_logits = _validate_inputs( 423 boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks 424 ) 425 if have_logits: 426 raise NotImplementedError 427 428 # Get the tiling parameters and compute embeddings if needed. 429 image_embeddings, shape, tile_shape, halo = _require_tiled_embeddings( 430 predictor, image, image_embeddings, embedding_path, tile_shape, halo, verbose_embeddings 431 ) 432 433 # Order the prompts by tile and then iterate over the tiles. 434 tiling = Blocking([0, 0], shape, tile_shape) 435 box_to_tile, point_to_tile, label_to_tile, logits_to_tile = {}, {}, {}, {} 436 tile_ids = [] 437 438 # Box prompts are in the format N x 4. Boxes are stored in order [MIN_X, MIN_Y, MAX_X, MAX_Y]. 439 # Points are in the Format N x 1 x 2 with coordinate order X, Y. 440 # Point labels are in the format N x 1. 441 # Mask prompts are in the format N x 1 x 256 x 256. 442 for prompt_id in range(n_prompts): 443 this_tile_id = None 444 445 if have_boxes: 446 box = boxes[prompt_id] 447 center = np.array([(box[1] + box[3]) / 2, (box[0] + box[2]) / 2]).round().astype("int").tolist() 448 this_tile_id = tiling.coordinates_to_block_id(center) 449 tile = tiling.get_block_with_halo(this_tile_id, list(halo)).outer_block 450 offset = tile.begin 451 this_tile_shape = tile.shape 452 box_in_tile = np.array( 453 [ 454 max(box[1] - offset[0], 0), max(box[0] - offset[1], 0), 455 min(box[3] - offset[0], this_tile_shape[0]), min(box[2] - offset[1], this_tile_shape[1]) 456 ] 457 )[None] 458 if this_tile_id in box_to_tile: 459 box_to_tile[this_tile_id] = np.concatenate([box_to_tile[this_tile_id], box_in_tile]) 460 else: 461 box_to_tile[this_tile_id] = box_in_tile 462 463 if have_points: 464 point = points[prompt_id, 0][::-1].round().astype("int").tolist() 465 if this_tile_id is None: 466 this_tile_id = tiling.coordinates_to_block_id(point) 467 else: 468 assert this_tile_id == tiling.coordinates_to_block_id(point) 469 tile = tiling.get_block_with_halo(this_tile_id, list(halo)).outer_block 470 offset = tile.begin 471 point_in_tile = (points[prompt_id, 0] - np.array(offset)[::-1])[None, None] 472 label_in_tile = point_labels[prompt_id][None] 473 if this_tile_id in point_to_tile: 474 point_to_tile[this_tile_id] = np.concatenate([point_to_tile[this_tile_id], point_in_tile]) 475 label_to_tile[this_tile_id] = np.concatenate([label_to_tile[this_tile_id], label_in_tile]) 476 else: 477 point_to_tile[this_tile_id] = point_in_tile 478 label_to_tile[this_tile_id] = label_in_tile 479 480 # NOTE: logits are not yet supported. 481 tile_ids.append(this_tile_id) 482 483 # Find the tiles with prompts. 484 tile_ids = sorted(list(set(tile_ids))) 485 486 # Run batched inference for each tile. 487 masks = [] 488 # Additional variables needed for optimized memory mode. 489 id_offset = 0 490 for tile_id in tqdm(tile_ids, desc="Run batched inference"): 491 # Get the prompts for this tile. 492 tile_boxes = box_to_tile.get(tile_id) 493 tile_logits = logits_to_tile.get(tile_id) 494 tile_points, tile_labels = point_to_tile.get(tile_id), label_to_tile.get(tile_id) 495 496 # Set the correct embeddings, run inference. 497 predictor = util.set_precomputed(predictor, image_embeddings, tile_id=tile_id, i=i) 498 this_masks = batched_inference( 499 predictor=predictor, 500 image=None, 501 batch_size=batch_size, 502 boxes=tile_boxes, 503 points=tile_points, 504 point_labels=tile_labels, 505 multimasking=multimasking, 506 return_instance_segmentation=False, 507 segmentation_ids=segmentation_ids, 508 reduce_multimasking=reduce_multimasking, 509 logits_masks=tile_logits, 510 mask_threshold=mask_threshold, 511 ) 512 513 if optimize_memory: 514 # Apply NMS directly to get a segmentation. 515 segmentation = util.apply_nms(this_masks, **nms_kwargs) 516 fg_mask = segmentation != 0 517 segmentation[fg_mask] += id_offset 518 id_offset = segmentation.max() 519 masks.append(segmentation) 520 else: 521 # Add the offset for the current tile to the bounding box. 522 tile = tiling.get_block_with_halo(tile_id, list(halo)).outer_block 523 offset = np.array(tile.begin[::-1] + [0, 0]) 524 this_masks = [ 525 {**mask, "global_bbox": (np.array(mask["bbox"]) + offset).tolist()} for mask in this_masks 526 ] 527 masks.extend(this_masks) 528 529 # Try to keep the memory clean. 530 del this_masks 531 gc.collect() 532 533 if optimize_memory: 534 return _stitch_segmentation(masks, tile_ids, tiling, halo, output_shape=shape) 535 536 if return_instance_segmentation: 537 masks = util.mask_data_to_segmentation(masks, shape=shape, min_object_size=0) 538 return masks
@torch.no_grad()
def
batched_inference( predictor: segment_anything.predictor.SamPredictor, image: Optional[numpy.ndarray], batch_size: int, boxes: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, point_labels: Optional[numpy.ndarray] = None, multimasking: bool = False, embedding_path: Union[str, os.PathLike, NoneType] = None, return_instance_segmentation: bool = True, segmentation_ids: Optional[list] = None, reduce_multimasking: bool = True, logits_masks: Optional[torch.Tensor] = None, verbose_embeddings: bool = True, mask_threshold: Union[float, str, NoneType] = None, return_highres_logits: bool = False, i: Optional[int] = None) -> Union[List[List[Dict[str, Any]]], numpy.ndarray]:
155@torch.no_grad() 156def batched_inference( 157 predictor: SamPredictor, 158 image: Optional[np.ndarray], 159 batch_size: int, 160 boxes: Optional[np.ndarray] = None, 161 points: Optional[np.ndarray] = None, 162 point_labels: Optional[np.ndarray] = None, 163 multimasking: bool = False, 164 embedding_path: Optional[Union[str, os.PathLike]] = None, 165 return_instance_segmentation: bool = True, 166 segmentation_ids: Optional[list] = None, 167 reduce_multimasking: bool = True, 168 logits_masks: Optional[torch.Tensor] = None, 169 verbose_embeddings: bool = True, 170 mask_threshold: Optional[Union[float, str]] = None, 171 return_highres_logits: bool = False, 172 i: Optional[int] = None, 173) -> Union[List[List[Dict[str, Any]]], np.ndarray]: 174 """Run batched inference for input prompts. 175 176 Args: 177 predictor: The Segment Anything predictor. 178 image: The input image. If None, we assume that the image embeddings have already been computed. 179 batch_size: The batch size to use for inference. 180 boxes: The box prompts. Array of shape N_PROMPTS x 4. 181 The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y]. 182 points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. 183 The points are represented by their coordinates [X, Y], which are given in the last dimension. 184 point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. 185 The labels are either 0 (negative prompt) or 1 (positive prompt). 186 multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'. 187 embedding_path: Cache path for the image embeddings. By default, computed on-the-fly. 188 return_instance_segmentation: Whether to return a instance segmentation 189 or the individual mask data. By default, set to 'True'. 190 segmentation_ids: Fixed segmentation ids to assign to the masks 191 derived from the prompts. 192 reduce_multimasking: Whether to choose the most likely masks with 193 highest ious from multimasking. By default, set to 'True'. 194 logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. 195 Whether to use the logits masks from previous segmentation. 196 verbose_embeddings: Whether to show progress outputs of computing image embeddings. 197 By default, set to 'True'. 198 mask_threshold: The theshold for binarizing masks based on the predicted values. 199 If None, the default threshold 0 is used. If "auto" is passed then the threshold is 200 determined with a local otsu filter. 201 return_highres_logits: Wheher to return high-resolution logits. 202 i: Index for the image data. Required if `image` has three spatial dimensions 203 or a time dimension and two spatial dimensions. 204 205 Returns: 206 The predicted segmentation masks. 207 """ 208 n_prompts, have_boxes, have_points, have_logits = _validate_inputs( 209 boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks 210 ) 211 212 # Compute the image embeddings. 213 if image is None: # This means the image embeddings are computed already. 214 # Call get image embeddings, this will throw an error if they have not yet been computed. 215 predictor.get_image_embedding() 216 else: 217 input_ = image if i is None else image[i] 218 image_embeddings = util.precompute_image_embeddings( 219 predictor, input_, embedding_path, verbose=verbose_embeddings 220 ) 221 util.set_precomputed(predictor, image_embeddings) 222 223 # Determine the number of batches. 224 n_batches = int(np.ceil(float(n_prompts) / batch_size)) 225 226 # Preprocess the prompts. 227 device = predictor.device 228 transform_function = ResizeLongestSide(1024) 229 image_shape = predictor.original_size 230 if have_boxes: 231 boxes = transform_function.apply_boxes(boxes, image_shape) 232 boxes = torch.tensor(boxes, dtype=torch.float32).to(device) 233 if have_points: 234 points = transform_function.apply_coords(points, image_shape) 235 points = torch.tensor(points, dtype=torch.float32).to(device) 236 point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device) 237 238 masks = amg_utils.MaskData() 239 mask_threshold = predictor.model.mask_threshold if mask_threshold is None else mask_threshold 240 for batch_idx in range(n_batches): 241 batch_start = batch_idx * batch_size 242 batch_stop = min((batch_idx + 1) * batch_size, n_prompts) 243 244 batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None 245 batch_points = points[batch_start:batch_stop] if have_points else None 246 batch_labels = point_labels[batch_start:batch_stop] if have_points else None 247 batch_logits = logits_masks[batch_start:batch_stop] if have_logits else None 248 249 batch_masks, batch_ious, batch_logits = predictor.predict_torch( 250 point_coords=batch_points, 251 point_labels=batch_labels, 252 boxes=batch_boxes, 253 mask_input=batch_logits, 254 multimask_output=multimasking, 255 return_logits=True, 256 ) 257 258 # If we expect to reduce the masks from multimasking and use multi-masking, 259 # then we need to select the most likely mask (according to the predicted IOU) here. 260 if reduce_multimasking and multimasking: 261 _, max_index = batch_ious.max(axis=1) 262 batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 263 batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 264 batch_logits = torch.cat([batch_logits[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1) 265 266 batch_data = _process_masks_for_batch( 267 batch_masks, batch_ious, batch_logits, return_highres_logits, mask_threshold 268 ) 269 masks.cat(batch_data) 270 271 # Mask data to records. 272 masks = [ 273 { 274 "segmentation": masks["masks"][idx], 275 "area": masks["masks"][idx].sum(), 276 "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(), 277 "predicted_iou": masks["iou_preds"][idx].item(), 278 "stability_score": masks["stability_scores"][idx].item(), 279 "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]), 280 "logits": masks["logits"][idx] 281 } 282 for idx in range(len(masks["masks"])) 283 ] 284 285 if return_instance_segmentation: 286 masks = util.mask_data_to_segmentation(masks, min_object_size=0) 287 return masks
Run batched inference for input prompts.
Arguments:
- predictor: The Segment Anything predictor.
- image: The input image. If None, we assume that the image embeddings have already been computed.
- batch_size: The batch size to use for inference.
- boxes: The box prompts. Array of shape N_PROMPTS x 4. The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
- points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. The points are represented by their coordinates [X, Y], which are given in the last dimension.
- point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. The labels are either 0 (negative prompt) or 1 (positive prompt).
- multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'.
- embedding_path: Cache path for the image embeddings. By default, computed on-the-fly.
- return_instance_segmentation: Whether to return a instance segmentation or the individual mask data. By default, set to 'True'.
- segmentation_ids: Fixed segmentation ids to assign to the masks derived from the prompts.
- reduce_multimasking: Whether to choose the most likely masks with highest ious from multimasking. By default, set to 'True'.
- logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. Whether to use the logits masks from previous segmentation.
- verbose_embeddings: Whether to show progress outputs of computing image embeddings. By default, set to 'True'.
- mask_threshold: The theshold for binarizing masks based on the predicted values. If None, the default threshold 0 is used. If "auto" is passed then the threshold is determined with a local otsu filter.
- return_highres_logits: Wheher to return high-resolution logits.
- i: Index for the image data. Required if
imagehas three spatial dimensions or a time dimension and two spatial dimensions.
Returns:
The predicted segmentation masks.
@torch.no_grad()
def
batched_tiled_inference( predictor: segment_anything.predictor.SamPredictor, image: Optional[numpy.ndarray], batch_size: int, image_embeddings: Optional[Dict[str, Any]] = None, boxes: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, point_labels: Optional[numpy.ndarray] = None, multimasking: bool = False, embedding_path: Union[str, os.PathLike, NoneType] = None, return_instance_segmentation: bool = True, reduce_multimasking: bool = True, logits_masks: Optional[torch.Tensor] = None, verbose_embeddings: bool = True, mask_threshold: Union[float, str, NoneType] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, optimize_memory: bool = False, i: Optional[int] = None, **nms_kwargs) -> Union[List[List[Dict[str, Any]]], numpy.ndarray]:
358@torch.no_grad() 359def batched_tiled_inference( 360 predictor: SamPredictor, 361 image: Optional[np.ndarray], 362 batch_size: int, 363 image_embeddings: Optional[util.ImageEmbeddings] = None, 364 boxes: Optional[np.ndarray] = None, 365 points: Optional[np.ndarray] = None, 366 point_labels: Optional[np.ndarray] = None, 367 multimasking: bool = False, 368 embedding_path: Optional[Union[str, os.PathLike]] = None, 369 return_instance_segmentation: bool = True, 370 reduce_multimasking: bool = True, 371 logits_masks: Optional[torch.Tensor] = None, 372 verbose_embeddings: bool = True, 373 mask_threshold: Optional[Union[float, str]] = None, 374 tile_shape: Optional[Tuple[int, int]] = None, 375 halo: Optional[Tuple[int, int]] = None, 376 optimize_memory: bool = False, 377 i: Optional[int] = None, 378 **nms_kwargs, 379) -> Union[List[List[Dict[str, Any]]], np.ndarray]: 380 """Run batched inference for input prompts. 381 382 Args: 383 predictor: The Segment Anything predictor. 384 image: The input image. If None, we assume that the image embeddings have already been computed. 385 batch_size: The batch size to use for inference. 386 boxes: The box prompts. Array of shape N_PROMPTS x 4. 387 The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y]. 388 points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. 389 The points are represented by their coordinates [X, Y], which are given in the last dimension. 390 point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. 391 The labels are either 0 (negative prompt) or 1 (positive prompt). 392 multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'. 393 embedding_path: Cache path for the image embeddings. By default, computed on-the-fly. 394 return_instance_segmentation: Whether to return a instance segmentation 395 or the individual mask data. By default, set to 'True'. 396 segmentation_ids: Fixed segmentation ids to assign to the masks 397 derived from the prompts. 398 reduce_multimasking: Whether to choose the most likely masks with 399 highest ious from multimasking. By default, set to 'True'. 400 logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. 401 Whether to use the logits masks from previous segmentation. 402 verbose_embeddings: Whether to show progress outputs of computing image embeddings. 403 By default, set to 'True'. 404 mask_threshold: The theshold for binarizing masks based on the predicted values. 405 If None, the default threshold 0 is used. If "auto" is passed then the threshold is 406 determined with a local otsu filter. 407 tile_shape: The tile shape for embedding prediction. 408 halo: The overlap of between tiles. 409 optimize_memory: Whether to optimize the memory usage. If set to True: 410 - NMS will be applied directly for each tile to reduce it to a per-tile instance segmentation. 411 - The per-tile segmentations will be stitched. 412 - The result will be returned as an instance segmentation. 413 i: Index for the image data. Required if `image` has three spatial dimensions 414 or a time dimension and two spatial dimensions. 415 nms_kwargs: Keyword arguments for the NMS operations that is used for optimize_memory=True. 416 Does not have any effect for optimize_memory=False. 417 418 Returns: 419 The predicted segmentation masks. 420 """ 421 # Validate inputs and get input prompt summary. 422 segmentation_ids = None 423 n_prompts, have_boxes, have_points, have_logits = _validate_inputs( 424 boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks 425 ) 426 if have_logits: 427 raise NotImplementedError 428 429 # Get the tiling parameters and compute embeddings if needed. 430 image_embeddings, shape, tile_shape, halo = _require_tiled_embeddings( 431 predictor, image, image_embeddings, embedding_path, tile_shape, halo, verbose_embeddings 432 ) 433 434 # Order the prompts by tile and then iterate over the tiles. 435 tiling = Blocking([0, 0], shape, tile_shape) 436 box_to_tile, point_to_tile, label_to_tile, logits_to_tile = {}, {}, {}, {} 437 tile_ids = [] 438 439 # Box prompts are in the format N x 4. Boxes are stored in order [MIN_X, MIN_Y, MAX_X, MAX_Y]. 440 # Points are in the Format N x 1 x 2 with coordinate order X, Y. 441 # Point labels are in the format N x 1. 442 # Mask prompts are in the format N x 1 x 256 x 256. 443 for prompt_id in range(n_prompts): 444 this_tile_id = None 445 446 if have_boxes: 447 box = boxes[prompt_id] 448 center = np.array([(box[1] + box[3]) / 2, (box[0] + box[2]) / 2]).round().astype("int").tolist() 449 this_tile_id = tiling.coordinates_to_block_id(center) 450 tile = tiling.get_block_with_halo(this_tile_id, list(halo)).outer_block 451 offset = tile.begin 452 this_tile_shape = tile.shape 453 box_in_tile = np.array( 454 [ 455 max(box[1] - offset[0], 0), max(box[0] - offset[1], 0), 456 min(box[3] - offset[0], this_tile_shape[0]), min(box[2] - offset[1], this_tile_shape[1]) 457 ] 458 )[None] 459 if this_tile_id in box_to_tile: 460 box_to_tile[this_tile_id] = np.concatenate([box_to_tile[this_tile_id], box_in_tile]) 461 else: 462 box_to_tile[this_tile_id] = box_in_tile 463 464 if have_points: 465 point = points[prompt_id, 0][::-1].round().astype("int").tolist() 466 if this_tile_id is None: 467 this_tile_id = tiling.coordinates_to_block_id(point) 468 else: 469 assert this_tile_id == tiling.coordinates_to_block_id(point) 470 tile = tiling.get_block_with_halo(this_tile_id, list(halo)).outer_block 471 offset = tile.begin 472 point_in_tile = (points[prompt_id, 0] - np.array(offset)[::-1])[None, None] 473 label_in_tile = point_labels[prompt_id][None] 474 if this_tile_id in point_to_tile: 475 point_to_tile[this_tile_id] = np.concatenate([point_to_tile[this_tile_id], point_in_tile]) 476 label_to_tile[this_tile_id] = np.concatenate([label_to_tile[this_tile_id], label_in_tile]) 477 else: 478 point_to_tile[this_tile_id] = point_in_tile 479 label_to_tile[this_tile_id] = label_in_tile 480 481 # NOTE: logits are not yet supported. 482 tile_ids.append(this_tile_id) 483 484 # Find the tiles with prompts. 485 tile_ids = sorted(list(set(tile_ids))) 486 487 # Run batched inference for each tile. 488 masks = [] 489 # Additional variables needed for optimized memory mode. 490 id_offset = 0 491 for tile_id in tqdm(tile_ids, desc="Run batched inference"): 492 # Get the prompts for this tile. 493 tile_boxes = box_to_tile.get(tile_id) 494 tile_logits = logits_to_tile.get(tile_id) 495 tile_points, tile_labels = point_to_tile.get(tile_id), label_to_tile.get(tile_id) 496 497 # Set the correct embeddings, run inference. 498 predictor = util.set_precomputed(predictor, image_embeddings, tile_id=tile_id, i=i) 499 this_masks = batched_inference( 500 predictor=predictor, 501 image=None, 502 batch_size=batch_size, 503 boxes=tile_boxes, 504 points=tile_points, 505 point_labels=tile_labels, 506 multimasking=multimasking, 507 return_instance_segmentation=False, 508 segmentation_ids=segmentation_ids, 509 reduce_multimasking=reduce_multimasking, 510 logits_masks=tile_logits, 511 mask_threshold=mask_threshold, 512 ) 513 514 if optimize_memory: 515 # Apply NMS directly to get a segmentation. 516 segmentation = util.apply_nms(this_masks, **nms_kwargs) 517 fg_mask = segmentation != 0 518 segmentation[fg_mask] += id_offset 519 id_offset = segmentation.max() 520 masks.append(segmentation) 521 else: 522 # Add the offset for the current tile to the bounding box. 523 tile = tiling.get_block_with_halo(tile_id, list(halo)).outer_block 524 offset = np.array(tile.begin[::-1] + [0, 0]) 525 this_masks = [ 526 {**mask, "global_bbox": (np.array(mask["bbox"]) + offset).tolist()} for mask in this_masks 527 ] 528 masks.extend(this_masks) 529 530 # Try to keep the memory clean. 531 del this_masks 532 gc.collect() 533 534 if optimize_memory: 535 return _stitch_segmentation(masks, tile_ids, tiling, halo, output_shape=shape) 536 537 if return_instance_segmentation: 538 masks = util.mask_data_to_segmentation(masks, shape=shape, min_object_size=0) 539 return masks
Run batched inference for input prompts.
Arguments:
- predictor: The Segment Anything predictor.
- image: The input image. If None, we assume that the image embeddings have already been computed.
- batch_size: The batch size to use for inference.
- boxes: The box prompts. Array of shape N_PROMPTS x 4. The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
- points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. The points are represented by their coordinates [X, Y], which are given in the last dimension.
- point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. The labels are either 0 (negative prompt) or 1 (positive prompt).
- multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'.
- embedding_path: Cache path for the image embeddings. By default, computed on-the-fly.
- return_instance_segmentation: Whether to return a instance segmentation or the individual mask data. By default, set to 'True'.
- segmentation_ids: Fixed segmentation ids to assign to the masks derived from the prompts.
- reduce_multimasking: Whether to choose the most likely masks with highest ious from multimasking. By default, set to 'True'.
- logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. Whether to use the logits masks from previous segmentation.
- verbose_embeddings: Whether to show progress outputs of computing image embeddings. By default, set to 'True'.
- mask_threshold: The theshold for binarizing masks based on the predicted values. If None, the default threshold 0 is used. If "auto" is passed then the threshold is determined with a local otsu filter.
- tile_shape: The tile shape for embedding prediction.
- halo: The overlap of between tiles.
- optimize_memory: Whether to optimize the memory usage. If set to True:
- NMS will be applied directly for each tile to reduce it to a per-tile instance segmentation.
- The per-tile segmentations will be stitched.
- The result will be returned as an instance segmentation.
- i: Index for the image data. Required if
imagehas three spatial dimensions or a time dimension and two spatial dimensions. - nms_kwargs: Keyword arguments for the NMS operations that is used for optimize_memory=True. Does not have any effect for optimize_memory=False.
Returns:
The predicted segmentation masks.