micro_sam.inference

  1import os
  2import gc
  3from typing import Any, Dict, List, Optional, Union, Tuple
  4
  5import numpy as np
  6import torch
  7import torch.nn.functional as F
  8from nifty.tools import blocking
  9import nifty.ground_truth as ngt
 10
 11import segment_anything.utils.amg as amg_utils
 12from segment_anything import SamPredictor
 13from segment_anything.utils.transforms import ResizeLongestSide
 14try:
 15    from napari.utils import progress as tqdm
 16except ImportError:
 17    from tqdm import tqdm
 18
 19from . import util
 20from ._vendored import batched_mask_to_box
 21
 22
 23def _validate_inputs(
 24    boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks
 25):
 26    if multimasking and (segmentation_ids is not None) and (not return_instance_segmentation):
 27        raise NotImplementedError
 28
 29    if (points is None) != (point_labels is None):
 30        raise ValueError(
 31            "If you have point prompts both `points` and `point_labels` have to be passed, "
 32            "but you passed only one of them."
 33        )
 34
 35    have_points = points is not None
 36    have_boxes = boxes is not None
 37    have_logits = logits_masks is not None
 38    if (not have_points) and (not have_boxes):
 39        raise ValueError("Point and/or box prompts have to be passed, you passed neither.")
 40
 41    if have_points and (len(point_labels) != len(points)):
 42        raise ValueError(
 43            f"The number of point coordinates and labels does not match: {len(point_labels)} != {len(points)}"
 44        )
 45
 46    if (have_points and have_boxes) and (len(points) != len(boxes)):
 47        raise ValueError(
 48            f"The number of point and box prompts does not match: {len(points)} != {len(boxes)}"
 49        )
 50
 51    if have_logits:
 52        if have_points and (len(logits_masks) != len(point_labels)):
 53            raise ValueError(
 54                f"The number of point and logits does not match: {len(points) != len(logits_masks)}"
 55            )
 56        elif have_boxes and (len(logits_masks) != len(boxes)):
 57            raise ValueError(
 58                f"The number of boxes and logits does not match: {len(boxes)} != {len(logits_masks)}"
 59            )
 60
 61    n_prompts = boxes.shape[0] if have_boxes else points.shape[0]
 62
 63    if (segmentation_ids is not None) and (len(segmentation_ids) != n_prompts):
 64        raise ValueError(
 65            f"The number of segmentation ids and prompts does not match: {len(segmentation_ids)} != {n_prompts}"
 66        )
 67
 68    return n_prompts, have_boxes, have_points, have_logits
 69
 70
 71def _local_otsu_threshold(images: torch.Tensor, window_size: int = 31, num_bins: int = 64, eps: float = 1e-6):
 72    x = images
 73    B, _, H, W = x.shape
 74    device = x.device
 75
 76    # Work in float32 for stability even if input is fp16
 77    x = x.to(torch.float32)
 78
 79    # --- per-image min/max for normalization to [0, 1] ---
 80    x_flat = x.view(B, -1)
 81    x_min = x_flat.min(dim=1).values.view(B, 1, 1, 1)
 82    x_max = x_flat.max(dim=1).values.view(B, 1, 1, 1)
 83    x_range = (x_max - x_min).clamp_min(eps)
 84
 85    x_norm = (x - x_min) / x_range  # (B,1,H,W), in [0,1]
 86
 87    # --- extract local patches via unfold ---
 88    pad = window_size // 2
 89    patches = F.unfold(x_norm, kernel_size=window_size, padding=pad)  # (B, P, L)
 90    # P = window_size * window_size, L = H * W
 91    B_, P, L = patches.shape
 92
 93    # --- quantize to bins ---
 94    bin_idx = (patches * (num_bins - 1)).long().clamp(0, num_bins - 1)  # (B, P, L)
 95
 96    # --- build histograms per patch ---
 97    # one_hot: (B, L, num_bins)
 98    one_hot = torch.zeros(B, L, num_bins, device=device, dtype=torch.float32)
 99    idx = bin_idx.transpose(1, 2)  # (B, L, P)
100    src = torch.ones_like(idx, dtype=one_hot.dtype)  # (B, L, P)
101    one_hot.scatter_add_(2, idx, src)
102    # hist: (B, num_bins, L)
103    hist = one_hot.permute(0, 2, 1)
104
105    # --- Otsu per patch (vectorized) ---
106    # p: (B, bins, L)
107    p = hist / hist.sum(dim=1, keepdim=True).clamp_min(eps)
108
109    bins = torch.arange(num_bins, device=device, dtype=torch.float32).view(1, num_bins, 1)
110
111    omega1 = torch.cumsum(p, dim=1)              # (B, bins, L)
112    mu = torch.cumsum(p * bins, dim=1)           # (B, bins, L)
113    mu_T = mu[:, -1:, :]                         # (B, 1, L)
114
115    omega2 = 1.0 - omega1
116
117    mu1 = mu / omega1.clamp_min(eps)
118    mu2 = (mu_T - mu) / omega2.clamp_min(eps)
119
120    sigma_b2 = omega1 * omega2 * (mu1 - mu2) ** 2  # (B, bins, L)
121
122    # argmax over bins gives local threshold bin per patch
123    t_bin = torch.argmax(sigma_b2, dim=1)        # (B, L)
124    t_norm = t_bin.to(torch.float32) / (num_bins - 1)  # normalized [0,1]
125
126    # --- map thresholds back to original intensity scale (per-image) ---
127    # x_min, x_range: (B,1,1,1) -> flatten batch dims
128    thr_vals = x_min.view(B, 1) + t_norm * x_range.view(B, 1)  # (B, L)
129    # clamp to >= 0 because foreground is positive
130    thr_vals = thr_vals.clamp_min(0.0)
131
132    thresholds = thr_vals.view(B, H, W)
133    # Take the spatial max over the thresholds.
134    thresholds = torch.amax(thresholds, dim=(1, 2), keepdims=True)
135    return thresholds
136
137
138def _process_masks_for_batch(batch_masks, batch_ious, batch_logits, return_highres_logits, mask_threshold):
139    batch_data = amg_utils.MaskData(masks=batch_masks.flatten(0, 1), iou_preds=batch_ious.flatten(0, 1))
140    batch_data["logits"] = batch_masks.clone() if return_highres_logits else batch_logits
141    # TODO: probably the best heuristic: choose the lowest threshold which still has stable masks.
142    # To do this: go through the thresholds, starting from 0 in smallish increments, compute stability score,
143    # select the threshold before the stability score starts to drop.
144    if mask_threshold == "auto":
145        thresholds = _local_otsu_threshold(batch_logits)
146        batch_data["stability_scores"] = amg_utils.calculate_stability_score(batch_data["masks"], thresholds, 1.0)
147        batch_data["masks"] = (batch_data["masks"] > thresholds).type(torch.bool)
148    else:
149        batch_data["stability_scores"] = amg_utils.calculate_stability_score(batch_data["masks"], mask_threshold, 1.0)
150        batch_data["masks"] = (batch_data["masks"] > mask_threshold).type(torch.bool)
151    batch_data["boxes"] = batched_mask_to_box(batch_data["masks"])
152    return batch_data
153
154
155@torch.no_grad()
156def batched_inference(
157    predictor: SamPredictor,
158    image: Optional[np.ndarray],
159    batch_size: int,
160    boxes: Optional[np.ndarray] = None,
161    points: Optional[np.ndarray] = None,
162    point_labels: Optional[np.ndarray] = None,
163    multimasking: bool = False,
164    embedding_path: Optional[Union[str, os.PathLike]] = None,
165    return_instance_segmentation: bool = True,
166    segmentation_ids: Optional[list] = None,
167    reduce_multimasking: bool = True,
168    logits_masks: Optional[torch.Tensor] = None,
169    verbose_embeddings: bool = True,
170    mask_threshold: Optional[Union[float, str]] = None,
171    return_highres_logits: bool = False,
172    i: Optional[int] = None,
173) -> Union[List[List[Dict[str, Any]]], np.ndarray]:
174    """Run batched inference for input prompts.
175
176    Args:
177        predictor: The Segment Anything predictor.
178        image: The input image. If None, we assume that the image embeddings have already been computed.
179        batch_size: The batch size to use for inference.
180        boxes: The box prompts. Array of shape N_PROMPTS x 4.
181            The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
182        points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2.
183            The points are represented by their coordinates [X, Y], which are given in the last dimension.
184        point_labels: The point prompt labels. Array of shape N_PROMPTS x 1.
185            The labels are either 0 (negative prompt) or 1 (positive prompt).
186        multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'.
187        embedding_path: Cache path for the image embeddings. By default, computed on-the-fly.
188        return_instance_segmentation: Whether to return a instance segmentation
189            or the individual mask data. By default, set to 'True'.
190        segmentation_ids: Fixed segmentation ids to assign to the masks
191            derived from the prompts.
192        reduce_multimasking: Whether to choose the most likely masks with
193            highest ious from multimasking. By default, set to 'True'.
194        logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256.
195            Whether to use the logits masks from previous segmentation.
196        verbose_embeddings: Whether to show progress outputs of computing image embeddings.
197            By default, set to 'True'.
198        mask_threshold: The theshold for binarizing masks based on the predicted values.
199            If None, the default threshold 0 is used. If "auto" is passed then the threshold is
200            determined with a local otsu filter.
201        return_highres_logits: Wheher to return high-resolution logits.
202        i: Index for the image data. Required if `image` has three spatial dimensions
203            or a time dimension and two spatial dimensions.
204
205    Returns:
206        The predicted segmentation masks.
207    """
208    n_prompts, have_boxes, have_points, have_logits = _validate_inputs(
209        boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks
210    )
211
212    # Compute the image embeddings.
213    if image is None:  # This means the image embeddings are computed already.
214        # Call get image embeddings, this will throw an error if they have not yet been computed.
215        predictor.get_image_embedding()
216    else:
217        image_embeddings = util.precompute_image_embeddings(
218            predictor, image, embedding_path, verbose=verbose_embeddings, i=i,
219        )
220        util.set_precomputed(predictor, image_embeddings)
221
222    # Determine the number of batches.
223    n_batches = int(np.ceil(float(n_prompts) / batch_size))
224
225    # Preprocess the prompts.
226    device = predictor.device
227    transform_function = ResizeLongestSide(1024)
228    image_shape = predictor.original_size
229    if have_boxes:
230        boxes = transform_function.apply_boxes(boxes, image_shape)
231        boxes = torch.tensor(boxes, dtype=torch.float32).to(device)
232    if have_points:
233        points = transform_function.apply_coords(points, image_shape)
234        points = torch.tensor(points, dtype=torch.float32).to(device)
235        point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device)
236
237    masks = amg_utils.MaskData()
238    mask_threshold = predictor.model.mask_threshold if mask_threshold is None else mask_threshold
239    for batch_idx in range(n_batches):
240        batch_start = batch_idx * batch_size
241        batch_stop = min((batch_idx + 1) * batch_size, n_prompts)
242
243        batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None
244        batch_points = points[batch_start:batch_stop] if have_points else None
245        batch_labels = point_labels[batch_start:batch_stop] if have_points else None
246        batch_logits = logits_masks[batch_start:batch_stop] if have_logits else None
247
248        batch_masks, batch_ious, batch_logits = predictor.predict_torch(
249            point_coords=batch_points,
250            point_labels=batch_labels,
251            boxes=batch_boxes,
252            mask_input=batch_logits,
253            multimask_output=multimasking,
254            return_logits=True,
255        )
256
257        # If we expect to reduce the masks from multimasking and use multi-masking,
258        # then we need to select the most likely mask (according to the predicted IOU) here.
259        if reduce_multimasking and multimasking:
260            _, max_index = batch_ious.max(axis=1)
261            batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
262            batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
263            batch_logits = torch.cat([batch_logits[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
264
265        batch_data = _process_masks_for_batch(
266            batch_masks, batch_ious, batch_logits, return_highres_logits, mask_threshold
267        )
268        masks.cat(batch_data)
269
270    # Mask data to records.
271    masks = [
272        {
273            "segmentation": masks["masks"][idx],
274            "area": masks["masks"][idx].sum(),
275            "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(),
276            "predicted_iou": masks["iou_preds"][idx].item(),
277            "stability_score": masks["stability_scores"][idx].item(),
278            "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]),
279            "logits": masks["logits"][idx]
280        }
281        for idx in range(len(masks["masks"]))
282    ]
283
284    if return_instance_segmentation:
285        masks = util.mask_data_to_segmentation(masks, min_object_size=0)
286    return masks
287
288
289def _require_tiled_embeddings(
290    predictor, image, image_embeddings, embedding_path, tile_shape, halo, verbose_embeddings
291):
292    if image_embeddings is None:
293        assert image is not None
294        assert (tile_shape is not None) and (halo is not None)
295        shape = image.shape
296        image_embeddings = util.precompute_image_embeddings(
297            predictor, image, embedding_path, ndim=2, tile_shape=tile_shape, halo=halo, verbose=verbose_embeddings
298        )
299    else:  # This means the image embeddings are computed already.
300        attrs = image_embeddings["features"].attrs
301        tile_shape_, halo_ = attrs["tile_shape"], attrs["halo"]
302        shape = attrs["shape"]
303        if tile_shape is None:
304            tile_shape = tile_shape_
305        elif any(ts != ts_ for ts, ts_ in zip(tile_shape, tile_shape_)):
306            raise ValueError(f"Incompatible tile shapes: {tile_shape} != {tile_shape_}")
307        if halo is None:
308            halo = halo_
309        elif any(ts != ts_ for ts, ts_ in zip(halo, halo_)):
310            raise ValueError(f"Incompatible tile shapes: {halo} != {halo_}")
311
312    return image_embeddings, shape, tile_shape, halo
313
314
315def _merge_segmentations(this_seg, prev_seg, overlap_threshold=0.75):
316    # Discard new ids with too much overlap.
317    ovlp = ngt.overlap(this_seg, prev_seg)
318    ids = np.unique(this_seg)
319    if ids[0] == 0:
320        ids = ids[1:]
321    discard_ids = []
322    for seg_id in ids:
323        ovlp_ids, ovlp_vals = ovlp.overlapArraysNormalized(seg_id, True)
324        ovlp_vals = ovlp_vals[ovlp_ids != 0]
325        if ovlp_vals.size > 0 and ovlp_vals[0] > overlap_threshold:
326            discard_ids.append(seg_id)
327
328    # Make sure the previous segmentation is fully preserved.
329    captured = prev_seg != 0
330    this_seg[captured] = prev_seg[captured]
331    return this_seg
332
333
334# Note: we merge with a very simple first come first serve strategy.
335# This could also be improved by merging according to IoUs.
336# (But would add a lot of complexity)
337def _stitch_segmentation(masks, tile_ids, tiling, halo, output_shape, verbose=False):
338    assert len(masks) == len(tile_ids), f"{len(masks)}, {len(tile_ids)}"
339    segmentation = np.zeros(output_shape, dtype="uint32")
340
341    for tile_id, this_seg in tqdm(zip(tile_ids, masks), desc="Stitch tiles", disable=not verbose):
342        tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
343        bb = tuple(slice(begin, end) for begin, end in zip(tile.begin, tile.end))
344        if tile_id == 0:
345            segmentation[bb] = this_seg
346        else:
347            # Merge the segmentation, discarding ids with too much overlap.
348            prev_seg = segmentation[bb]
349            assert prev_seg.shape == this_seg.shape, f"{tile_id}: {prev_seg.shape}, {this_seg.shape}"
350            this_seg = _merge_segmentations(this_seg, prev_seg)
351            segmentation[bb] = this_seg
352
353    return segmentation
354
355
356@torch.no_grad()
357def batched_tiled_inference(
358    predictor: SamPredictor,
359    image: Optional[np.ndarray],
360    batch_size: int,
361    image_embeddings: Optional[util.ImageEmbeddings] = None,
362    boxes: Optional[np.ndarray] = None,
363    points: Optional[np.ndarray] = None,
364    point_labels: Optional[np.ndarray] = None,
365    multimasking: bool = False,
366    embedding_path: Optional[Union[str, os.PathLike]] = None,
367    return_instance_segmentation: bool = True,
368    reduce_multimasking: bool = True,
369    logits_masks: Optional[torch.Tensor] = None,
370    verbose_embeddings: bool = True,
371    mask_threshold: Optional[Union[float, str]] = None,
372    tile_shape: Optional[Tuple[int, int]] = None,
373    halo: Optional[Tuple[int, int]] = None,
374    optimize_memory: bool = False,
375    i: Optional[int] = None,
376    **nms_kwargs,
377) -> Union[List[List[Dict[str, Any]]], np.ndarray]:
378    """Run batched inference for input prompts.
379
380    Args:
381        predictor: The Segment Anything predictor.
382        image: The input image. If None, we assume that the image embeddings have already been computed.
383        batch_size: The batch size to use for inference.
384        boxes: The box prompts. Array of shape N_PROMPTS x 4.
385            The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
386        points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2.
387            The points are represented by their coordinates [X, Y], which are given in the last dimension.
388        point_labels: The point prompt labels. Array of shape N_PROMPTS x 1.
389            The labels are either 0 (negative prompt) or 1 (positive prompt).
390        multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'.
391        embedding_path: Cache path for the image embeddings. By default, computed on-the-fly.
392        return_instance_segmentation: Whether to return a instance segmentation
393            or the individual mask data. By default, set to 'True'.
394        segmentation_ids: Fixed segmentation ids to assign to the masks
395            derived from the prompts.
396        reduce_multimasking: Whether to choose the most likely masks with
397            highest ious from multimasking. By default, set to 'True'.
398        logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256.
399            Whether to use the logits masks from previous segmentation.
400        verbose_embeddings: Whether to show progress outputs of computing image embeddings.
401            By default, set to 'True'.
402        mask_threshold: The theshold for binarizing masks based on the predicted values.
403            If None, the default threshold 0 is used. If "auto" is passed then the threshold is
404            determined with a local otsu filter.
405        tile_shape: The tile shape for embedding prediction.
406        halo: The overlap of between tiles.
407        optimize_memory: Whether to optimize the memory usage. If set to True:
408            - NMS will be applied directly for each tile to reduce it to a per-tile instance segmentation.
409            - The per-tile segmentations will be stitched.
410            - The result will be returned as an instance segmentation.
411        i: Index for the image data. Required if `image` has three spatial dimensions
412            or a time dimension and two spatial dimensions.
413        nms_kwargs: Keyword arguments for the NMS operations that is used for optimize_memory=True.
414            Does not have any effcet for optimize_memory=False.
415
416    Returns:
417        The predicted segmentation masks.
418    """
419    # Validate inputs and get input prompt summary.
420    segmentation_ids = None
421    n_prompts, have_boxes, have_points, have_logits = _validate_inputs(
422        boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks
423    )
424    if have_logits:
425        raise NotImplementedError
426
427    # Get the tiling parameters and compute embeddings if needed.
428    image_embeddings, shape, tile_shape, halo = _require_tiled_embeddings(
429        predictor, image, image_embeddings, embedding_path, tile_shape, halo, verbose_embeddings
430    )
431
432    # Order the prompts by tile and then iterate over the tiles.
433    tiling = blocking([0, 0], shape, tile_shape)
434    box_to_tile, point_to_tile, label_to_tile, logits_to_tile = {}, {}, {}, {}
435    tile_ids = []
436
437    # Box prompts are in the format N x 4. Boxes are stored in order [MIN_X, MIN_Y, MAX_X, MAX_Y].
438    # Points are in the Format N x 1 x 2 with coordinate order X, Y.
439    # Point labels are in the format N x 1.
440    # Mask prompts are in the format N x 1 x 256 x 256.
441    for prompt_id in range(n_prompts):
442        this_tile_id = None
443
444        if have_boxes:
445            box = boxes[prompt_id]
446            center = np.array([(box[1] + box[3]) / 2, (box[0] + box[2]) / 2]).round().astype("int").tolist()
447            this_tile_id = tiling.coordinatesToBlockId(center)
448            tile = tiling.getBlockWithHalo(this_tile_id, list(halo)).outerBlock
449            offset = tile.begin
450            this_tile_shape = tile.shape
451            box_in_tile = np.array(
452                [
453                    max(box[1] - offset[0], 0), max(box[0] - offset[1], 0),
454                    min(box[3] - offset[0], this_tile_shape[0]), min(box[2] - offset[1], this_tile_shape[1])
455                ]
456            )[None]
457            if this_tile_id in box_to_tile:
458                box_to_tile[this_tile_id] = np.concatenate([box_to_tile[this_tile_id], box_in_tile])
459            else:
460                box_to_tile[this_tile_id] = box_in_tile
461
462        if have_points:
463            point = points[prompt_id, 0][::-1].round().astype("int").tolist()
464            if this_tile_id is None:
465                this_tile_id = tiling.coordinatesToBlockId(point)
466            else:
467                assert this_tile_id == tiling.coordinatesToBlockId(point)
468            tile = tiling.getBlockWithHalo(this_tile_id, list(halo)).outerBlock
469            offset = tile.begin
470            point_in_tile = (points[prompt_id, 0] - np.array(offset)[::-1])[None, None]
471            label_in_tile = point_labels[prompt_id][None]
472            if this_tile_id in point_to_tile:
473                point_to_tile[this_tile_id] = np.concatenate([point_to_tile[this_tile_id], point_in_tile])
474                label_to_tile[this_tile_id] = np.concatenate([label_to_tile[this_tile_id], label_in_tile])
475            else:
476                point_to_tile[this_tile_id] = point_in_tile
477                label_to_tile[this_tile_id] = label_in_tile
478
479        # NOTE: logits are not yet supported.
480        tile_ids.append(this_tile_id)
481
482    # Find the tiles with prompts.
483    tile_ids = sorted(list(set(tile_ids)))
484
485    # Run batched inference for each tile.
486    masks = []
487    # Additional variables needed for optimized memory mode.
488    id_offset = 0
489    for tile_id in tqdm(tile_ids, desc="Run batched inference"):
490        # Get the prompts for this tile.
491        tile_boxes = box_to_tile.get(tile_id)
492        tile_logits = logits_to_tile.get(tile_id)
493        tile_points, tile_labels = point_to_tile.get(tile_id), label_to_tile.get(tile_id)
494
495        # Set the correct embeddings, run inference.
496        predictor = util.set_precomputed(predictor, image_embeddings, tile_id=tile_id, i=i)
497        this_masks = batched_inference(
498            predictor=predictor,
499            image=None,
500            batch_size=batch_size,
501            boxes=tile_boxes,
502            points=tile_points,
503            point_labels=tile_labels,
504            multimasking=multimasking,
505            return_instance_segmentation=False,
506            segmentation_ids=segmentation_ids,
507            reduce_multimasking=reduce_multimasking,
508            logits_masks=tile_logits,
509            mask_threshold=mask_threshold,
510        )
511
512        if optimize_memory:
513            # Apply NMS directly to get a segmentation.
514            segmentation = util.apply_nms(this_masks, **nms_kwargs)
515            fg_mask = segmentation != 0
516            segmentation[fg_mask] += id_offset
517            id_offset = segmentation.max()
518            masks.append(segmentation)
519        else:
520            # Add the offset for the current tile to the bounding box.
521            tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
522            offset = np.array(tile.begin[::-1] + [0, 0])
523            this_masks = [
524                {**mask, "global_bbox": (np.array(mask["bbox"]) + offset).tolist()} for mask in this_masks
525            ]
526            masks.extend(this_masks)
527
528        # Try to keep the memory clean.
529        del this_masks
530        gc.collect()
531
532    if optimize_memory:
533        return _stitch_segmentation(masks, tile_ids, tiling, halo, output_shape=shape)
534
535    if return_instance_segmentation:
536        masks = util.mask_data_to_segmentation(masks, shape=shape, min_object_size=0)
537    return masks
@torch.no_grad()
def batched_inference( predictor: segment_anything.predictor.SamPredictor, image: Optional[numpy.ndarray], batch_size: int, boxes: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, point_labels: Optional[numpy.ndarray] = None, multimasking: bool = False, embedding_path: Union[str, os.PathLike, NoneType] = None, return_instance_segmentation: bool = True, segmentation_ids: Optional[list] = None, reduce_multimasking: bool = True, logits_masks: Optional[torch.Tensor] = None, verbose_embeddings: bool = True, mask_threshold: Union[float, str, NoneType] = None, return_highres_logits: bool = False, i: Optional[int] = None) -> Union[List[List[Dict[str, Any]]], numpy.ndarray]:
156@torch.no_grad()
157def batched_inference(
158    predictor: SamPredictor,
159    image: Optional[np.ndarray],
160    batch_size: int,
161    boxes: Optional[np.ndarray] = None,
162    points: Optional[np.ndarray] = None,
163    point_labels: Optional[np.ndarray] = None,
164    multimasking: bool = False,
165    embedding_path: Optional[Union[str, os.PathLike]] = None,
166    return_instance_segmentation: bool = True,
167    segmentation_ids: Optional[list] = None,
168    reduce_multimasking: bool = True,
169    logits_masks: Optional[torch.Tensor] = None,
170    verbose_embeddings: bool = True,
171    mask_threshold: Optional[Union[float, str]] = None,
172    return_highres_logits: bool = False,
173    i: Optional[int] = None,
174) -> Union[List[List[Dict[str, Any]]], np.ndarray]:
175    """Run batched inference for input prompts.
176
177    Args:
178        predictor: The Segment Anything predictor.
179        image: The input image. If None, we assume that the image embeddings have already been computed.
180        batch_size: The batch size to use for inference.
181        boxes: The box prompts. Array of shape N_PROMPTS x 4.
182            The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
183        points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2.
184            The points are represented by their coordinates [X, Y], which are given in the last dimension.
185        point_labels: The point prompt labels. Array of shape N_PROMPTS x 1.
186            The labels are either 0 (negative prompt) or 1 (positive prompt).
187        multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'.
188        embedding_path: Cache path for the image embeddings. By default, computed on-the-fly.
189        return_instance_segmentation: Whether to return a instance segmentation
190            or the individual mask data. By default, set to 'True'.
191        segmentation_ids: Fixed segmentation ids to assign to the masks
192            derived from the prompts.
193        reduce_multimasking: Whether to choose the most likely masks with
194            highest ious from multimasking. By default, set to 'True'.
195        logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256.
196            Whether to use the logits masks from previous segmentation.
197        verbose_embeddings: Whether to show progress outputs of computing image embeddings.
198            By default, set to 'True'.
199        mask_threshold: The theshold for binarizing masks based on the predicted values.
200            If None, the default threshold 0 is used. If "auto" is passed then the threshold is
201            determined with a local otsu filter.
202        return_highres_logits: Wheher to return high-resolution logits.
203        i: Index for the image data. Required if `image` has three spatial dimensions
204            or a time dimension and two spatial dimensions.
205
206    Returns:
207        The predicted segmentation masks.
208    """
209    n_prompts, have_boxes, have_points, have_logits = _validate_inputs(
210        boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks
211    )
212
213    # Compute the image embeddings.
214    if image is None:  # This means the image embeddings are computed already.
215        # Call get image embeddings, this will throw an error if they have not yet been computed.
216        predictor.get_image_embedding()
217    else:
218        image_embeddings = util.precompute_image_embeddings(
219            predictor, image, embedding_path, verbose=verbose_embeddings, i=i,
220        )
221        util.set_precomputed(predictor, image_embeddings)
222
223    # Determine the number of batches.
224    n_batches = int(np.ceil(float(n_prompts) / batch_size))
225
226    # Preprocess the prompts.
227    device = predictor.device
228    transform_function = ResizeLongestSide(1024)
229    image_shape = predictor.original_size
230    if have_boxes:
231        boxes = transform_function.apply_boxes(boxes, image_shape)
232        boxes = torch.tensor(boxes, dtype=torch.float32).to(device)
233    if have_points:
234        points = transform_function.apply_coords(points, image_shape)
235        points = torch.tensor(points, dtype=torch.float32).to(device)
236        point_labels = torch.tensor(point_labels, dtype=torch.float32).to(device)
237
238    masks = amg_utils.MaskData()
239    mask_threshold = predictor.model.mask_threshold if mask_threshold is None else mask_threshold
240    for batch_idx in range(n_batches):
241        batch_start = batch_idx * batch_size
242        batch_stop = min((batch_idx + 1) * batch_size, n_prompts)
243
244        batch_boxes = boxes[batch_start:batch_stop] if have_boxes else None
245        batch_points = points[batch_start:batch_stop] if have_points else None
246        batch_labels = point_labels[batch_start:batch_stop] if have_points else None
247        batch_logits = logits_masks[batch_start:batch_stop] if have_logits else None
248
249        batch_masks, batch_ious, batch_logits = predictor.predict_torch(
250            point_coords=batch_points,
251            point_labels=batch_labels,
252            boxes=batch_boxes,
253            mask_input=batch_logits,
254            multimask_output=multimasking,
255            return_logits=True,
256        )
257
258        # If we expect to reduce the masks from multimasking and use multi-masking,
259        # then we need to select the most likely mask (according to the predicted IOU) here.
260        if reduce_multimasking and multimasking:
261            _, max_index = batch_ious.max(axis=1)
262            batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
263            batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
264            batch_logits = torch.cat([batch_logits[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
265
266        batch_data = _process_masks_for_batch(
267            batch_masks, batch_ious, batch_logits, return_highres_logits, mask_threshold
268        )
269        masks.cat(batch_data)
270
271    # Mask data to records.
272    masks = [
273        {
274            "segmentation": masks["masks"][idx],
275            "area": masks["masks"][idx].sum(),
276            "bbox": amg_utils.box_xyxy_to_xywh(masks["boxes"][idx]).tolist(),
277            "predicted_iou": masks["iou_preds"][idx].item(),
278            "stability_score": masks["stability_scores"][idx].item(),
279            "seg_id": idx + 1 if segmentation_ids is None else int(segmentation_ids[idx]),
280            "logits": masks["logits"][idx]
281        }
282        for idx in range(len(masks["masks"]))
283    ]
284
285    if return_instance_segmentation:
286        masks = util.mask_data_to_segmentation(masks, min_object_size=0)
287    return masks

Run batched inference for input prompts.

Arguments:
  • predictor: The Segment Anything predictor.
  • image: The input image. If None, we assume that the image embeddings have already been computed.
  • batch_size: The batch size to use for inference.
  • boxes: The box prompts. Array of shape N_PROMPTS x 4. The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
  • points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. The points are represented by their coordinates [X, Y], which are given in the last dimension.
  • point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. The labels are either 0 (negative prompt) or 1 (positive prompt).
  • multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'.
  • embedding_path: Cache path for the image embeddings. By default, computed on-the-fly.
  • return_instance_segmentation: Whether to return a instance segmentation or the individual mask data. By default, set to 'True'.
  • segmentation_ids: Fixed segmentation ids to assign to the masks derived from the prompts.
  • reduce_multimasking: Whether to choose the most likely masks with highest ious from multimasking. By default, set to 'True'.
  • logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. Whether to use the logits masks from previous segmentation.
  • verbose_embeddings: Whether to show progress outputs of computing image embeddings. By default, set to 'True'.
  • mask_threshold: The theshold for binarizing masks based on the predicted values. If None, the default threshold 0 is used. If "auto" is passed then the threshold is determined with a local otsu filter.
  • return_highres_logits: Wheher to return high-resolution logits.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
Returns:

The predicted segmentation masks.

@torch.no_grad()
def batched_tiled_inference( predictor: segment_anything.predictor.SamPredictor, image: Optional[numpy.ndarray], batch_size: int, image_embeddings: Optional[Dict[str, Any]] = None, boxes: Optional[numpy.ndarray] = None, points: Optional[numpy.ndarray] = None, point_labels: Optional[numpy.ndarray] = None, multimasking: bool = False, embedding_path: Union[str, os.PathLike, NoneType] = None, return_instance_segmentation: bool = True, reduce_multimasking: bool = True, logits_masks: Optional[torch.Tensor] = None, verbose_embeddings: bool = True, mask_threshold: Union[float, str, NoneType] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, optimize_memory: bool = False, i: Optional[int] = None, **nms_kwargs) -> Union[List[List[Dict[str, Any]]], numpy.ndarray]:
357@torch.no_grad()
358def batched_tiled_inference(
359    predictor: SamPredictor,
360    image: Optional[np.ndarray],
361    batch_size: int,
362    image_embeddings: Optional[util.ImageEmbeddings] = None,
363    boxes: Optional[np.ndarray] = None,
364    points: Optional[np.ndarray] = None,
365    point_labels: Optional[np.ndarray] = None,
366    multimasking: bool = False,
367    embedding_path: Optional[Union[str, os.PathLike]] = None,
368    return_instance_segmentation: bool = True,
369    reduce_multimasking: bool = True,
370    logits_masks: Optional[torch.Tensor] = None,
371    verbose_embeddings: bool = True,
372    mask_threshold: Optional[Union[float, str]] = None,
373    tile_shape: Optional[Tuple[int, int]] = None,
374    halo: Optional[Tuple[int, int]] = None,
375    optimize_memory: bool = False,
376    i: Optional[int] = None,
377    **nms_kwargs,
378) -> Union[List[List[Dict[str, Any]]], np.ndarray]:
379    """Run batched inference for input prompts.
380
381    Args:
382        predictor: The Segment Anything predictor.
383        image: The input image. If None, we assume that the image embeddings have already been computed.
384        batch_size: The batch size to use for inference.
385        boxes: The box prompts. Array of shape N_PROMPTS x 4.
386            The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
387        points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2.
388            The points are represented by their coordinates [X, Y], which are given in the last dimension.
389        point_labels: The point prompt labels. Array of shape N_PROMPTS x 1.
390            The labels are either 0 (negative prompt) or 1 (positive prompt).
391        multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'.
392        embedding_path: Cache path for the image embeddings. By default, computed on-the-fly.
393        return_instance_segmentation: Whether to return a instance segmentation
394            or the individual mask data. By default, set to 'True'.
395        segmentation_ids: Fixed segmentation ids to assign to the masks
396            derived from the prompts.
397        reduce_multimasking: Whether to choose the most likely masks with
398            highest ious from multimasking. By default, set to 'True'.
399        logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256.
400            Whether to use the logits masks from previous segmentation.
401        verbose_embeddings: Whether to show progress outputs of computing image embeddings.
402            By default, set to 'True'.
403        mask_threshold: The theshold for binarizing masks based on the predicted values.
404            If None, the default threshold 0 is used. If "auto" is passed then the threshold is
405            determined with a local otsu filter.
406        tile_shape: The tile shape for embedding prediction.
407        halo: The overlap of between tiles.
408        optimize_memory: Whether to optimize the memory usage. If set to True:
409            - NMS will be applied directly for each tile to reduce it to a per-tile instance segmentation.
410            - The per-tile segmentations will be stitched.
411            - The result will be returned as an instance segmentation.
412        i: Index for the image data. Required if `image` has three spatial dimensions
413            or a time dimension and two spatial dimensions.
414        nms_kwargs: Keyword arguments for the NMS operations that is used for optimize_memory=True.
415            Does not have any effcet for optimize_memory=False.
416
417    Returns:
418        The predicted segmentation masks.
419    """
420    # Validate inputs and get input prompt summary.
421    segmentation_ids = None
422    n_prompts, have_boxes, have_points, have_logits = _validate_inputs(
423        boxes, points, point_labels, multimasking, return_instance_segmentation, segmentation_ids, logits_masks
424    )
425    if have_logits:
426        raise NotImplementedError
427
428    # Get the tiling parameters and compute embeddings if needed.
429    image_embeddings, shape, tile_shape, halo = _require_tiled_embeddings(
430        predictor, image, image_embeddings, embedding_path, tile_shape, halo, verbose_embeddings
431    )
432
433    # Order the prompts by tile and then iterate over the tiles.
434    tiling = blocking([0, 0], shape, tile_shape)
435    box_to_tile, point_to_tile, label_to_tile, logits_to_tile = {}, {}, {}, {}
436    tile_ids = []
437
438    # Box prompts are in the format N x 4. Boxes are stored in order [MIN_X, MIN_Y, MAX_X, MAX_Y].
439    # Points are in the Format N x 1 x 2 with coordinate order X, Y.
440    # Point labels are in the format N x 1.
441    # Mask prompts are in the format N x 1 x 256 x 256.
442    for prompt_id in range(n_prompts):
443        this_tile_id = None
444
445        if have_boxes:
446            box = boxes[prompt_id]
447            center = np.array([(box[1] + box[3]) / 2, (box[0] + box[2]) / 2]).round().astype("int").tolist()
448            this_tile_id = tiling.coordinatesToBlockId(center)
449            tile = tiling.getBlockWithHalo(this_tile_id, list(halo)).outerBlock
450            offset = tile.begin
451            this_tile_shape = tile.shape
452            box_in_tile = np.array(
453                [
454                    max(box[1] - offset[0], 0), max(box[0] - offset[1], 0),
455                    min(box[3] - offset[0], this_tile_shape[0]), min(box[2] - offset[1], this_tile_shape[1])
456                ]
457            )[None]
458            if this_tile_id in box_to_tile:
459                box_to_tile[this_tile_id] = np.concatenate([box_to_tile[this_tile_id], box_in_tile])
460            else:
461                box_to_tile[this_tile_id] = box_in_tile
462
463        if have_points:
464            point = points[prompt_id, 0][::-1].round().astype("int").tolist()
465            if this_tile_id is None:
466                this_tile_id = tiling.coordinatesToBlockId(point)
467            else:
468                assert this_tile_id == tiling.coordinatesToBlockId(point)
469            tile = tiling.getBlockWithHalo(this_tile_id, list(halo)).outerBlock
470            offset = tile.begin
471            point_in_tile = (points[prompt_id, 0] - np.array(offset)[::-1])[None, None]
472            label_in_tile = point_labels[prompt_id][None]
473            if this_tile_id in point_to_tile:
474                point_to_tile[this_tile_id] = np.concatenate([point_to_tile[this_tile_id], point_in_tile])
475                label_to_tile[this_tile_id] = np.concatenate([label_to_tile[this_tile_id], label_in_tile])
476            else:
477                point_to_tile[this_tile_id] = point_in_tile
478                label_to_tile[this_tile_id] = label_in_tile
479
480        # NOTE: logits are not yet supported.
481        tile_ids.append(this_tile_id)
482
483    # Find the tiles with prompts.
484    tile_ids = sorted(list(set(tile_ids)))
485
486    # Run batched inference for each tile.
487    masks = []
488    # Additional variables needed for optimized memory mode.
489    id_offset = 0
490    for tile_id in tqdm(tile_ids, desc="Run batched inference"):
491        # Get the prompts for this tile.
492        tile_boxes = box_to_tile.get(tile_id)
493        tile_logits = logits_to_tile.get(tile_id)
494        tile_points, tile_labels = point_to_tile.get(tile_id), label_to_tile.get(tile_id)
495
496        # Set the correct embeddings, run inference.
497        predictor = util.set_precomputed(predictor, image_embeddings, tile_id=tile_id, i=i)
498        this_masks = batched_inference(
499            predictor=predictor,
500            image=None,
501            batch_size=batch_size,
502            boxes=tile_boxes,
503            points=tile_points,
504            point_labels=tile_labels,
505            multimasking=multimasking,
506            return_instance_segmentation=False,
507            segmentation_ids=segmentation_ids,
508            reduce_multimasking=reduce_multimasking,
509            logits_masks=tile_logits,
510            mask_threshold=mask_threshold,
511        )
512
513        if optimize_memory:
514            # Apply NMS directly to get a segmentation.
515            segmentation = util.apply_nms(this_masks, **nms_kwargs)
516            fg_mask = segmentation != 0
517            segmentation[fg_mask] += id_offset
518            id_offset = segmentation.max()
519            masks.append(segmentation)
520        else:
521            # Add the offset for the current tile to the bounding box.
522            tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
523            offset = np.array(tile.begin[::-1] + [0, 0])
524            this_masks = [
525                {**mask, "global_bbox": (np.array(mask["bbox"]) + offset).tolist()} for mask in this_masks
526            ]
527            masks.extend(this_masks)
528
529        # Try to keep the memory clean.
530        del this_masks
531        gc.collect()
532
533    if optimize_memory:
534        return _stitch_segmentation(masks, tile_ids, tiling, halo, output_shape=shape)
535
536    if return_instance_segmentation:
537        masks = util.mask_data_to_segmentation(masks, shape=shape, min_object_size=0)
538    return masks

Run batched inference for input prompts.

Arguments:
  • predictor: The Segment Anything predictor.
  • image: The input image. If None, we assume that the image embeddings have already been computed.
  • batch_size: The batch size to use for inference.
  • boxes: The box prompts. Array of shape N_PROMPTS x 4. The bounding boxes are represented by [MIN_X, MIN_Y, MAX_X, MAX_Y].
  • points: The point prompt coordinates. Array of shape N_PROMPTS x 1 x 2. The points are represented by their coordinates [X, Y], which are given in the last dimension.
  • point_labels: The point prompt labels. Array of shape N_PROMPTS x 1. The labels are either 0 (negative prompt) or 1 (positive prompt).
  • multimasking: Whether to predict with 3 or 1 mask. By default, set to 'False'.
  • embedding_path: Cache path for the image embeddings. By default, computed on-the-fly.
  • return_instance_segmentation: Whether to return a instance segmentation or the individual mask data. By default, set to 'True'.
  • segmentation_ids: Fixed segmentation ids to assign to the masks derived from the prompts.
  • reduce_multimasking: Whether to choose the most likely masks with highest ious from multimasking. By default, set to 'True'.
  • logits_masks: The logits masks. Array of shape N_PROMPTS x 1 x 256 x 256. Whether to use the logits masks from previous segmentation.
  • verbose_embeddings: Whether to show progress outputs of computing image embeddings. By default, set to 'True'.
  • mask_threshold: The theshold for binarizing masks based on the predicted values. If None, the default threshold 0 is used. If "auto" is passed then the threshold is determined with a local otsu filter.
  • tile_shape: The tile shape for embedding prediction.
  • halo: The overlap of between tiles.
  • optimize_memory: Whether to optimize the memory usage. If set to True:
    • NMS will be applied directly for each tile to reduce it to a per-tile instance segmentation.
    • The per-tile segmentations will be stitched.
    • The result will be returned as an instance segmentation.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • nms_kwargs: Keyword arguments for the NMS operations that is used for optimize_memory=True. Does not have any effcet for optimize_memory=False.
Returns:

The predicted segmentation masks.