micro_sam.instance_segmentation

Automated instance segmentation functionality. The classes implemented here extend the automatic instance segmentation from Segment Anything: https://computational-cell-analytics.github.io/micro-sam/micro_sam.html

   1"""Automated instance segmentation functionality.
   2The classes implemented here extend the automatic instance segmentation from Segment Anything:
   3https://computational-cell-analytics.github.io/micro-sam/micro_sam.html
   4"""
   5
   6import os
   7import shutil
   8import tempfile
   9import warnings
  10from abc import ABC
  11from contextlib import contextmanager
  12from copy import deepcopy
  13from collections import OrderedDict
  14from typing import Any, Dict, Literal, List, Optional, Tuple, Union
  15
  16import numpy as np
  17import zarr
  18from skimage.measure import regionprops
  19from skimage.segmentation import find_boundaries
  20
  21from bioimage_cpp import filters as filter_impl
  22
  23import torch
  24from torchvision.ops.boxes import batched_nms, box_area
  25
  26from torch_em.model import UNETR
  27from torch_em.util.segmentation import watershed_from_center_and_boundary_distances
  28
  29import elf.parallel as parallel_impl
  30from elf.parallel.filters import apply_filter
  31from elf.wrapper.base import MultiTransformationWrapper
  32from elf.wrapper.generic import ThresholdWrapper
  33
  34from bioimage_cpp.utils import Blocking
  35
  36import segment_anything.utils.amg as amg_utils
  37from segment_anything.predictor import SamPredictor
  38
  39from . import util
  40from .inference import batched_inference, batched_tiled_inference
  41from ._vendored import batched_mask_to_box, mask_to_rle_pytorch
  42
  43# We may change this to 'apg' in version 1.8.
  44DEFAULT_SEGMENTATION_MODE_WITH_DECODER = "ais"
  45
  46#
  47# Utility Functionality
  48#
  49
  50
  51class _FakeInput:
  52    def __init__(self, shape):
  53        self.shape = shape
  54
  55    def __getitem__(self, index):
  56        block_shape = tuple(ind.stop - ind.start for ind in index)
  57        return np.zeros(block_shape, dtype="float32")
  58
  59
  60#
  61# Classes for automatic instance segmentation
  62#
  63
  64
  65class AMGBase(ABC):
  66    """Base class for the automatic mask generators.
  67    """
  68    def __init__(self):
  69        # the state that has to be computed by the 'initialize' method of the child classes
  70        self._is_initialized = False
  71        self._crop_list = None
  72        self._crop_boxes = None
  73        self._original_size = None
  74
  75    @property
  76    def is_initialized(self):
  77        """Whether the mask generator has already been initialized.
  78        """
  79        return self._is_initialized
  80
  81    @property
  82    def crop_list(self):
  83        """The list of mask data after initialization.
  84        """
  85        return self._crop_list
  86
  87    @property
  88    def crop_boxes(self):
  89        """The list of crop boxes.
  90        """
  91        return self._crop_boxes
  92
  93    @property
  94    def original_size(self):
  95        """The original image size.
  96        """
  97        return self._original_size
  98
  99    def _postprocess_batch(
 100        self,
 101        data,
 102        crop_box,
 103        original_size,
 104        pred_iou_thresh,
 105        stability_score_thresh,
 106        box_nms_thresh,
 107    ):
 108        orig_h, orig_w = original_size
 109
 110        # filter by predicted IoU
 111        if pred_iou_thresh > 0.0:
 112            keep_mask = data["iou_preds"] > pred_iou_thresh
 113            data.filter(keep_mask)
 114
 115        # filter by stability score
 116        if stability_score_thresh > 0.0:
 117            keep_mask = data["stability_score"] >= stability_score_thresh
 118            data.filter(keep_mask)
 119
 120        # filter boxes that touch crop boundaries
 121        keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
 122        if not torch.all(keep_mask):
 123            data.filter(keep_mask)
 124
 125        # remove duplicates within this crop.
 126        keep_by_nms = batched_nms(
 127            data["boxes"].float(),
 128            data["iou_preds"],
 129            torch.zeros_like(data["boxes"][:, 0]),  # categories
 130            iou_threshold=box_nms_thresh,
 131        )
 132        data.filter(keep_by_nms)
 133
 134        # return to the original image frame
 135        data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box)
 136        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
 137        # the data from embedding based segmentation doesn't have the points
 138        # so we skip if the corresponding key can't be found
 139        try:
 140            data["points"] = amg_utils.uncrop_points(data["points"], crop_box)
 141        except KeyError:
 142            pass
 143
 144        return data
 145
 146    def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
 147
 148        if len(mask_data["rles"]) == 0:
 149            return mask_data
 150
 151        # filter small disconnected regions and holes
 152        new_masks = []
 153        scores = []
 154        for rle in mask_data["rles"]:
 155            mask = amg_utils.rle_to_mask(rle)
 156
 157            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes")
 158            unchanged = not changed
 159            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands")
 160            unchanged = unchanged and not changed
 161
 162            new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0))
 163            # give score=0 to changed masks and score=1 to unchanged masks
 164            # so NMS will prefer ones that didn't need postprocessing
 165            scores.append(float(unchanged))
 166
 167        # recalculate boxes and remove any new duplicates
 168        masks = torch.cat(new_masks, dim=0)
 169        boxes = batched_mask_to_box(masks.to(torch.bool))  # Casting this to boolean as we work with one-hot labels.
 170        keep_by_nms = batched_nms(
 171            boxes.float(),
 172            torch.as_tensor(scores, dtype=torch.float),
 173            torch.zeros_like(boxes[:, 0]),  # categories
 174            iou_threshold=nms_thresh,
 175        )
 176
 177        # only recalculate RLEs for masks that have changed
 178        for i_mask in keep_by_nms:
 179            if scores[i_mask] == 0.0:
 180                mask_torch = masks[i_mask].unsqueeze(0)
 181                # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
 182                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
 183                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
 184        mask_data.filter(keep_by_nms)
 185
 186        return mask_data
 187
 188    def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode):
 189        # filter small disconnected regions and holes in masks
 190        if min_mask_region_area > 0:
 191            mask_data = self._postprocess_small_regions(
 192                mask_data,
 193                min_mask_region_area,
 194                max(box_nms_thresh, crop_nms_thresh),
 195            )
 196
 197        # encode masks
 198        if output_mode == "coco_rle":
 199            mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]]
 200        # We require the binary mask output for creating the instance-seg output.
 201        elif output_mode in ("binary_mask", "instance_segmentation"):
 202            mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]]
 203        elif output_mode == "rle":
 204            mask_data["segmentations"] = mask_data["rles"]
 205        else:
 206            raise ValueError(f"Invalid output mode {output_mode}.")
 207
 208        # write mask records
 209        curr_anns = []
 210        for idx in range(len(mask_data["segmentations"])):
 211            ann = {
 212                "segmentation": mask_data["segmentations"][idx],
 213                "area": amg_utils.area_from_rle(mask_data["rles"][idx]),
 214                "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
 215                "predicted_iou": mask_data["iou_preds"][idx].item(),
 216                "stability_score": mask_data["stability_score"][idx].item(),
 217                "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
 218            }
 219            # the data from embedding based segmentation doesn't have the points
 220            # so we skip if the corresponding key can't be found
 221            try:
 222                ann["point_coords"] = [mask_data["points"][idx].tolist()]
 223            except KeyError:
 224                pass
 225            curr_anns.append(ann)
 226
 227        return curr_anns
 228
 229    def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
 230        orig_h, orig_w = original_size
 231
 232        # serialize predictions and store in MaskData
 233        data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
 234        if points is not None:
 235            data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
 236
 237        del masks
 238
 239        # calculate the stability scores
 240        data["stability_score"] = amg_utils.calculate_stability_score(
 241            data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
 242        )
 243
 244        # threshold masks and calculate boxes
 245        data["masks"] = data["masks"] > self._predictor.model.mask_threshold
 246        data["masks"] = data["masks"].type(torch.bool)
 247        data["boxes"] = batched_mask_to_box(data["masks"])
 248
 249        # compress to RLE
 250        data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
 251        # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
 252        data["rles"] = mask_to_rle_pytorch(data["masks"])
 253        del data["masks"]
 254
 255        return data
 256
 257    def get_state(self) -> Dict[str, Any]:
 258        """Get the initialized state of the mask generator.
 259
 260        Returns:
 261            State of the mask generator.
 262        """
 263        if not self.is_initialized:
 264            raise RuntimeError("The state has not been computed yet. Call initialize first.")
 265
 266        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
 267
 268    def set_state(self, state: Dict[str, Any]) -> None:
 269        """Set the state of the mask generator.
 270
 271        Args:
 272            state: The state of the mask generator, e.g. from serialized state.
 273        """
 274        self._crop_list = state["crop_list"]
 275        self._crop_boxes = state["crop_boxes"]
 276        self._original_size = state["original_size"]
 277        self._is_initialized = True
 278
 279    def clear_state(self):
 280        """Clear the state of the mask generator.
 281        """
 282        self._crop_list = None
 283        self._crop_boxes = None
 284        self._original_size = None
 285        self._is_initialized = False
 286
 287
 288class AutomaticMaskGenerator(AMGBase):
 289    """Generates an instance segmentation without prompts, using a point grid.
 290
 291    This class implements the same logic as
 292    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
 293    It decouples the computationally expensive steps of generating masks from the cheap post-processing operation
 294    to filter these masks to enable grid search and interactively changing the post-processing.
 295
 296    Use this class as follows:
 297    ```python
 298    amg = AutomaticMaskGenerator(predictor)
 299    amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
 300    masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
 301    ```
 302
 303    Args:
 304        predictor: The segment anything predictor.
 305        points_per_side: The number of points to be sampled along one side of the image.
 306            If None, `point_grids` must provide explicit point sampling. By default, set to '32'.
 307        points_per_batch: The number of points run simultaneously by the model.
 308            Higher numbers may be faster but use more GPU memory.
 309            By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons).
 310        crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
 311            By default, set to '0'.
 312        crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'.
 313        crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
 314            By default, set to '1'.
 315        point_grids: A list over explicit grids of points used for sampling masks.
 316            Normalized to [0, 1] with respect to the image coordinate system.
 317        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
 318            By default, set to '1.0'.
 319    """
 320    def __init__(
 321        self,
 322        predictor: SamPredictor,
 323        points_per_side: Optional[int] = 32,
 324        points_per_batch: Optional[int] = None,
 325        crop_n_layers: int = 0,
 326        crop_overlap_ratio: float = 512 / 1500,
 327        crop_n_points_downscale_factor: int = 1,
 328        point_grids: Optional[List[np.ndarray]] = None,
 329        stability_score_offset: float = 1.0,
 330    ):
 331        super().__init__()
 332
 333        if points_per_side is not None:
 334            self.point_grids = amg_utils.build_all_layer_point_grids(
 335                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
 336            )
 337        elif point_grids is not None:
 338            self.point_grids = point_grids
 339        else:
 340            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
 341
 342        self._predictor = predictor
 343        self._points_per_side = points_per_side
 344
 345        # we set the points per batch to 16 for mps for performance reasons
 346        # and otherwise keep them at the default of 64
 347        if points_per_batch is None:
 348            points_per_batch = 16 if str(predictor.device) == "mps" else 64
 349        self._points_per_batch = points_per_batch
 350
 351        self._crop_n_layers = crop_n_layers
 352        self._crop_overlap_ratio = crop_overlap_ratio
 353        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
 354        self._stability_score_offset = stability_score_offset
 355
 356    def _process_batch(self, points, im_size, crop_box, original_size):
 357        # run model on this batch
 358        transformed_points = self._predictor.transform.apply_coords(points, im_size)
 359        in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
 360        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
 361        masks, iou_preds, _ = self._predictor.predict_torch(
 362            point_coords=in_points[:, None, :],
 363            point_labels=in_labels[:, None],
 364            multimask_output=True,
 365            return_logits=True,
 366        )
 367        data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
 368        del masks
 369        return data
 370
 371    def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None):
 372        # Crop the image and calculate embeddings.
 373        x0, y0, x1, y1 = crop_box
 374        cropped_im = image[y0:y1, x0:x1, :]
 375        cropped_im_size = cropped_im.shape[:2]
 376
 377        if not precomputed_embeddings:
 378            self._predictor.set_image(cropped_im)
 379
 380        # Get the points for this crop.
 381        points_scale = np.array(cropped_im_size)[None, ::-1]
 382        points_for_image = self.point_grids[crop_layer_idx] * points_scale
 383
 384        # Generate masks for this crop in batches.
 385        data = amg_utils.MaskData()
 386        n_batches = len(points_for_image) // self._points_per_batch +\
 387            int(len(points_for_image) % self._points_per_batch != 0)
 388        if pbar_init is not None:
 389            pbar_init(n_batches, "Predict masks for point grid prompts")
 390
 391        for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image):
 392            batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
 393            data.cat(batch_data)
 394            del batch_data
 395            if pbar_update is not None:
 396                pbar_update(1)
 397
 398        if not precomputed_embeddings:
 399            self._predictor.reset_image()
 400
 401        return data
 402
 403    @torch.no_grad()
 404    def initialize(
 405        self,
 406        image: np.ndarray,
 407        image_embeddings: Optional[util.ImageEmbeddings] = None,
 408        i: Optional[int] = None,
 409        verbose: bool = False,
 410        pbar_init: Optional[callable] = None,
 411        pbar_update: Optional[callable] = None,
 412    ) -> None:
 413        """Initialize image embeddings and masks for an image.
 414
 415        Args:
 416            image: The input image, volume or timeseries.
 417            image_embeddings: Optional precomputed image embeddings.
 418                See `util.precompute_image_embeddings` for details.
 419            i: Index for the image data. Required if `image` has three spatial dimensions
 420                or a time dimension and two spatial dimensions.
 421            verbose: Whether to print computation progress. By default, set to 'False'.
 422            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 423                Can be used together with pbar_update to handle napari progress bar in other thread.
 424                To enable using this function within a threadworker.
 425            pbar_update: Callback to update an external progress bar.
 426        """
 427        original_size = image.shape[:2]
 428        self._original_size = original_size
 429
 430        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
 431            original_size, self._crop_n_layers, self._crop_overlap_ratio
 432        )
 433
 434        # We can set fixed image embeddings if we only have a single crop box (the default setting).
 435        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
 436        if len(crop_boxes) == 1:
 437            if image_embeddings is None:
 438                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 439            util.set_precomputed(self._predictor, image_embeddings, i=i)
 440            precomputed_embeddings = True
 441        else:
 442            precomputed_embeddings = False
 443
 444        # we need to cast to the image representation that is compatible with SAM
 445        image = util._to_image(image)
 446
 447        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 448
 449        crop_list = []
 450        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
 451            crop_data = self._process_crop(
 452                image, crop_box, layer_idx,
 453                precomputed_embeddings=precomputed_embeddings,
 454                pbar_init=pbar_init, pbar_update=pbar_update,
 455            )
 456            crop_list.append(crop_data)
 457        pbar_close()
 458
 459        self._is_initialized = True
 460        self._crop_list = crop_list
 461        self._crop_boxes = crop_boxes
 462
 463    @torch.no_grad()
 464    def generate(
 465        self,
 466        pred_iou_thresh: float = 0.88,
 467        stability_score_thresh: float = 0.95,
 468        box_nms_thresh: float = 0.7,
 469        crop_nms_thresh: float = 0.7,
 470        min_mask_region_area: int = 0,
 471        output_mode: str = "instance_segmentation",
 472        with_background: bool = True,
 473    ) -> Union[List[Dict[str, Any]], np.ndarray]:
 474        """Generate instance segmentation for the currently initialized image.
 475
 476        Args:
 477            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
 478                By default, set to '0.88'.
 479            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
 480                under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
 481            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
 482                By default, set to '0.7'.
 483            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
 484                By default, set to '0.7'.
 485            min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
 486            output_mode: The form masks are returned in. Possible values are:
 487                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
 488                - 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format.
 489                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
 490                - 'rle': Return a list of dictionaries with run-length encoded masks.
 491                By default, set to 'instance_segmentation'.
 492            with_background: Whether to remove the largest object, which often covers the background.
 493
 494        Returns:
 495            The segmentation masks.
 496        """
 497        if not self.is_initialized:
 498            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
 499
 500        data = amg_utils.MaskData()
 501        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
 502            crop_data = self._postprocess_batch(
 503                data=deepcopy(data_),
 504                crop_box=crop_box, original_size=self.original_size,
 505                pred_iou_thresh=pred_iou_thresh,
 506                stability_score_thresh=stability_score_thresh,
 507                box_nms_thresh=box_nms_thresh
 508            )
 509            data.cat(crop_data)
 510
 511        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
 512            # Prefer masks from smaller crops
 513            scores = 1 / box_area(data["crop_boxes"])
 514            scores = scores.to(data["boxes"].device)
 515            keep_by_nms = batched_nms(
 516                data["boxes"].float(),
 517                scores,
 518                torch.zeros_like(data["boxes"][:, 0]),  # categories
 519                iou_threshold=crop_nms_thresh,
 520            )
 521            data.filter(keep_by_nms)
 522
 523        data.to_numpy()
 524        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
 525        if output_mode == "instance_segmentation":
 526            shape = next(iter(masks))["segmentation"].shape if len(masks) > 0 else self.original_size
 527            masks = util.mask_data_to_segmentation(
 528                masks, shape=shape, with_background=with_background, merge_exclusively=False
 529            )
 530        return masks
 531
 532
 533# Helper function for tiled embedding computation and checking consistent state.
 534def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo, verbose, batch_size, mask, i):
 535    if image_embeddings is None:
 536        if tile_shape is None or halo is None:
 537            raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.")
 538        image_embeddings = util.precompute_image_embeddings(
 539            predictor, image, tile_shape=tile_shape, halo=halo, verbose=verbose, batch_size=batch_size, mask=mask,
 540        )
 541
 542    # Use tile shape and halo from the precomputed embeddings if not given.
 543    # Otherwise check that they are consistent.
 544    feats = image_embeddings["features"]
 545    tile_shape_, halo_ = tuple(feats.attrs["tile_shape"]), tuple(feats.attrs["halo"])
 546    if tile_shape is None:
 547        tile_shape = tile_shape_
 548    elif tile_shape != tile_shape_:
 549        raise ValueError(
 550            f"Inconsistent tile_shape parameter {tile_shape} with precomputed embeedings: {tile_shape_}."
 551        )
 552    if halo is None:
 553        halo = halo_
 554    elif halo != halo_:
 555        raise ValueError(f"Inconsistent halo parameter {halo} with precomputed embeedings: {halo_}.")
 556
 557    tiles_in_mask = feats.attrs.get("tiles_in_mask", None)
 558    if tiles_in_mask is not None and i is not None:
 559        tiles_in_mask = tiles_in_mask[str(i)]
 560
 561    return image_embeddings, tile_shape, halo, tiles_in_mask
 562
 563
 564class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
 565    """Generates an instance segmentation without prompts, using a point grid.
 566
 567    Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
 568
 569    Args:
 570        predictor: The Segment Anything predictor.
 571        points_per_side: The number of points to be sampled along one side of the image.
 572            If None, `point_grids` must provide explicit point sampling. By default, set to '32'.
 573        points_per_batch: The number of points run simultaneously by the model.
 574            Higher numbers may be faster but use more GPU memory. By default, set to '64'.
 575        point_grids: A list over explicit grids of points used for sampling masks.
 576            Normalized to [0, 1] with respect to the image coordinate system.
 577        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
 578            By default, set to '1.0'.
 579    """
 580
 581    # We only expose the arguments that make sense for the tiled mask generator.
 582    # Anything related to crops doesn't make sense, because we re-use that functionality
 583    # for tiling, so these parameters wouldn't have any effect.
 584    def __init__(
 585        self,
 586        predictor: SamPredictor,
 587        points_per_side: Optional[int] = 32,
 588        points_per_batch: int = 64,
 589        point_grids: Optional[List[np.ndarray]] = None,
 590        stability_score_offset: float = 1.0,
 591    ) -> None:
 592        super().__init__(
 593            predictor=predictor,
 594            points_per_side=points_per_side,
 595            points_per_batch=points_per_batch,
 596            point_grids=point_grids,
 597            stability_score_offset=stability_score_offset,
 598        )
 599
 600    @torch.no_grad()
 601    def initialize(
 602        self,
 603        image: np.ndarray,
 604        image_embeddings: Optional[util.ImageEmbeddings] = None,
 605        i: Optional[int] = None,
 606        tile_shape: Optional[Tuple[int, int]] = None,
 607        halo: Optional[Tuple[int, int]] = None,
 608        verbose: bool = False,
 609        pbar_init: Optional[callable] = None,
 610        pbar_update: Optional[callable] = None,
 611        batch_size: int = 1,
 612        mask: Optional[np.typing.ArrayLike] = None,
 613    ) -> None:
 614        """Initialize image embeddings and masks for an image.
 615
 616        Args:
 617            image: The input image, volume or timeseries.
 618            image_embeddings: Optional precomputed image embeddings.
 619                See `util.precompute_image_embeddings` for details.
 620            i: Index for the image data. Required if `image` has three spatial dimensions
 621                or a time dimension and two spatial dimensions.
 622            tile_shape: The tile shape for embedding prediction.
 623            halo: The overlap of between tiles.
 624            verbose: Whether to print computation progress. By default, set to 'False'.
 625            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 626                Can be used together with pbar_update to handle napari progress bar in other thread.
 627                To enable using this function within a threadworker.
 628            pbar_update: Callback to update an external progress bar.
 629            batch_size: The batch size for image embedding prediction. By default, set to '1'.
 630            mask: An optional mask to define areas that are ignored in the segmentation.
 631        """
 632        original_size = image.shape[:2]
 633        self._original_size = original_size
 634
 635        self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings(
 636            self._predictor, image, image_embeddings, tile_shape, halo,
 637            verbose=verbose, batch_size=batch_size, mask=mask, i=i,
 638        )
 639
 640        tiling = Blocking([0, 0], original_size, tile_shape)
 641        if tiles_in_mask is None:
 642            n_tiles = tiling.number_of_blocks
 643            tile_ids = range(n_tiles)
 644        else:
 645            n_tiles = len(tiles_in_mask)
 646            tile_ids = tiles_in_mask
 647
 648        # The crop box is always the full local tile.
 649        tiles = [tiling.get_block_with_halo(tile_id, list(halo)).outer_block for tile_id in tile_ids]
 650        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
 651
 652        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 653        pbar_init(n_tiles, "Compute masks for tile")
 654
 655        # We need to cast to the image representation that is compatible with SAM.
 656        image = util._to_image(image)
 657
 658        mask_data = []
 659        for idx, tile_id in enumerate(tile_ids):
 660            # set the pre-computed embeddings for this tile
 661            features = image_embeddings["features"][str(tile_id)]
 662            tile_embeddings = {
 663                "features": features,
 664                "input_size": features.attrs["input_size"],
 665                "original_size": features.attrs["original_size"],
 666            }
 667            util.set_precomputed(self._predictor, tile_embeddings, i)
 668
 669            # Compute the mask data for this tile and append it
 670            this_mask_data = self._process_crop(
 671                image, crop_box=crop_boxes[idx], crop_layer_idx=0, precomputed_embeddings=True
 672            )
 673            mask_data.append(this_mask_data)
 674            pbar_update(1)
 675        pbar_close()
 676
 677        # set the initialized data
 678        self._is_initialized = True
 679        self._crop_list = mask_data
 680        self._crop_boxes = crop_boxes
 681
 682
 683#
 684# Instance segmentation functionality based on fine-tuned decoder
 685#
 686
 687
 688class DecoderAdapter(torch.nn.Module):
 689    """Adapter to contain the UNETR decoder in a single module.
 690
 691    To apply the decoder on top of pre-computed embeddings for the segmentation functionality.
 692    See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
 693    """
 694    def __init__(self, unetr: torch.nn.Module):
 695        super().__init__()
 696
 697        self.base = unetr.base
 698        self.out_conv = unetr.out_conv
 699        self.deconv_out = unetr.deconv_out
 700        self.decoder_head = unetr.decoder_head
 701        self.final_activation = unetr.final_activation
 702        self.postprocess_masks = unetr.postprocess_masks
 703
 704        self.decoder = unetr.decoder
 705        self.deconv1 = unetr.deconv1
 706        self.deconv2 = unetr.deconv2
 707        self.deconv3 = unetr.deconv3
 708        self.deconv4 = unetr.deconv4
 709
 710    def _forward_impl(self, input_):
 711        z12 = input_
 712
 713        z9 = self.deconv1(z12)
 714        z6 = self.deconv2(z9)
 715        z3 = self.deconv3(z6)
 716        z0 = self.deconv4(z3)
 717
 718        updated_from_encoder = [z9, z6, z3]
 719
 720        x = self.base(z12)
 721        x = self.decoder(x, encoder_inputs=updated_from_encoder)
 722        x = self.deconv_out(x)
 723
 724        x = torch.cat([x, z0], dim=1)
 725        x = self.decoder_head(x)
 726
 727        x = self.out_conv(x)
 728        if self.final_activation is not None:
 729            x = self.final_activation(x)
 730        return x
 731
 732    def forward(self, input_, input_shape, original_shape):
 733        x = self._forward_impl(input_)
 734        x = self.postprocess_masks(x, input_shape, original_shape)
 735        return x
 736
 737
 738def get_unetr(
 739    image_encoder: torch.nn.Module,
 740    decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
 741    device: Optional[Union[str, torch.device]] = None,
 742    out_channels: int = 3,
 743    flexible_load_checkpoint: bool = False,
 744    final_activation: Optional[str] = "Sigmoid",
 745) -> torch.nn.Module:
 746    """Get UNETR model for automatic instance segmentation.
 747
 748    Args:
 749        image_encoder: The image encoder of the SAM model.
 750            This is used as encoder by the UNETR too.
 751        decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
 752        device: The device. By default, automatically chooses the best available device.
 753        out_channels: The number of output channels. By default, set to '3'.
 754        flexible_load_checkpoint: Whether to allow reinitialization of parameters
 755            which could not be found in the provided decoder state. By default, set to 'False'.
 756        final_activation: The activation applied to the network output. By default uses a Sigmoid activation.
 757
 758    Returns:
 759        The UNETR model.
 760    """
 761    device = util.get_device(device)
 762
 763    if decoder_state is None:
 764        use_conv_transpose = False  # By default, we use interpolation for upsampling.
 765    else:
 766        # From the provided pretrained 'decoder_state', we check whether it uses transposed convolutions.
 767        # NOTE: Explanation to the logic below -
 768        # - We do this by looking for parameter names that contain '.block.' within the "decoder.samplers"
 769        #   submodules. This naming convention indicates that transposed convolutions are used,
 770        #   wrapped inside a custom block module.
 771        # - Otherwise '.conv.' appears. It indicates a standard `Conv2d` applied after interpolation for upsampling.
 772        use_conv_transpose = any(".block." in k for k in decoder_state.keys() if k.startswith("decoder.samplers"))
 773
 774    unetr = UNETR(
 775        backbone="sam",
 776        encoder=image_encoder,
 777        out_channels=out_channels,
 778        use_sam_stats=True,
 779        final_activation=final_activation,
 780        use_skip_connection=False,
 781        resize_input=True,
 782        use_conv_transpose=use_conv_transpose,
 783    )
 784
 785    if decoder_state is not None:
 786        unetr_state_dict = unetr.state_dict()
 787        for k, v in unetr_state_dict.items():
 788            if not k.startswith("encoder"):
 789                # Whether allow reinitalization of params, if not found or mismatched.
 790                if flexible_load_checkpoint:
 791                    if k in decoder_state:  # First check whether the key is available in the provided decoder state.
 792                        if v.shape != decoder_state[k].shape:  # Then check if the sizes mismatch.
 793                            warnings.warn(f"Shape of '{k}' did not match. Hence, we reinitialize it.")
 794                            unetr_state_dict[k] = v
 795                        else:
 796                            unetr_state_dict[k] = decoder_state[k]
 797                    else:  # Otherwise, allow it to initialize it.
 798                        warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.")
 799                        unetr_state_dict[k] = v
 800
 801                else:  # Be strict on finding the parameter in the decoder state.
 802                    if k not in decoder_state:
 803                        raise RuntimeError(f"The parameters for '{k}' could not be found or has a size mismatch.")
 804                    unetr_state_dict[k] = decoder_state[k]
 805
 806        unetr.load_state_dict(unetr_state_dict)
 807
 808    unetr.to(device)
 809    return unetr
 810
 811
 812def get_decoder(
 813    image_encoder: torch.nn.Module,
 814    decoder_state: OrderedDict[str, torch.Tensor],
 815    device: Optional[Union[str, torch.device]] = None,
 816) -> DecoderAdapter:
 817    """Get decoder to predict outputs for automatic instance segmentation
 818
 819    Args:
 820        image_encoder: The image encoder of the SAM model.
 821        decoder_state: State to initialize the weights of the UNETR decoder.
 822        device: The device. By default, automatically chooses the best available device.
 823
 824    Returns:
 825        The decoder for instance segmentation.
 826    """
 827    unetr = get_unetr(image_encoder, decoder_state, device)
 828    return DecoderAdapter(unetr)
 829
 830
 831def get_predictor_and_decoder(
 832    model_type: str,
 833    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 834    device: Optional[Union[str, torch.device]] = None,
 835    peft_kwargs: Optional[Dict] = None,
 836) -> Tuple[SamPredictor, DecoderAdapter]:
 837    """Load the SAM model (predictor) and instance segmentation decoder.
 838
 839    This requires a checkpoint that contains the state for both predictor
 840    and decoder.
 841
 842    Args:
 843        model_type: The type of the image encoder used in the SAM model.
 844        checkpoint_path: Path to the checkpoint from which to load the data.
 845        device: The device. By default, automatically chooses the best available device.
 846        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 847
 848    Returns:
 849        The SAM predictor.
 850        The decoder for instance segmentation.
 851    """
 852    device = util.get_device(device)
 853    predictor, state = util.get_sam_model(
 854        model_type=model_type,
 855        checkpoint_path=checkpoint_path,
 856        device=device,
 857        return_state=True,
 858        peft_kwargs=peft_kwargs,
 859    )
 860
 861    if "decoder_state" not in state:
 862        raise ValueError(
 863            f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
 864        )
 865
 866    decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
 867    return predictor, decoder
 868
 869
 870@contextmanager
 871def _array_or_zarr(shape, dtype, chunks, use_zarr=False):
 872    if not use_zarr:
 873        yield np.zeros(shape, dtype=dtype)
 874        return
 875
 876    tmpdir = tempfile.mkdtemp(prefix="tmp-zarr-")
 877    try:
 878        store_path = os.path.join(tmpdir, "tmp.zarr")
 879        root = zarr.open_group(store_path, mode="w")
 880        arr = root.create_dataset(name="data", shape=shape, dtype=dtype, chunks=chunks)
 881        yield arr
 882
 883    finally:
 884        shutil.rmtree(tmpdir, ignore_errors=True)
 885
 886
 887def _watershed_from_center_and_boundary_distances_parallel(
 888    center_distances,
 889    boundary_distances,
 890    foreground_map,
 891    center_distance_threshold,
 892    boundary_distance_threshold,
 893    foreground_threshold,
 894    distance_smoothing,
 895    min_size,
 896    tile_shape,
 897    halo,
 898    n_threads,
 899    verbose=False,
 900    optimize_memory=False,
 901    segmentation=None,
 902):
 903    center_distances = apply_filter(
 904        center_distances, "gaussianSmoothing", sigma=distance_smoothing,
 905        block_shape=tile_shape, n_threads=n_threads
 906    )
 907    boundary_distances = apply_filter(
 908        boundary_distances, "gaussianSmoothing", sigma=distance_smoothing,
 909        block_shape=tile_shape, n_threads=n_threads
 910    )
 911
 912    fg_mask = ThresholdWrapper(foreground_map, foreground_threshold, operator=np.greater)
 913
 914    marker_map = MultiTransformationWrapper(
 915        np.logical_and,
 916        ThresholdWrapper(center_distances, center_distance_threshold, operator=np.less),
 917        ThresholdWrapper(boundary_distances, boundary_distance_threshold, operator=np.less),
 918    )
 919    marker_map = MultiTransformationWrapper(np.logical_and, marker_map, fg_mask)
 920
 921    with _array_or_zarr(marker_map.shape, dtype="uint64", chunks=tile_shape, use_zarr=optimize_memory) as markers:
 922        markers = parallel_impl.label(
 923            marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose,
 924        )
 925
 926        if segmentation is None:
 927            segmentation = np.zeros(markers.shape, dtype="uint64")
 928        segmentation = parallel_impl.seeded_watershed(
 929            boundary_distances, seeds=markers, out=segmentation, block_shape=tile_shape,
 930            halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask,
 931        )
 932
 933    if min_size > 0:
 934        segmentation = parallel_impl.size_filter(
 935            segmentation, out=segmentation, min_size=min_size,
 936            block_shape=tile_shape, n_threads=n_threads, verbose=verbose
 937        )
 938
 939    return segmentation
 940
 941
 942def _apply_smoothing(foreground, foreground_smoothing, tile_shape, n_threads):
 943    if tile_shape is None:
 944        foreground = filter_impl.gaussian_smoothing(foreground, sigma=foreground_smoothing)
 945    else:
 946        foreground = apply_filter(
 947            foreground, "gaussianSmoothing", sigma=foreground_smoothing,
 948            block_shape=tile_shape, n_threads=n_threads
 949        )
 950    return foreground
 951
 952
 953class InstanceSegmentationWithDecoder:
 954    """Generates an instance segmentation without prompts, using a decoder.
 955
 956    Implements the same interface as `AutomaticMaskGenerator`.
 957
 958    Use this class as follows:
 959    ```python
 960    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
 961    segmenter.initialize(image)  # Predict the image embeddings and decoder outputs.
 962    masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
 963    ```
 964
 965    Args:
 966        predictor: The segment anything predictor.
 967        decoder: The decoder to predict intermediate representations for instance segmentation.
 968    """
 969    def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None:
 970        self._predictor = predictor
 971        self._decoder = decoder
 972
 973        # The decoder outputs.
 974        self._foreground = None
 975        self._center_distances = None
 976        self._boundary_distances = None
 977
 978        self._is_initialized = False
 979
 980    @property
 981    def is_initialized(self):
 982        """Whether the mask generator has already been initialized.
 983        """
 984        return self._is_initialized
 985
 986    @torch.no_grad()
 987    def initialize(
 988        self,
 989        image: np.ndarray,
 990        image_embeddings: Optional[util.ImageEmbeddings] = None,
 991        i: Optional[int] = None,
 992        verbose: bool = False,
 993        pbar_init: Optional[callable] = None,
 994        pbar_update: Optional[callable] = None,
 995        ndim: int = 2,
 996    ) -> None:
 997        """Initialize image embeddings and decoder predictions for an image.
 998
 999        Args:
1000            image: The input image, volume or timeseries.
1001            image_embeddings: Optional precomputed image embeddings.
1002                See `util.precompute_image_embeddings` for details.
1003            i: Index for the image data. Required if `image` has three spatial dimensions
1004                or a time dimension and two spatial dimensions.
1005            verbose: Whether to be verbose. By default, set to 'False'.
1006            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1007                Can be used together with pbar_update to handle napari progress bar in other thread.
1008                To enable using this function within a threadworker.
1009            pbar_update: Callback to update an external progress bar.
1010            ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2.
1011        """
1012        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1013        pbar_init(1, "Initialize instance segmentation with decoder")
1014
1015        if image_embeddings is None:
1016            image_embeddings = util.precompute_image_embeddings(
1017                predictor=self._predictor, input_=image, ndim=ndim, verbose=verbose
1018            )
1019
1020        # Get the image embeddings from the predictor.
1021        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
1022        embeddings = self._predictor.features
1023        input_shape = tuple(self._predictor.input_size)
1024        original_shape = tuple(self._predictor.original_size)
1025
1026        # Run prediction with the UNETR decoder.
1027        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1028        assert output.shape[0] == 3, f"{output.shape}"
1029        pbar_update(1)
1030        pbar_close()
1031
1032        # Set the state.
1033        self._foreground = output[0]
1034        self._center_distances = output[1]
1035        self._boundary_distances = output[2]
1036        self._i = i
1037        self._is_initialized = True
1038
1039    def _to_masks(self, segmentation, output_mode):
1040        if output_mode != "binary_mask":
1041            raise ValueError(
1042                f"Output mode {output_mode} is not supported. Choose one of 'instance_segmentation', 'binary_masks'"
1043            )
1044
1045        props = regionprops(segmentation)
1046        ndim = segmentation.ndim
1047        assert ndim in (2, 3)
1048
1049        shape = segmentation.shape
1050        if ndim == 2:
1051            crop_box = [0, shape[1], 0, shape[0]]
1052        else:
1053            crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]
1054
1055        # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
1056        def to_bbox_2d(bbox):
1057            y0, x0 = bbox[0], bbox[1]
1058            w = bbox[3] - x0
1059            h = bbox[2] - y0
1060            return [x0, w, y0, h]
1061
1062        def to_bbox_3d(bbox):
1063            z0, y0, x0 = bbox[0], bbox[1], bbox[2]
1064            w = bbox[5] - x0
1065            h = bbox[4] - y0
1066            d = bbox[3] - y0
1067            return [x0, w, y0, h, z0, d]
1068
1069        to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
1070        masks = [
1071            {
1072                "segmentation": segmentation == prop.label,
1073                "area": prop.area,
1074                "bbox": to_bbox(prop.bbox),
1075                "crop_box": crop_box,
1076                "seg_id": prop.label,
1077            } for prop in props
1078        ]
1079        return masks
1080
1081    def generate(
1082        self,
1083        center_distance_threshold: float = 0.5,
1084        boundary_distance_threshold: float = 0.5,
1085        foreground_threshold: float = 0.5,
1086        foreground_smoothing: float = 1.0,
1087        distance_smoothing: float = 1.6,
1088        min_size: int = 0,
1089        output_mode: str = "instance_segmentation",
1090        tile_shape: Optional[Tuple[int, int]] = None,
1091        halo: Optional[Tuple[int, int]] = None,
1092        n_threads: Optional[int] = None,
1093        optimize_memory: bool = False,
1094        segmentation: Optional[np.ndarray] = None,
1095    ) -> Union[List[Dict[str, Any]], np.ndarray]:
1096        """Generate instance segmentation for the currently initialized image.
1097
1098        Args:
1099            center_distance_threshold: Center distance predictions below this value will be
1100                used to find seeds (intersected with thresholded boundary distance predictions).
1101                By default, set to '0.5'.
1102            boundary_distance_threshold: Boundary distance predictions below this value will be
1103                used to find seeds (intersected with thresholded center distance predictions).
1104                By default, set to '0.5'.
1105            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1106                By default, set to '0.5'.
1107            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1108                checkerboard artifacts in the prediction. By default, set to '1.0'.
1109            distance_smoothing: Sigma value for smoothing the distance predictions.
1110            min_size: Minimal object size in the segmentation result. By default, set to '0'.
1111            output_mode: The form masks are returned in. Possible values are:
1112                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1113                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1114                By default, set to 'instance_segmentation'.
1115            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1116                This parameter is independent from the tile shape for computing the embeddings.
1117                If not given then post-processing will not be parallelized.
1118            halo: Halo for parallel post-processing. See also `tile_shape`.
1119            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1120            optimize_memory: Whether to optimize the memory consumption by allocating intermediate files.
1121            segmentation: Optional pre-allocated segmentation.
1122
1123        Returns:
1124            The segmentation masks.
1125        """
1126        if not self.is_initialized:
1127            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1128
1129        if foreground_smoothing > 0:
1130            foreground = _apply_smoothing(self._foreground, foreground_smoothing, tile_shape, n_threads)
1131        else:
1132            foreground = self._foreground
1133
1134        if tile_shape is None:
1135            segmentation = watershed_from_center_and_boundary_distances(
1136                center_distances=self._center_distances,
1137                boundary_distances=self._boundary_distances,
1138                foreground_map=foreground,
1139                center_distance_threshold=center_distance_threshold,
1140                boundary_distance_threshold=boundary_distance_threshold,
1141                foreground_threshold=foreground_threshold,
1142                distance_smoothing=distance_smoothing,
1143                min_size=min_size,
1144            )
1145        else:
1146            if halo is None:
1147                raise ValueError("You must pass a value for halo if tile_shape is given.")
1148
1149            # Shards are not thread-safe for parallel writing! So if we have shards we have to use them for tiling.
1150            # This is ok in terms efficiency as GPU tiles are small; shards should still be manegable for the watershed.
1151            if isinstance(segmentation, zarr.Array) and getattr(segmentation, "shards", None) is not None:
1152                tile_shape = segmentation.shards
1153
1154            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1155                center_distances=self._center_distances,
1156                boundary_distances=self._boundary_distances,
1157                foreground_map=foreground,
1158                center_distance_threshold=center_distance_threshold,
1159                boundary_distance_threshold=boundary_distance_threshold,
1160                foreground_threshold=foreground_threshold,
1161                distance_smoothing=distance_smoothing,
1162                min_size=min_size,
1163                tile_shape=tile_shape,
1164                halo=halo,
1165                n_threads=n_threads,
1166                verbose=False,
1167                optimize_memory=optimize_memory,
1168                segmentation=segmentation,
1169            )
1170
1171        if output_mode != "instance_segmentation":
1172            segmentation = self._to_masks(segmentation, output_mode)
1173        return segmentation
1174
1175    def get_state(self) -> Dict[str, Any]:
1176        """Get the initialized state of the instance segmenter.
1177
1178        Returns:
1179            Instance segmentation state.
1180        """
1181        if not self.is_initialized:
1182            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1183
1184        return {
1185            "foreground": self._foreground,
1186            "center_distances": self._center_distances,
1187            "boundary_distances": self._boundary_distances,
1188        }
1189
1190    def set_state(self, state: Dict[str, Any]) -> None:
1191        """Set the state of the instance segmenter.
1192
1193        Args:
1194            state: The instance segmentation state
1195        """
1196        self._foreground = state["foreground"]
1197        self._center_distances = state["center_distances"]
1198        self._boundary_distances = state["boundary_distances"]
1199        self._is_initialized = True
1200
1201    def clear_state(self):
1202        """Clear the state of the instance segmenter.
1203        """
1204        self._foreground = None
1205        self._center_distances = None
1206        self._boundary_distances = None
1207        self._is_initialized = False
1208
1209
1210class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1211    """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings.
1212    """
1213
1214    # Apply the decoder in a batched fashion, and then perform the resizing independently per output.
1215    # This is necessary, because the individual tiles may have different tile shapes due to border tiles.
1216    def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes):
1217        batched_embeddings = torch.cat(batched_embeddings)
1218        output = self._decoder._forward_impl(batched_embeddings)
1219
1220        batched_output = []
1221        for x, input_shape, original_shape in zip(output, input_shapes, original_shapes):
1222            x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0)
1223            batched_output.append(x.cpu().numpy())
1224        return batched_output
1225
1226    @torch.no_grad()
1227    def initialize(
1228        self,
1229        image: np.ndarray,
1230        image_embeddings: Optional[util.ImageEmbeddings] = None,
1231        i: Optional[int] = None,
1232        tile_shape: Optional[Tuple[int, int]] = None,
1233        halo: Optional[Tuple[int, int]] = None,
1234        verbose: bool = False,
1235        pbar_init: Optional[callable] = None,
1236        pbar_update: Optional[callable] = None,
1237        batch_size: int = 1,
1238        mask: Optional[np.typing.ArrayLike] = None,
1239    ) -> None:
1240        """Initialize image embeddings and decoder predictions for an image.
1241
1242        Args:
1243            image: The input image, volume or timeseries.
1244            image_embeddings: Optional precomputed image embeddings.
1245                See `util.precompute_image_embeddings` for details.
1246            i: Index for the image data. Required if `image` has three spatial dimensions
1247                or a time dimension and two spatial dimensions.
1248            tile_shape: Shape of the tiles for precomputing image embeddings.
1249            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1250            verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
1251            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1252                Can be used together with pbar_update to handle napari progress bar in other thread.
1253                To enable using this function within a threadworker.
1254            pbar_update: Callback to update an external progress bar.
1255            batch_size: The batch size for image embedding computation and segmentation decoder prediction.
1256            mask: An optional mask to define areas that are ignored in the segmentation.
1257        """
1258        original_size = image.shape[:2]
1259        self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings(
1260            self._predictor, image, image_embeddings, tile_shape, halo,
1261            verbose=verbose, batch_size=batch_size, mask=mask, i=i,
1262        )
1263        tiling = Blocking([0, 0], original_size, tile_shape)
1264
1265        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1266
1267        foreground = np.zeros(original_size, dtype="float32")
1268        center_distances = np.zeros(original_size, dtype="float32")
1269        boundary_distances = np.zeros(original_size, dtype="float32")
1270
1271        msg = "Initialize tiled instance segmentation with decoder"
1272        if tiles_in_mask is None:
1273            n_tiles = tiling.number_of_blocks
1274            all_tile_ids = list(range(n_tiles))
1275        else:
1276            n_tiles = len(tiles_in_mask)
1277            all_tile_ids = tiles_in_mask
1278            msg += " and mask"
1279
1280        n_batches = int(np.ceil(n_tiles / batch_size))
1281        pbar_init(n_tiles, msg)
1282        tile_ids_for_batches = np.array_split(all_tile_ids, n_batches)
1283
1284        for tile_ids in tile_ids_for_batches:
1285            batched_embeddings, input_shapes, original_shapes = [], [], []
1286            for tile_id in tile_ids:
1287                # Get the image embeddings from the predictor for this tile.
1288                self._predictor = util.set_precomputed(self._predictor, self._image_embeddings, i=i, tile_id=tile_id)
1289
1290                batched_embeddings.append(self._predictor.features)
1291                input_shapes.append(tuple(self._predictor.input_size))
1292                original_shapes.append(tuple(self._predictor.original_size))
1293
1294            batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes)
1295
1296            for output_id, tile_id in enumerate(tile_ids):
1297                output = batched_output[output_id]
1298                assert output.shape[0] == 3
1299
1300                # Set the predictions in the output for this tile.
1301                block = tiling.get_block_with_halo(tile_id, halo=list(halo))
1302                local_bb = tuple(
1303                    slice(beg, end) for beg, end in zip(block.inner_block_local.begin, block.inner_block_local.end)
1304                )
1305                inner_bb = tuple(slice(beg, end) for beg, end in zip(block.inner_block.begin, block.inner_block.end))
1306
1307                foreground[inner_bb] = output[0][local_bb]
1308                center_distances[inner_bb] = output[1][local_bb]
1309                boundary_distances[inner_bb] = output[2][local_bb]
1310                pbar_update(1)
1311
1312        pbar_close()
1313
1314        # Set the state.
1315        self._i = i
1316        self._foreground = foreground
1317        self._center_distances = center_distances
1318        self._boundary_distances = boundary_distances
1319        self._is_initialized = True
1320
1321
1322def _get_centers(segmentation, avoid_image_border=True):
1323    # Not sure if 'find_boundaries' has to be paralellized.
1324    # Compute the distance transform to object bounaries.
1325    boundaries = find_boundaries(segmentation, mode="outer") == 0
1326    # Avoid centroids on the border.
1327    if avoid_image_border:
1328        boundaries[0, :] = False
1329        boundaries[:, 0] = False
1330        boundaries[-1, :] = False
1331        boundaries[:, -1] = False
1332    block_shape = (512, 512)
1333    halo = (16, 16)
1334    distances = parallel_impl.distance_transform(boundaries, halo=halo, block_shape=block_shape)
1335
1336    # Get the maximum coordinate of the distance transform for each id.
1337    props = regionprops(segmentation)
1338    centers = []
1339    for prop in props:
1340        seg_id = prop.label
1341        # Get the bounding box and mask for this segmentation id.
1342        bounding_box = np.s_[prop.bbox[0]:prop.bbox[2], prop.bbox[1]:prop.bbox[3]]
1343        mask = segmentation[bounding_box] == seg_id
1344        # Restrict the distances to the bounding box.
1345        dist = distances[bounding_box].copy()
1346        # Set the distances outside of the mask to 0.
1347        dist[~mask] = 0
1348        # Find the center coordinate (= distance maximum).
1349        center = np.argmax(dist)
1350        center = np.unravel_index(center, dist.shape)
1351        # Bring the center coordinate back to the full coordinate system.
1352        center = tuple(ce + bb.start for ce, bb in zip(center, bounding_box))
1353        centers.append(center)
1354
1355    return np.array(centers)
1356
1357
1358def _derive_point_prompts(
1359    foreground: np.ndarray,
1360    center_distances: np.ndarray,
1361    boundary_distances: np.ndarray,
1362    foreground_threshold: float = 0.5,
1363    center_distance_threshold: float = 0.5,
1364    boundary_distance_threshold: float = 0.5,
1365):
1366    bg_mask = foreground < foreground_threshold
1367    hmap_cc = np.logical_and(
1368        center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold,
1369    )
1370    hmap_cc[bg_mask] = 0
1371    cc = np.zeros_like(hmap_cc, dtype="uint32")
1372    cc = parallel_impl.label(hmap_cc, out=cc, block_shape=(512, 512))
1373    prompts = _get_centers(cc)
1374    if len(prompts) == 0:
1375        return None
1376
1377    points = prompts[:, None, ::-1]
1378    labels = np.ones((len(prompts), 1))
1379    return {"points": points, "point_labels": labels}
1380
1381
1382def _derive_box_prompts(predictions, box_extension):
1383    shape = predictions[0]["segmentation"].shape
1384    bboxes = [pred["bbox"] for pred in predictions]
1385    prompts = [[
1386        max(x - w * box_extension, 0),
1387        max(y - h * box_extension, 0),
1388        min(x + (1 + box_extension) * w, shape[0]),
1389        min(y + (1 + box_extension) * h, shape[1]),
1390    ] for (x, y, w, h) in bboxes]
1391    return {"boxes": np.array(prompts)}
1392
1393
1394class AutomaticPromptGenerator(InstanceSegmentationWithDecoder):
1395    """Generates an instance segmentation automatically, using automatically generated prompts from a decoder.
1396
1397    This class is used in the same way as `InstanceSegmentationWithDecoder` and `AutomaticMaskGenerator`
1398
1399    Args:
1400        predictor: The segment anything predictor.
1401        decoder: The derive prompts for automatic instance segmentation.
1402    """
1403    def generate(
1404        self,
1405        min_size: int = 25,
1406        center_distance_threshold: float = 0.5,
1407        boundary_distance_threshold: float = 0.5,
1408        foreground_threshold: float = 0.5,
1409        multimasking: bool = False,
1410        batch_size: int = 32,
1411        nms_threshold: float = 0.9,
1412        intersection_over_min: bool = False,
1413        output_mode: str = "instance_segmentation",
1414        mask_threshold: Optional[Union[float, str]] = None,
1415        refine_with_box_prompts: bool = False,
1416        prompt_function: Optional[callable] = None,
1417    ) -> Union[List[Dict[str, Any]], np.ndarray]:
1418        """Generate the instance segmentation for the currently initialized image.
1419
1420        The instance segmentation is generated by deriving prompts from the foreground and
1421        distance predictions of the segmentation decoder by thresholding these predictions,
1422        intersecting them, computing connected components, and then using the component's
1423        centers as point prompts. The masks are then filtered via NMS and merged into a segmentation.
1424
1425        Args:
1426            min_size: Minimal object size in the segmentation result. By default, set to '25'.
1427            center_distance_threshold: The threshold for the center distance predictions.
1428            boundary_distance_threshold: The threshold for the boundary distance predictions.
1429            multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
1430            batch_size: The batch size for parallelizing the prediction based on prompts.
1431            nms_threshold: The threshold for non-maximum suppression (NMS).
1432            intersection_over_min: Whether to use the minimum area of the two objects or the
1433                intersection over union area (default) in NMS.
1434            output_mode: The form masks are returned in. Possible values are:
1435                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1436                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1437                By default, set to 'instance_segmentation'.
1438            mask_threshold: The threshold for turning logits into masks in `micro_sam.inference.batched_inference`.`
1439            refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps
1440                derived from the segmentations after point prompts.
1441            prompt_function: A custom function for deriving prompts from the segmentation decoder predictions.
1442                If given, the default prompt derivation procedure is not used. Must have the following signature:
1443                ```
1444                    def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
1445                ```
1446                where `foreground`, `center_distances`, and `boundary_distances` are the respective
1447                predictions from the segmentation decoder. It must returns a dictionary containing
1448                either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`.
1449
1450        Returns:
1451            The instance segmentation masks.
1452        """
1453        if not self.is_initialized:
1454            raise RuntimeError("AutomaticPromptGenerator has not been initialized. Call initialize first.")
1455        foreground, center_distances, boundary_distances =\
1456            self._foreground, self._center_distances, self._boundary_distances
1457
1458        # 1.) Derive promtps from the decoder predictions.
1459        prompt_function = _derive_point_prompts if prompt_function is None else prompt_function
1460        prompts = prompt_function(
1461            foreground=foreground,
1462            center_distances=center_distances,
1463            boundary_distances=boundary_distances,
1464            foreground_threshold=foreground_threshold,
1465            center_distance_threshold=center_distance_threshold,
1466            boundary_distance_threshold=boundary_distance_threshold,
1467        )
1468
1469        # 2.) Apply the predictor to the prompts.
1470        if prompts is None:  # No prompts were derived, we can't do much further and return empty masks.
1471            return np.zeros(foreground.shape, dtype="uint32") if output_mode == "instance_segmentation" else []
1472        else:
1473            predictions = batched_inference(
1474                self._predictor,
1475                image=None,
1476                batch_size=batch_size,
1477                return_instance_segmentation=False,
1478                multimasking=multimasking,
1479                mask_threshold=mask_threshold,
1480                i=getattr(self, "_i", None),
1481                **prompts,
1482            )
1483
1484        # 3.) Refine the segmentation with box prompts.
1485        if refine_with_box_prompts:
1486            box_extension = 0.01  # expose as hyperparam?
1487            prompts = _derive_box_prompts(predictions, box_extension)
1488            predictions = batched_inference(
1489                self._predictor,
1490                image=None,
1491                batch_size=batch_size,
1492                return_instance_segmentation=False,
1493                multimasking=multimasking,
1494                mask_threshold=mask_threshold,
1495                i=getattr(self, "_i", None),
1496                **prompts,
1497            )
1498
1499        # 4.) Apply non-max suppression to the masks.
1500        segmentation = util.apply_nms(
1501            predictions, min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min
1502        )
1503        if output_mode != "instance_segmentation":
1504            segmentation = self._to_masks(segmentation, output_mode)
1505        return segmentation
1506
1507
1508class TiledAutomaticPromptGenerator(TiledInstanceSegmentationWithDecoder):
1509    """Same as `AutomaticPromptGenerator` but for tiled image embeddings.
1510    """
1511    def generate(
1512        self,
1513        min_size: int = 25,
1514        center_distance_threshold: float = 0.5,
1515        boundary_distance_threshold: float = 0.5,
1516        foreground_threshold: float = 0.5,
1517        multimasking: bool = False,
1518        batch_size: int = 32,
1519        nms_threshold: float = 0.9,
1520        intersection_over_min: bool = False,
1521        output_mode: str = "instance_segmentation",
1522        mask_threshold: Optional[Union[float, str]] = None,
1523        refine_with_box_prompts: bool = False,
1524        prompt_function: Optional[callable] = None,
1525        optimize_memory: bool = False,
1526    ) -> List[Dict[str, Any]]:
1527        """Generate tiling-based instance segmentation for the currently initialized image.
1528
1529        Args:
1530            min_size: Minimal object size in the segmentation result. By default, set to '25'.
1531            center_distance_threshold: The threshold for the center distance predictions.
1532            boundary_distance_threshold: The threshold for the boundary distance predictions.
1533            multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
1534            batch_size: The batch size for parallelizing the prediction based on prompts.
1535            nms_threshold: The threshold for non-maximum suppression (NMS).
1536            intersection_over_min: Whether to use the minimum area of the two objects or the
1537                intersection over union area (default) in NMS.
1538            output_mode: The form masks are returned in. Possible values are:
1539                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1540                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1541                By default, set to 'instance_segmentation'.
1542            mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.`
1543            refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps
1544                derived from the segmentations after point prompts. Currently not supported for tiled segmentation.
1545            prompt_function: A custom function for deriving prompts from the segmentation decoder predictions.
1546                If given, the default prompt derivation procedure is not used. Must have the following signature:
1547                ```
1548                    def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
1549                ```
1550                where `foreground`, `center_distances`, and `boundary_distances` are the respective
1551                predictions from the segmentation decoder. It must returns a dictionary containing
1552                either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`.
1553            optimize_memory: Whether to optimize the memory consumption by merging the per-slice
1554                segmentation results immediatly with NMS, rather than running a NMS for all results.
1555                This may lead to a slightly different segmentation result and is only compatible with
1556                `refine_with_box_prompts=False` and `output_mode="instance_segmentation"`.
1557
1558        Returns:
1559            The instance segmentation masks.
1560        """
1561        if not self.is_initialized:
1562            raise RuntimeError("TiledAutomaticPromptGenerator has not been initialized. Call initialize first.")
1563        if optimize_memory and (output_mode != "instance_segmentation" or refine_with_box_prompts):
1564            raise ValueError("Invalid settings")
1565        foreground, center_distances, boundary_distances =\
1566            self._foreground, self._center_distances, self._boundary_distances
1567
1568        # 1.) Derive promtps from the decoder predictions.
1569        prompt_function = _derive_point_prompts if prompt_function is None else prompt_function
1570        prompts = prompt_function(
1571            foreground,
1572            center_distances,
1573            boundary_distances,
1574            foreground_threshold=foreground_threshold,
1575            center_distance_threshold=center_distance_threshold,
1576            boundary_distance_threshold=boundary_distance_threshold,
1577        )
1578
1579        # 2.) Apply the predictor to the prompts.
1580        shape = foreground.shape
1581        if prompts is None:  # No prompts were derived, we can't do much further and return empty masks.
1582            return np.zeros(shape, dtype="uint32") if output_mode == "instance_segmentation" else []
1583        else:
1584            if optimize_memory:
1585                prompts.update(dict(
1586                    min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min
1587                ))
1588            predictions = batched_tiled_inference(
1589                self._predictor,
1590                image=None,
1591                batch_size=batch_size,
1592                image_embeddings=self._image_embeddings,
1593                return_instance_segmentation=False,
1594                multimasking=multimasking,
1595                optimize_memory=optimize_memory,
1596                i=getattr(self, "_i", None),
1597                **prompts
1598            )
1599        # Optimize memory directly returns an instance segmentation and does not
1600        # allow for any further refinements.
1601        if optimize_memory:
1602            return predictions
1603
1604        # 3.) Refine the segmentation with box prompts.
1605        if refine_with_box_prompts:
1606            # TODO
1607            raise NotImplementedError
1608
1609        # 4.) Apply non-max suppression to the masks.
1610        segmentation = util.apply_nms(
1611            predictions, shape=shape, min_size=min_size, nms_thresh=nms_threshold,
1612            intersection_over_min=intersection_over_min,
1613        )
1614        if output_mode != "instance_segmentation":
1615            segmentation = self._to_masks(segmentation, output_mode)
1616        return segmentation
1617
1618    # Set state and get state are not implemented yet, as this generator relies on having the image embeddings
1619    # in the state. However, they should not be serialized here and we have to address this a bit differently.
1620    def get_state(self):
1621        """@private
1622        """
1623        raise NotImplementedError
1624
1625    def set_state(self, state):
1626        """@private
1627        """
1628        raise NotImplementedError
1629
1630
1631def get_instance_segmentation_generator(
1632    predictor: SamPredictor,
1633    is_tiled: bool,
1634    decoder: Optional[torch.nn.Module] = None,
1635    segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None,
1636    **kwargs,
1637) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1638    f"""Get the automatic mask generator.
1639
1640    Args:
1641        predictor: The segment anything predictor.
1642        is_tiled: Whether tiled embeddings are used.
1643        decoder: Decoder to predict instacne segmmentation.
1644        segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
1645            By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used if a decoder is passed,
1646            otherwise 'amg' is used.
1647        kwargs: The keyword arguments of the segmentation genetator class.
1648
1649    Returns:
1650        The segmentation generator instance.
1651    """
1652    # Choose the segmentation decoder default depending on whether we have a decoder.
1653    if segmentation_mode is None:
1654        segmentation_mode = "amg" if decoder is None else DEFAULT_SEGMENTATION_MODE_WITH_DECODER
1655
1656    if segmentation_mode.lower() == "amg":
1657        segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator
1658        segmenter = segmenter_class(predictor, **kwargs)
1659    elif segmentation_mode.lower() == "ais":
1660        assert decoder is not None
1661        segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder
1662        segmenter = segmenter_class(predictor, decoder, **kwargs)
1663    elif segmentation_mode.lower() == "apg":
1664        assert decoder is not None
1665        segmenter_class = TiledAutomaticPromptGenerator if is_tiled else AutomaticPromptGenerator
1666        segmenter = segmenter_class(predictor, decoder, **kwargs)
1667    else:
1668        raise ValueError(f"Invalid segmentation_mode: {segmentation_mode}. Choose one of 'amg', 'ais', or 'apg'.")
1669
1670    return segmenter
DEFAULT_SEGMENTATION_MODE_WITH_DECODER = 'ais'
class AMGBase(abc.ABC):
 66class AMGBase(ABC):
 67    """Base class for the automatic mask generators.
 68    """
 69    def __init__(self):
 70        # the state that has to be computed by the 'initialize' method of the child classes
 71        self._is_initialized = False
 72        self._crop_list = None
 73        self._crop_boxes = None
 74        self._original_size = None
 75
 76    @property
 77    def is_initialized(self):
 78        """Whether the mask generator has already been initialized.
 79        """
 80        return self._is_initialized
 81
 82    @property
 83    def crop_list(self):
 84        """The list of mask data after initialization.
 85        """
 86        return self._crop_list
 87
 88    @property
 89    def crop_boxes(self):
 90        """The list of crop boxes.
 91        """
 92        return self._crop_boxes
 93
 94    @property
 95    def original_size(self):
 96        """The original image size.
 97        """
 98        return self._original_size
 99
100    def _postprocess_batch(
101        self,
102        data,
103        crop_box,
104        original_size,
105        pred_iou_thresh,
106        stability_score_thresh,
107        box_nms_thresh,
108    ):
109        orig_h, orig_w = original_size
110
111        # filter by predicted IoU
112        if pred_iou_thresh > 0.0:
113            keep_mask = data["iou_preds"] > pred_iou_thresh
114            data.filter(keep_mask)
115
116        # filter by stability score
117        if stability_score_thresh > 0.0:
118            keep_mask = data["stability_score"] >= stability_score_thresh
119            data.filter(keep_mask)
120
121        # filter boxes that touch crop boundaries
122        keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
123        if not torch.all(keep_mask):
124            data.filter(keep_mask)
125
126        # remove duplicates within this crop.
127        keep_by_nms = batched_nms(
128            data["boxes"].float(),
129            data["iou_preds"],
130            torch.zeros_like(data["boxes"][:, 0]),  # categories
131            iou_threshold=box_nms_thresh,
132        )
133        data.filter(keep_by_nms)
134
135        # return to the original image frame
136        data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box)
137        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
138        # the data from embedding based segmentation doesn't have the points
139        # so we skip if the corresponding key can't be found
140        try:
141            data["points"] = amg_utils.uncrop_points(data["points"], crop_box)
142        except KeyError:
143            pass
144
145        return data
146
147    def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
148
149        if len(mask_data["rles"]) == 0:
150            return mask_data
151
152        # filter small disconnected regions and holes
153        new_masks = []
154        scores = []
155        for rle in mask_data["rles"]:
156            mask = amg_utils.rle_to_mask(rle)
157
158            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes")
159            unchanged = not changed
160            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands")
161            unchanged = unchanged and not changed
162
163            new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0))
164            # give score=0 to changed masks and score=1 to unchanged masks
165            # so NMS will prefer ones that didn't need postprocessing
166            scores.append(float(unchanged))
167
168        # recalculate boxes and remove any new duplicates
169        masks = torch.cat(new_masks, dim=0)
170        boxes = batched_mask_to_box(masks.to(torch.bool))  # Casting this to boolean as we work with one-hot labels.
171        keep_by_nms = batched_nms(
172            boxes.float(),
173            torch.as_tensor(scores, dtype=torch.float),
174            torch.zeros_like(boxes[:, 0]),  # categories
175            iou_threshold=nms_thresh,
176        )
177
178        # only recalculate RLEs for masks that have changed
179        for i_mask in keep_by_nms:
180            if scores[i_mask] == 0.0:
181                mask_torch = masks[i_mask].unsqueeze(0)
182                # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
183                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
184                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
185        mask_data.filter(keep_by_nms)
186
187        return mask_data
188
189    def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode):
190        # filter small disconnected regions and holes in masks
191        if min_mask_region_area > 0:
192            mask_data = self._postprocess_small_regions(
193                mask_data,
194                min_mask_region_area,
195                max(box_nms_thresh, crop_nms_thresh),
196            )
197
198        # encode masks
199        if output_mode == "coco_rle":
200            mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]]
201        # We require the binary mask output for creating the instance-seg output.
202        elif output_mode in ("binary_mask", "instance_segmentation"):
203            mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]]
204        elif output_mode == "rle":
205            mask_data["segmentations"] = mask_data["rles"]
206        else:
207            raise ValueError(f"Invalid output mode {output_mode}.")
208
209        # write mask records
210        curr_anns = []
211        for idx in range(len(mask_data["segmentations"])):
212            ann = {
213                "segmentation": mask_data["segmentations"][idx],
214                "area": amg_utils.area_from_rle(mask_data["rles"][idx]),
215                "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
216                "predicted_iou": mask_data["iou_preds"][idx].item(),
217                "stability_score": mask_data["stability_score"][idx].item(),
218                "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
219            }
220            # the data from embedding based segmentation doesn't have the points
221            # so we skip if the corresponding key can't be found
222            try:
223                ann["point_coords"] = [mask_data["points"][idx].tolist()]
224            except KeyError:
225                pass
226            curr_anns.append(ann)
227
228        return curr_anns
229
230    def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
231        orig_h, orig_w = original_size
232
233        # serialize predictions and store in MaskData
234        data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
235        if points is not None:
236            data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
237
238        del masks
239
240        # calculate the stability scores
241        data["stability_score"] = amg_utils.calculate_stability_score(
242            data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
243        )
244
245        # threshold masks and calculate boxes
246        data["masks"] = data["masks"] > self._predictor.model.mask_threshold
247        data["masks"] = data["masks"].type(torch.bool)
248        data["boxes"] = batched_mask_to_box(data["masks"])
249
250        # compress to RLE
251        data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
252        # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
253        data["rles"] = mask_to_rle_pytorch(data["masks"])
254        del data["masks"]
255
256        return data
257
258    def get_state(self) -> Dict[str, Any]:
259        """Get the initialized state of the mask generator.
260
261        Returns:
262            State of the mask generator.
263        """
264        if not self.is_initialized:
265            raise RuntimeError("The state has not been computed yet. Call initialize first.")
266
267        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
268
269    def set_state(self, state: Dict[str, Any]) -> None:
270        """Set the state of the mask generator.
271
272        Args:
273            state: The state of the mask generator, e.g. from serialized state.
274        """
275        self._crop_list = state["crop_list"]
276        self._crop_boxes = state["crop_boxes"]
277        self._original_size = state["original_size"]
278        self._is_initialized = True
279
280    def clear_state(self):
281        """Clear the state of the mask generator.
282        """
283        self._crop_list = None
284        self._crop_boxes = None
285        self._original_size = None
286        self._is_initialized = False

Base class for the automatic mask generators.

is_initialized
76    @property
77    def is_initialized(self):
78        """Whether the mask generator has already been initialized.
79        """
80        return self._is_initialized

Whether the mask generator has already been initialized.

crop_list
82    @property
83    def crop_list(self):
84        """The list of mask data after initialization.
85        """
86        return self._crop_list

The list of mask data after initialization.

crop_boxes
88    @property
89    def crop_boxes(self):
90        """The list of crop boxes.
91        """
92        return self._crop_boxes

The list of crop boxes.

original_size
94    @property
95    def original_size(self):
96        """The original image size.
97        """
98        return self._original_size

The original image size.

def get_state(self) -> Dict[str, Any]:
258    def get_state(self) -> Dict[str, Any]:
259        """Get the initialized state of the mask generator.
260
261        Returns:
262            State of the mask generator.
263        """
264        if not self.is_initialized:
265            raise RuntimeError("The state has not been computed yet. Call initialize first.")
266
267        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}

Get the initialized state of the mask generator.

Returns:

State of the mask generator.

def set_state(self, state: Dict[str, Any]) -> None:
269    def set_state(self, state: Dict[str, Any]) -> None:
270        """Set the state of the mask generator.
271
272        Args:
273            state: The state of the mask generator, e.g. from serialized state.
274        """
275        self._crop_list = state["crop_list"]
276        self._crop_boxes = state["crop_boxes"]
277        self._original_size = state["original_size"]
278        self._is_initialized = True

Set the state of the mask generator.

Arguments:
  • state: The state of the mask generator, e.g. from serialized state.
def clear_state(self):
280    def clear_state(self):
281        """Clear the state of the mask generator.
282        """
283        self._crop_list = None
284        self._crop_boxes = None
285        self._original_size = None
286        self._is_initialized = False

Clear the state of the mask generator.

class AutomaticMaskGenerator(AMGBase):
289class AutomaticMaskGenerator(AMGBase):
290    """Generates an instance segmentation without prompts, using a point grid.
291
292    This class implements the same logic as
293    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
294    It decouples the computationally expensive steps of generating masks from the cheap post-processing operation
295    to filter these masks to enable grid search and interactively changing the post-processing.
296
297    Use this class as follows:
298    ```python
299    amg = AutomaticMaskGenerator(predictor)
300    amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
301    masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
302    ```
303
304    Args:
305        predictor: The segment anything predictor.
306        points_per_side: The number of points to be sampled along one side of the image.
307            If None, `point_grids` must provide explicit point sampling. By default, set to '32'.
308        points_per_batch: The number of points run simultaneously by the model.
309            Higher numbers may be faster but use more GPU memory.
310            By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons).
311        crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
312            By default, set to '0'.
313        crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'.
314        crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
315            By default, set to '1'.
316        point_grids: A list over explicit grids of points used for sampling masks.
317            Normalized to [0, 1] with respect to the image coordinate system.
318        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
319            By default, set to '1.0'.
320    """
321    def __init__(
322        self,
323        predictor: SamPredictor,
324        points_per_side: Optional[int] = 32,
325        points_per_batch: Optional[int] = None,
326        crop_n_layers: int = 0,
327        crop_overlap_ratio: float = 512 / 1500,
328        crop_n_points_downscale_factor: int = 1,
329        point_grids: Optional[List[np.ndarray]] = None,
330        stability_score_offset: float = 1.0,
331    ):
332        super().__init__()
333
334        if points_per_side is not None:
335            self.point_grids = amg_utils.build_all_layer_point_grids(
336                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
337            )
338        elif point_grids is not None:
339            self.point_grids = point_grids
340        else:
341            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
342
343        self._predictor = predictor
344        self._points_per_side = points_per_side
345
346        # we set the points per batch to 16 for mps for performance reasons
347        # and otherwise keep them at the default of 64
348        if points_per_batch is None:
349            points_per_batch = 16 if str(predictor.device) == "mps" else 64
350        self._points_per_batch = points_per_batch
351
352        self._crop_n_layers = crop_n_layers
353        self._crop_overlap_ratio = crop_overlap_ratio
354        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
355        self._stability_score_offset = stability_score_offset
356
357    def _process_batch(self, points, im_size, crop_box, original_size):
358        # run model on this batch
359        transformed_points = self._predictor.transform.apply_coords(points, im_size)
360        in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
361        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
362        masks, iou_preds, _ = self._predictor.predict_torch(
363            point_coords=in_points[:, None, :],
364            point_labels=in_labels[:, None],
365            multimask_output=True,
366            return_logits=True,
367        )
368        data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
369        del masks
370        return data
371
372    def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None):
373        # Crop the image and calculate embeddings.
374        x0, y0, x1, y1 = crop_box
375        cropped_im = image[y0:y1, x0:x1, :]
376        cropped_im_size = cropped_im.shape[:2]
377
378        if not precomputed_embeddings:
379            self._predictor.set_image(cropped_im)
380
381        # Get the points for this crop.
382        points_scale = np.array(cropped_im_size)[None, ::-1]
383        points_for_image = self.point_grids[crop_layer_idx] * points_scale
384
385        # Generate masks for this crop in batches.
386        data = amg_utils.MaskData()
387        n_batches = len(points_for_image) // self._points_per_batch +\
388            int(len(points_for_image) % self._points_per_batch != 0)
389        if pbar_init is not None:
390            pbar_init(n_batches, "Predict masks for point grid prompts")
391
392        for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image):
393            batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
394            data.cat(batch_data)
395            del batch_data
396            if pbar_update is not None:
397                pbar_update(1)
398
399        if not precomputed_embeddings:
400            self._predictor.reset_image()
401
402        return data
403
404    @torch.no_grad()
405    def initialize(
406        self,
407        image: np.ndarray,
408        image_embeddings: Optional[util.ImageEmbeddings] = None,
409        i: Optional[int] = None,
410        verbose: bool = False,
411        pbar_init: Optional[callable] = None,
412        pbar_update: Optional[callable] = None,
413    ) -> None:
414        """Initialize image embeddings and masks for an image.
415
416        Args:
417            image: The input image, volume or timeseries.
418            image_embeddings: Optional precomputed image embeddings.
419                See `util.precompute_image_embeddings` for details.
420            i: Index for the image data. Required if `image` has three spatial dimensions
421                or a time dimension and two spatial dimensions.
422            verbose: Whether to print computation progress. By default, set to 'False'.
423            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
424                Can be used together with pbar_update to handle napari progress bar in other thread.
425                To enable using this function within a threadworker.
426            pbar_update: Callback to update an external progress bar.
427        """
428        original_size = image.shape[:2]
429        self._original_size = original_size
430
431        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
432            original_size, self._crop_n_layers, self._crop_overlap_ratio
433        )
434
435        # We can set fixed image embeddings if we only have a single crop box (the default setting).
436        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
437        if len(crop_boxes) == 1:
438            if image_embeddings is None:
439                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
440            util.set_precomputed(self._predictor, image_embeddings, i=i)
441            precomputed_embeddings = True
442        else:
443            precomputed_embeddings = False
444
445        # we need to cast to the image representation that is compatible with SAM
446        image = util._to_image(image)
447
448        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
449
450        crop_list = []
451        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
452            crop_data = self._process_crop(
453                image, crop_box, layer_idx,
454                precomputed_embeddings=precomputed_embeddings,
455                pbar_init=pbar_init, pbar_update=pbar_update,
456            )
457            crop_list.append(crop_data)
458        pbar_close()
459
460        self._is_initialized = True
461        self._crop_list = crop_list
462        self._crop_boxes = crop_boxes
463
464    @torch.no_grad()
465    def generate(
466        self,
467        pred_iou_thresh: float = 0.88,
468        stability_score_thresh: float = 0.95,
469        box_nms_thresh: float = 0.7,
470        crop_nms_thresh: float = 0.7,
471        min_mask_region_area: int = 0,
472        output_mode: str = "instance_segmentation",
473        with_background: bool = True,
474    ) -> Union[List[Dict[str, Any]], np.ndarray]:
475        """Generate instance segmentation for the currently initialized image.
476
477        Args:
478            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
479                By default, set to '0.88'.
480            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
481                under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
482            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
483                By default, set to '0.7'.
484            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
485                By default, set to '0.7'.
486            min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
487            output_mode: The form masks are returned in. Possible values are:
488                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
489                - 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format.
490                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
491                - 'rle': Return a list of dictionaries with run-length encoded masks.
492                By default, set to 'instance_segmentation'.
493            with_background: Whether to remove the largest object, which often covers the background.
494
495        Returns:
496            The segmentation masks.
497        """
498        if not self.is_initialized:
499            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
500
501        data = amg_utils.MaskData()
502        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
503            crop_data = self._postprocess_batch(
504                data=deepcopy(data_),
505                crop_box=crop_box, original_size=self.original_size,
506                pred_iou_thresh=pred_iou_thresh,
507                stability_score_thresh=stability_score_thresh,
508                box_nms_thresh=box_nms_thresh
509            )
510            data.cat(crop_data)
511
512        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
513            # Prefer masks from smaller crops
514            scores = 1 / box_area(data["crop_boxes"])
515            scores = scores.to(data["boxes"].device)
516            keep_by_nms = batched_nms(
517                data["boxes"].float(),
518                scores,
519                torch.zeros_like(data["boxes"][:, 0]),  # categories
520                iou_threshold=crop_nms_thresh,
521            )
522            data.filter(keep_by_nms)
523
524        data.to_numpy()
525        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
526        if output_mode == "instance_segmentation":
527            shape = next(iter(masks))["segmentation"].shape if len(masks) > 0 else self.original_size
528            masks = util.mask_data_to_segmentation(
529                masks, shape=shape, with_background=with_background, merge_exclusively=False
530            )
531        return masks

Generates an instance segmentation without prompts, using a point grid.

This class implements the same logic as https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py It decouples the computationally expensive steps of generating masks from the cheap post-processing operation to filter these masks to enable grid search and interactively changing the post-processing.

Use this class as follows:

amg = AutomaticMaskGenerator(predictor)
amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
Arguments:
  • predictor: The segment anything predictor.
  • points_per_side: The number of points to be sampled along one side of the image. If None, point_grids must provide explicit point sampling. By default, set to '32'.
  • points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons).
  • crop_n_layers: If >0, the mask prediction will be run again on crops of the image. By default, set to '0'.
  • crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'.
  • crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. By default, set to '1'.
  • point_grids: A list over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
  • stability_score_offset: The amount to shift the cutoff when calculating the stability score. By default, set to '1.0'.
AutomaticMaskGenerator( predictor: segment_anything.predictor.SamPredictor, points_per_side: Optional[int] = 32, points_per_batch: Optional[int] = None, crop_n_layers: int = 0, crop_overlap_ratio: float = 0.3413333333333333, crop_n_points_downscale_factor: int = 1, point_grids: Optional[List[numpy.ndarray]] = None, stability_score_offset: float = 1.0)
321    def __init__(
322        self,
323        predictor: SamPredictor,
324        points_per_side: Optional[int] = 32,
325        points_per_batch: Optional[int] = None,
326        crop_n_layers: int = 0,
327        crop_overlap_ratio: float = 512 / 1500,
328        crop_n_points_downscale_factor: int = 1,
329        point_grids: Optional[List[np.ndarray]] = None,
330        stability_score_offset: float = 1.0,
331    ):
332        super().__init__()
333
334        if points_per_side is not None:
335            self.point_grids = amg_utils.build_all_layer_point_grids(
336                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
337            )
338        elif point_grids is not None:
339            self.point_grids = point_grids
340        else:
341            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
342
343        self._predictor = predictor
344        self._points_per_side = points_per_side
345
346        # we set the points per batch to 16 for mps for performance reasons
347        # and otherwise keep them at the default of 64
348        if points_per_batch is None:
349            points_per_batch = 16 if str(predictor.device) == "mps" else 64
350        self._points_per_batch = points_per_batch
351
352        self._crop_n_layers = crop_n_layers
353        self._crop_overlap_ratio = crop_overlap_ratio
354        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
355        self._stability_score_offset = stability_score_offset
@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
404    @torch.no_grad()
405    def initialize(
406        self,
407        image: np.ndarray,
408        image_embeddings: Optional[util.ImageEmbeddings] = None,
409        i: Optional[int] = None,
410        verbose: bool = False,
411        pbar_init: Optional[callable] = None,
412        pbar_update: Optional[callable] = None,
413    ) -> None:
414        """Initialize image embeddings and masks for an image.
415
416        Args:
417            image: The input image, volume or timeseries.
418            image_embeddings: Optional precomputed image embeddings.
419                See `util.precompute_image_embeddings` for details.
420            i: Index for the image data. Required if `image` has three spatial dimensions
421                or a time dimension and two spatial dimensions.
422            verbose: Whether to print computation progress. By default, set to 'False'.
423            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
424                Can be used together with pbar_update to handle napari progress bar in other thread.
425                To enable using this function within a threadworker.
426            pbar_update: Callback to update an external progress bar.
427        """
428        original_size = image.shape[:2]
429        self._original_size = original_size
430
431        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
432            original_size, self._crop_n_layers, self._crop_overlap_ratio
433        )
434
435        # We can set fixed image embeddings if we only have a single crop box (the default setting).
436        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
437        if len(crop_boxes) == 1:
438            if image_embeddings is None:
439                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
440            util.set_precomputed(self._predictor, image_embeddings, i=i)
441            precomputed_embeddings = True
442        else:
443            precomputed_embeddings = False
444
445        # we need to cast to the image representation that is compatible with SAM
446        image = util._to_image(image)
447
448        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
449
450        crop_list = []
451        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
452            crop_data = self._process_crop(
453                image, crop_box, layer_idx,
454                precomputed_embeddings=precomputed_embeddings,
455                pbar_init=pbar_init, pbar_update=pbar_update,
456            )
457            crop_list.append(crop_data)
458        pbar_close()
459
460        self._is_initialized = True
461        self._crop_list = crop_list
462        self._crop_boxes = crop_boxes

Initialize image embeddings and masks for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to print computation progress. By default, set to 'False'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enable using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
@torch.no_grad()
def generate( self, pred_iou_thresh: float = 0.88, stability_score_thresh: float = 0.95, box_nms_thresh: float = 0.7, crop_nms_thresh: float = 0.7, min_mask_region_area: int = 0, output_mode: str = 'instance_segmentation', with_background: bool = True) -> Union[List[Dict[str, Any]], numpy.ndarray]:
464    @torch.no_grad()
465    def generate(
466        self,
467        pred_iou_thresh: float = 0.88,
468        stability_score_thresh: float = 0.95,
469        box_nms_thresh: float = 0.7,
470        crop_nms_thresh: float = 0.7,
471        min_mask_region_area: int = 0,
472        output_mode: str = "instance_segmentation",
473        with_background: bool = True,
474    ) -> Union[List[Dict[str, Any]], np.ndarray]:
475        """Generate instance segmentation for the currently initialized image.
476
477        Args:
478            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
479                By default, set to '0.88'.
480            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
481                under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
482            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
483                By default, set to '0.7'.
484            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
485                By default, set to '0.7'.
486            min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
487            output_mode: The form masks are returned in. Possible values are:
488                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
489                - 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format.
490                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
491                - 'rle': Return a list of dictionaries with run-length encoded masks.
492                By default, set to 'instance_segmentation'.
493            with_background: Whether to remove the largest object, which often covers the background.
494
495        Returns:
496            The segmentation masks.
497        """
498        if not self.is_initialized:
499            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
500
501        data = amg_utils.MaskData()
502        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
503            crop_data = self._postprocess_batch(
504                data=deepcopy(data_),
505                crop_box=crop_box, original_size=self.original_size,
506                pred_iou_thresh=pred_iou_thresh,
507                stability_score_thresh=stability_score_thresh,
508                box_nms_thresh=box_nms_thresh
509            )
510            data.cat(crop_data)
511
512        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
513            # Prefer masks from smaller crops
514            scores = 1 / box_area(data["crop_boxes"])
515            scores = scores.to(data["boxes"].device)
516            keep_by_nms = batched_nms(
517                data["boxes"].float(),
518                scores,
519                torch.zeros_like(data["boxes"][:, 0]),  # categories
520                iou_threshold=crop_nms_thresh,
521            )
522            data.filter(keep_by_nms)
523
524        data.to_numpy()
525        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
526        if output_mode == "instance_segmentation":
527            shape = next(iter(masks))["segmentation"].shape if len(masks) > 0 else self.original_size
528            masks = util.mask_data_to_segmentation(
529                masks, shape=shape, with_background=with_background, merge_exclusively=False
530            )
531        return masks

Generate instance segmentation for the currently initialized image.

Arguments:
  • pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. By default, set to '0.88'.
  • stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
  • box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. By default, set to '0.7'.
  • crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. By default, set to '0.7'.
  • min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
  • output_mode: The form masks are returned in. Possible values are:
    • 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
    • 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format.
    • 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
    • 'rle': Return a list of dictionaries with run-length encoded masks. By default, set to 'instance_segmentation'.
  • with_background: Whether to remove the largest object, which often covers the background.
Returns:

The segmentation masks.

class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
565class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
566    """Generates an instance segmentation without prompts, using a point grid.
567
568    Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
569
570    Args:
571        predictor: The Segment Anything predictor.
572        points_per_side: The number of points to be sampled along one side of the image.
573            If None, `point_grids` must provide explicit point sampling. By default, set to '32'.
574        points_per_batch: The number of points run simultaneously by the model.
575            Higher numbers may be faster but use more GPU memory. By default, set to '64'.
576        point_grids: A list over explicit grids of points used for sampling masks.
577            Normalized to [0, 1] with respect to the image coordinate system.
578        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
579            By default, set to '1.0'.
580    """
581
582    # We only expose the arguments that make sense for the tiled mask generator.
583    # Anything related to crops doesn't make sense, because we re-use that functionality
584    # for tiling, so these parameters wouldn't have any effect.
585    def __init__(
586        self,
587        predictor: SamPredictor,
588        points_per_side: Optional[int] = 32,
589        points_per_batch: int = 64,
590        point_grids: Optional[List[np.ndarray]] = None,
591        stability_score_offset: float = 1.0,
592    ) -> None:
593        super().__init__(
594            predictor=predictor,
595            points_per_side=points_per_side,
596            points_per_batch=points_per_batch,
597            point_grids=point_grids,
598            stability_score_offset=stability_score_offset,
599        )
600
601    @torch.no_grad()
602    def initialize(
603        self,
604        image: np.ndarray,
605        image_embeddings: Optional[util.ImageEmbeddings] = None,
606        i: Optional[int] = None,
607        tile_shape: Optional[Tuple[int, int]] = None,
608        halo: Optional[Tuple[int, int]] = None,
609        verbose: bool = False,
610        pbar_init: Optional[callable] = None,
611        pbar_update: Optional[callable] = None,
612        batch_size: int = 1,
613        mask: Optional[np.typing.ArrayLike] = None,
614    ) -> None:
615        """Initialize image embeddings and masks for an image.
616
617        Args:
618            image: The input image, volume or timeseries.
619            image_embeddings: Optional precomputed image embeddings.
620                See `util.precompute_image_embeddings` for details.
621            i: Index for the image data. Required if `image` has three spatial dimensions
622                or a time dimension and two spatial dimensions.
623            tile_shape: The tile shape for embedding prediction.
624            halo: The overlap of between tiles.
625            verbose: Whether to print computation progress. By default, set to 'False'.
626            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
627                Can be used together with pbar_update to handle napari progress bar in other thread.
628                To enable using this function within a threadworker.
629            pbar_update: Callback to update an external progress bar.
630            batch_size: The batch size for image embedding prediction. By default, set to '1'.
631            mask: An optional mask to define areas that are ignored in the segmentation.
632        """
633        original_size = image.shape[:2]
634        self._original_size = original_size
635
636        self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings(
637            self._predictor, image, image_embeddings, tile_shape, halo,
638            verbose=verbose, batch_size=batch_size, mask=mask, i=i,
639        )
640
641        tiling = Blocking([0, 0], original_size, tile_shape)
642        if tiles_in_mask is None:
643            n_tiles = tiling.number_of_blocks
644            tile_ids = range(n_tiles)
645        else:
646            n_tiles = len(tiles_in_mask)
647            tile_ids = tiles_in_mask
648
649        # The crop box is always the full local tile.
650        tiles = [tiling.get_block_with_halo(tile_id, list(halo)).outer_block for tile_id in tile_ids]
651        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
652
653        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
654        pbar_init(n_tiles, "Compute masks for tile")
655
656        # We need to cast to the image representation that is compatible with SAM.
657        image = util._to_image(image)
658
659        mask_data = []
660        for idx, tile_id in enumerate(tile_ids):
661            # set the pre-computed embeddings for this tile
662            features = image_embeddings["features"][str(tile_id)]
663            tile_embeddings = {
664                "features": features,
665                "input_size": features.attrs["input_size"],
666                "original_size": features.attrs["original_size"],
667            }
668            util.set_precomputed(self._predictor, tile_embeddings, i)
669
670            # Compute the mask data for this tile and append it
671            this_mask_data = self._process_crop(
672                image, crop_box=crop_boxes[idx], crop_layer_idx=0, precomputed_embeddings=True
673            )
674            mask_data.append(this_mask_data)
675            pbar_update(1)
676        pbar_close()
677
678        # set the initialized data
679        self._is_initialized = True
680        self._crop_list = mask_data
681        self._crop_boxes = crop_boxes

Generates an instance segmentation without prompts, using a point grid.

Implements the same functionality as AutomaticMaskGenerator but for tiled embeddings.

Arguments:
  • predictor: The Segment Anything predictor.
  • points_per_side: The number of points to be sampled along one side of the image. If None, point_grids must provide explicit point sampling. By default, set to '32'.
  • points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. By default, set to '64'.
  • point_grids: A list over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
  • stability_score_offset: The amount to shift the cutoff when calculating the stability score. By default, set to '1.0'.
TiledAutomaticMaskGenerator( predictor: segment_anything.predictor.SamPredictor, points_per_side: Optional[int] = 32, points_per_batch: int = 64, point_grids: Optional[List[numpy.ndarray]] = None, stability_score_offset: float = 1.0)
585    def __init__(
586        self,
587        predictor: SamPredictor,
588        points_per_side: Optional[int] = 32,
589        points_per_batch: int = 64,
590        point_grids: Optional[List[np.ndarray]] = None,
591        stability_score_offset: float = 1.0,
592    ) -> None:
593        super().__init__(
594            predictor=predictor,
595            points_per_side=points_per_side,
596            points_per_batch=points_per_batch,
597            point_grids=point_grids,
598            stability_score_offset=stability_score_offset,
599        )
@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None, batch_size: int = 1, mask: Optional[ArrayLike] = None) -> None:
601    @torch.no_grad()
602    def initialize(
603        self,
604        image: np.ndarray,
605        image_embeddings: Optional[util.ImageEmbeddings] = None,
606        i: Optional[int] = None,
607        tile_shape: Optional[Tuple[int, int]] = None,
608        halo: Optional[Tuple[int, int]] = None,
609        verbose: bool = False,
610        pbar_init: Optional[callable] = None,
611        pbar_update: Optional[callable] = None,
612        batch_size: int = 1,
613        mask: Optional[np.typing.ArrayLike] = None,
614    ) -> None:
615        """Initialize image embeddings and masks for an image.
616
617        Args:
618            image: The input image, volume or timeseries.
619            image_embeddings: Optional precomputed image embeddings.
620                See `util.precompute_image_embeddings` for details.
621            i: Index for the image data. Required if `image` has three spatial dimensions
622                or a time dimension and two spatial dimensions.
623            tile_shape: The tile shape for embedding prediction.
624            halo: The overlap of between tiles.
625            verbose: Whether to print computation progress. By default, set to 'False'.
626            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
627                Can be used together with pbar_update to handle napari progress bar in other thread.
628                To enable using this function within a threadworker.
629            pbar_update: Callback to update an external progress bar.
630            batch_size: The batch size for image embedding prediction. By default, set to '1'.
631            mask: An optional mask to define areas that are ignored in the segmentation.
632        """
633        original_size = image.shape[:2]
634        self._original_size = original_size
635
636        self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings(
637            self._predictor, image, image_embeddings, tile_shape, halo,
638            verbose=verbose, batch_size=batch_size, mask=mask, i=i,
639        )
640
641        tiling = Blocking([0, 0], original_size, tile_shape)
642        if tiles_in_mask is None:
643            n_tiles = tiling.number_of_blocks
644            tile_ids = range(n_tiles)
645        else:
646            n_tiles = len(tiles_in_mask)
647            tile_ids = tiles_in_mask
648
649        # The crop box is always the full local tile.
650        tiles = [tiling.get_block_with_halo(tile_id, list(halo)).outer_block for tile_id in tile_ids]
651        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
652
653        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
654        pbar_init(n_tiles, "Compute masks for tile")
655
656        # We need to cast to the image representation that is compatible with SAM.
657        image = util._to_image(image)
658
659        mask_data = []
660        for idx, tile_id in enumerate(tile_ids):
661            # set the pre-computed embeddings for this tile
662            features = image_embeddings["features"][str(tile_id)]
663            tile_embeddings = {
664                "features": features,
665                "input_size": features.attrs["input_size"],
666                "original_size": features.attrs["original_size"],
667            }
668            util.set_precomputed(self._predictor, tile_embeddings, i)
669
670            # Compute the mask data for this tile and append it
671            this_mask_data = self._process_crop(
672                image, crop_box=crop_boxes[idx], crop_layer_idx=0, precomputed_embeddings=True
673            )
674            mask_data.append(this_mask_data)
675            pbar_update(1)
676        pbar_close()
677
678        # set the initialized data
679        self._is_initialized = True
680        self._crop_list = mask_data
681        self._crop_boxes = crop_boxes

Initialize image embeddings and masks for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_shape: The tile shape for embedding prediction.
  • halo: The overlap of between tiles.
  • verbose: Whether to print computation progress. By default, set to 'False'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enable using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
  • batch_size: The batch size for image embedding prediction. By default, set to '1'.
  • mask: An optional mask to define areas that are ignored in the segmentation.
class DecoderAdapter(torch.nn.modules.module.Module):
689class DecoderAdapter(torch.nn.Module):
690    """Adapter to contain the UNETR decoder in a single module.
691
692    To apply the decoder on top of pre-computed embeddings for the segmentation functionality.
693    See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
694    """
695    def __init__(self, unetr: torch.nn.Module):
696        super().__init__()
697
698        self.base = unetr.base
699        self.out_conv = unetr.out_conv
700        self.deconv_out = unetr.deconv_out
701        self.decoder_head = unetr.decoder_head
702        self.final_activation = unetr.final_activation
703        self.postprocess_masks = unetr.postprocess_masks
704
705        self.decoder = unetr.decoder
706        self.deconv1 = unetr.deconv1
707        self.deconv2 = unetr.deconv2
708        self.deconv3 = unetr.deconv3
709        self.deconv4 = unetr.deconv4
710
711    def _forward_impl(self, input_):
712        z12 = input_
713
714        z9 = self.deconv1(z12)
715        z6 = self.deconv2(z9)
716        z3 = self.deconv3(z6)
717        z0 = self.deconv4(z3)
718
719        updated_from_encoder = [z9, z6, z3]
720
721        x = self.base(z12)
722        x = self.decoder(x, encoder_inputs=updated_from_encoder)
723        x = self.deconv_out(x)
724
725        x = torch.cat([x, z0], dim=1)
726        x = self.decoder_head(x)
727
728        x = self.out_conv(x)
729        if self.final_activation is not None:
730            x = self.final_activation(x)
731        return x
732
733    def forward(self, input_, input_shape, original_shape):
734        x = self._forward_impl(input_)
735        x = self.postprocess_masks(x, input_shape, original_shape)
736        return x

Adapter to contain the UNETR decoder in a single module.

To apply the decoder on top of pre-computed embeddings for the segmentation functionality. See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py

DecoderAdapter(unetr: torch.nn.modules.module.Module)
695    def __init__(self, unetr: torch.nn.Module):
696        super().__init__()
697
698        self.base = unetr.base
699        self.out_conv = unetr.out_conv
700        self.deconv_out = unetr.deconv_out
701        self.decoder_head = unetr.decoder_head
702        self.final_activation = unetr.final_activation
703        self.postprocess_masks = unetr.postprocess_masks
704
705        self.decoder = unetr.decoder
706        self.deconv1 = unetr.deconv1
707        self.deconv2 = unetr.deconv2
708        self.deconv3 = unetr.deconv3
709        self.deconv4 = unetr.deconv4

Initialize internal Module state, shared by both nn.Module and ScriptModule.

base
out_conv
deconv_out
decoder_head
final_activation
postprocess_masks
decoder
deconv1
deconv2
deconv3
deconv4
def forward(self, input_, input_shape, original_shape):
733    def forward(self, input_, input_shape, original_shape):
734        x = self._forward_impl(input_)
735        x = self.postprocess_masks(x, input_shape, original_shape)
736        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def get_unetr( image_encoder: torch.nn.modules.module.Module, decoder_state: Optional[collections.OrderedDict[str, torch.Tensor]] = None, device: Union[str, torch.device, NoneType] = None, out_channels: int = 3, flexible_load_checkpoint: bool = False, final_activation: Optional[str] = 'Sigmoid') -> torch.nn.modules.module.Module:
739def get_unetr(
740    image_encoder: torch.nn.Module,
741    decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
742    device: Optional[Union[str, torch.device]] = None,
743    out_channels: int = 3,
744    flexible_load_checkpoint: bool = False,
745    final_activation: Optional[str] = "Sigmoid",
746) -> torch.nn.Module:
747    """Get UNETR model for automatic instance segmentation.
748
749    Args:
750        image_encoder: The image encoder of the SAM model.
751            This is used as encoder by the UNETR too.
752        decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
753        device: The device. By default, automatically chooses the best available device.
754        out_channels: The number of output channels. By default, set to '3'.
755        flexible_load_checkpoint: Whether to allow reinitialization of parameters
756            which could not be found in the provided decoder state. By default, set to 'False'.
757        final_activation: The activation applied to the network output. By default uses a Sigmoid activation.
758
759    Returns:
760        The UNETR model.
761    """
762    device = util.get_device(device)
763
764    if decoder_state is None:
765        use_conv_transpose = False  # By default, we use interpolation for upsampling.
766    else:
767        # From the provided pretrained 'decoder_state', we check whether it uses transposed convolutions.
768        # NOTE: Explanation to the logic below -
769        # - We do this by looking for parameter names that contain '.block.' within the "decoder.samplers"
770        #   submodules. This naming convention indicates that transposed convolutions are used,
771        #   wrapped inside a custom block module.
772        # - Otherwise '.conv.' appears. It indicates a standard `Conv2d` applied after interpolation for upsampling.
773        use_conv_transpose = any(".block." in k for k in decoder_state.keys() if k.startswith("decoder.samplers"))
774
775    unetr = UNETR(
776        backbone="sam",
777        encoder=image_encoder,
778        out_channels=out_channels,
779        use_sam_stats=True,
780        final_activation=final_activation,
781        use_skip_connection=False,
782        resize_input=True,
783        use_conv_transpose=use_conv_transpose,
784    )
785
786    if decoder_state is not None:
787        unetr_state_dict = unetr.state_dict()
788        for k, v in unetr_state_dict.items():
789            if not k.startswith("encoder"):
790                # Whether allow reinitalization of params, if not found or mismatched.
791                if flexible_load_checkpoint:
792                    if k in decoder_state:  # First check whether the key is available in the provided decoder state.
793                        if v.shape != decoder_state[k].shape:  # Then check if the sizes mismatch.
794                            warnings.warn(f"Shape of '{k}' did not match. Hence, we reinitialize it.")
795                            unetr_state_dict[k] = v
796                        else:
797                            unetr_state_dict[k] = decoder_state[k]
798                    else:  # Otherwise, allow it to initialize it.
799                        warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.")
800                        unetr_state_dict[k] = v
801
802                else:  # Be strict on finding the parameter in the decoder state.
803                    if k not in decoder_state:
804                        raise RuntimeError(f"The parameters for '{k}' could not be found or has a size mismatch.")
805                    unetr_state_dict[k] = decoder_state[k]
806
807        unetr.load_state_dict(unetr_state_dict)
808
809    unetr.to(device)
810    return unetr

Get UNETR model for automatic instance segmentation.

Arguments:
  • image_encoder: The image encoder of the SAM model. This is used as encoder by the UNETR too.
  • decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
  • device: The device. By default, automatically chooses the best available device.
  • out_channels: The number of output channels. By default, set to '3'.
  • flexible_load_checkpoint: Whether to allow reinitialization of parameters which could not be found in the provided decoder state. By default, set to 'False'.
  • final_activation: The activation applied to the network output. By default uses a Sigmoid activation.
Returns:

The UNETR model.

def get_decoder( image_encoder: torch.nn.modules.module.Module, decoder_state: collections.OrderedDict[str, torch.Tensor], device: Union[str, torch.device, NoneType] = None) -> DecoderAdapter:
813def get_decoder(
814    image_encoder: torch.nn.Module,
815    decoder_state: OrderedDict[str, torch.Tensor],
816    device: Optional[Union[str, torch.device]] = None,
817) -> DecoderAdapter:
818    """Get decoder to predict outputs for automatic instance segmentation
819
820    Args:
821        image_encoder: The image encoder of the SAM model.
822        decoder_state: State to initialize the weights of the UNETR decoder.
823        device: The device. By default, automatically chooses the best available device.
824
825    Returns:
826        The decoder for instance segmentation.
827    """
828    unetr = get_unetr(image_encoder, decoder_state, device)
829    return DecoderAdapter(unetr)

Get decoder to predict outputs for automatic instance segmentation

Arguments:
  • image_encoder: The image encoder of the SAM model.
  • decoder_state: State to initialize the weights of the UNETR decoder.
  • device: The device. By default, automatically chooses the best available device.
Returns:

The decoder for instance segmentation.

def get_predictor_and_decoder( model_type: str, checkpoint_path: Union[str, os.PathLike, NoneType] = None, device: Union[str, torch.device, NoneType] = None, peft_kwargs: Optional[Dict] = None) -> Tuple[segment_anything.predictor.SamPredictor, DecoderAdapter]:
832def get_predictor_and_decoder(
833    model_type: str,
834    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
835    device: Optional[Union[str, torch.device]] = None,
836    peft_kwargs: Optional[Dict] = None,
837) -> Tuple[SamPredictor, DecoderAdapter]:
838    """Load the SAM model (predictor) and instance segmentation decoder.
839
840    This requires a checkpoint that contains the state for both predictor
841    and decoder.
842
843    Args:
844        model_type: The type of the image encoder used in the SAM model.
845        checkpoint_path: Path to the checkpoint from which to load the data.
846        device: The device. By default, automatically chooses the best available device.
847        peft_kwargs: Keyword arguments for the PEFT wrapper class.
848
849    Returns:
850        The SAM predictor.
851        The decoder for instance segmentation.
852    """
853    device = util.get_device(device)
854    predictor, state = util.get_sam_model(
855        model_type=model_type,
856        checkpoint_path=checkpoint_path,
857        device=device,
858        return_state=True,
859        peft_kwargs=peft_kwargs,
860    )
861
862    if "decoder_state" not in state:
863        raise ValueError(
864            f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
865        )
866
867    decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
868    return predictor, decoder

Load the SAM model (predictor) and instance segmentation decoder.

This requires a checkpoint that contains the state for both predictor and decoder.

Arguments:
  • model_type: The type of the image encoder used in the SAM model.
  • checkpoint_path: Path to the checkpoint from which to load the data.
  • device: The device. By default, automatically chooses the best available device.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
Returns:

The SAM predictor. The decoder for instance segmentation.

class InstanceSegmentationWithDecoder:
 954class InstanceSegmentationWithDecoder:
 955    """Generates an instance segmentation without prompts, using a decoder.
 956
 957    Implements the same interface as `AutomaticMaskGenerator`.
 958
 959    Use this class as follows:
 960    ```python
 961    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
 962    segmenter.initialize(image)  # Predict the image embeddings and decoder outputs.
 963    masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
 964    ```
 965
 966    Args:
 967        predictor: The segment anything predictor.
 968        decoder: The decoder to predict intermediate representations for instance segmentation.
 969    """
 970    def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None:
 971        self._predictor = predictor
 972        self._decoder = decoder
 973
 974        # The decoder outputs.
 975        self._foreground = None
 976        self._center_distances = None
 977        self._boundary_distances = None
 978
 979        self._is_initialized = False
 980
 981    @property
 982    def is_initialized(self):
 983        """Whether the mask generator has already been initialized.
 984        """
 985        return self._is_initialized
 986
 987    @torch.no_grad()
 988    def initialize(
 989        self,
 990        image: np.ndarray,
 991        image_embeddings: Optional[util.ImageEmbeddings] = None,
 992        i: Optional[int] = None,
 993        verbose: bool = False,
 994        pbar_init: Optional[callable] = None,
 995        pbar_update: Optional[callable] = None,
 996        ndim: int = 2,
 997    ) -> None:
 998        """Initialize image embeddings and decoder predictions for an image.
 999
1000        Args:
1001            image: The input image, volume or timeseries.
1002            image_embeddings: Optional precomputed image embeddings.
1003                See `util.precompute_image_embeddings` for details.
1004            i: Index for the image data. Required if `image` has three spatial dimensions
1005                or a time dimension and two spatial dimensions.
1006            verbose: Whether to be verbose. By default, set to 'False'.
1007            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1008                Can be used together with pbar_update to handle napari progress bar in other thread.
1009                To enable using this function within a threadworker.
1010            pbar_update: Callback to update an external progress bar.
1011            ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2.
1012        """
1013        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1014        pbar_init(1, "Initialize instance segmentation with decoder")
1015
1016        if image_embeddings is None:
1017            image_embeddings = util.precompute_image_embeddings(
1018                predictor=self._predictor, input_=image, ndim=ndim, verbose=verbose
1019            )
1020
1021        # Get the image embeddings from the predictor.
1022        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
1023        embeddings = self._predictor.features
1024        input_shape = tuple(self._predictor.input_size)
1025        original_shape = tuple(self._predictor.original_size)
1026
1027        # Run prediction with the UNETR decoder.
1028        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1029        assert output.shape[0] == 3, f"{output.shape}"
1030        pbar_update(1)
1031        pbar_close()
1032
1033        # Set the state.
1034        self._foreground = output[0]
1035        self._center_distances = output[1]
1036        self._boundary_distances = output[2]
1037        self._i = i
1038        self._is_initialized = True
1039
1040    def _to_masks(self, segmentation, output_mode):
1041        if output_mode != "binary_mask":
1042            raise ValueError(
1043                f"Output mode {output_mode} is not supported. Choose one of 'instance_segmentation', 'binary_masks'"
1044            )
1045
1046        props = regionprops(segmentation)
1047        ndim = segmentation.ndim
1048        assert ndim in (2, 3)
1049
1050        shape = segmentation.shape
1051        if ndim == 2:
1052            crop_box = [0, shape[1], 0, shape[0]]
1053        else:
1054            crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]
1055
1056        # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
1057        def to_bbox_2d(bbox):
1058            y0, x0 = bbox[0], bbox[1]
1059            w = bbox[3] - x0
1060            h = bbox[2] - y0
1061            return [x0, w, y0, h]
1062
1063        def to_bbox_3d(bbox):
1064            z0, y0, x0 = bbox[0], bbox[1], bbox[2]
1065            w = bbox[5] - x0
1066            h = bbox[4] - y0
1067            d = bbox[3] - y0
1068            return [x0, w, y0, h, z0, d]
1069
1070        to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
1071        masks = [
1072            {
1073                "segmentation": segmentation == prop.label,
1074                "area": prop.area,
1075                "bbox": to_bbox(prop.bbox),
1076                "crop_box": crop_box,
1077                "seg_id": prop.label,
1078            } for prop in props
1079        ]
1080        return masks
1081
1082    def generate(
1083        self,
1084        center_distance_threshold: float = 0.5,
1085        boundary_distance_threshold: float = 0.5,
1086        foreground_threshold: float = 0.5,
1087        foreground_smoothing: float = 1.0,
1088        distance_smoothing: float = 1.6,
1089        min_size: int = 0,
1090        output_mode: str = "instance_segmentation",
1091        tile_shape: Optional[Tuple[int, int]] = None,
1092        halo: Optional[Tuple[int, int]] = None,
1093        n_threads: Optional[int] = None,
1094        optimize_memory: bool = False,
1095        segmentation: Optional[np.ndarray] = None,
1096    ) -> Union[List[Dict[str, Any]], np.ndarray]:
1097        """Generate instance segmentation for the currently initialized image.
1098
1099        Args:
1100            center_distance_threshold: Center distance predictions below this value will be
1101                used to find seeds (intersected with thresholded boundary distance predictions).
1102                By default, set to '0.5'.
1103            boundary_distance_threshold: Boundary distance predictions below this value will be
1104                used to find seeds (intersected with thresholded center distance predictions).
1105                By default, set to '0.5'.
1106            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1107                By default, set to '0.5'.
1108            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1109                checkerboard artifacts in the prediction. By default, set to '1.0'.
1110            distance_smoothing: Sigma value for smoothing the distance predictions.
1111            min_size: Minimal object size in the segmentation result. By default, set to '0'.
1112            output_mode: The form masks are returned in. Possible values are:
1113                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1114                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1115                By default, set to 'instance_segmentation'.
1116            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1117                This parameter is independent from the tile shape for computing the embeddings.
1118                If not given then post-processing will not be parallelized.
1119            halo: Halo for parallel post-processing. See also `tile_shape`.
1120            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1121            optimize_memory: Whether to optimize the memory consumption by allocating intermediate files.
1122            segmentation: Optional pre-allocated segmentation.
1123
1124        Returns:
1125            The segmentation masks.
1126        """
1127        if not self.is_initialized:
1128            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1129
1130        if foreground_smoothing > 0:
1131            foreground = _apply_smoothing(self._foreground, foreground_smoothing, tile_shape, n_threads)
1132        else:
1133            foreground = self._foreground
1134
1135        if tile_shape is None:
1136            segmentation = watershed_from_center_and_boundary_distances(
1137                center_distances=self._center_distances,
1138                boundary_distances=self._boundary_distances,
1139                foreground_map=foreground,
1140                center_distance_threshold=center_distance_threshold,
1141                boundary_distance_threshold=boundary_distance_threshold,
1142                foreground_threshold=foreground_threshold,
1143                distance_smoothing=distance_smoothing,
1144                min_size=min_size,
1145            )
1146        else:
1147            if halo is None:
1148                raise ValueError("You must pass a value for halo if tile_shape is given.")
1149
1150            # Shards are not thread-safe for parallel writing! So if we have shards we have to use them for tiling.
1151            # This is ok in terms efficiency as GPU tiles are small; shards should still be manegable for the watershed.
1152            if isinstance(segmentation, zarr.Array) and getattr(segmentation, "shards", None) is not None:
1153                tile_shape = segmentation.shards
1154
1155            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1156                center_distances=self._center_distances,
1157                boundary_distances=self._boundary_distances,
1158                foreground_map=foreground,
1159                center_distance_threshold=center_distance_threshold,
1160                boundary_distance_threshold=boundary_distance_threshold,
1161                foreground_threshold=foreground_threshold,
1162                distance_smoothing=distance_smoothing,
1163                min_size=min_size,
1164                tile_shape=tile_shape,
1165                halo=halo,
1166                n_threads=n_threads,
1167                verbose=False,
1168                optimize_memory=optimize_memory,
1169                segmentation=segmentation,
1170            )
1171
1172        if output_mode != "instance_segmentation":
1173            segmentation = self._to_masks(segmentation, output_mode)
1174        return segmentation
1175
1176    def get_state(self) -> Dict[str, Any]:
1177        """Get the initialized state of the instance segmenter.
1178
1179        Returns:
1180            Instance segmentation state.
1181        """
1182        if not self.is_initialized:
1183            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1184
1185        return {
1186            "foreground": self._foreground,
1187            "center_distances": self._center_distances,
1188            "boundary_distances": self._boundary_distances,
1189        }
1190
1191    def set_state(self, state: Dict[str, Any]) -> None:
1192        """Set the state of the instance segmenter.
1193
1194        Args:
1195            state: The instance segmentation state
1196        """
1197        self._foreground = state["foreground"]
1198        self._center_distances = state["center_distances"]
1199        self._boundary_distances = state["boundary_distances"]
1200        self._is_initialized = True
1201
1202    def clear_state(self):
1203        """Clear the state of the instance segmenter.
1204        """
1205        self._foreground = None
1206        self._center_distances = None
1207        self._boundary_distances = None
1208        self._is_initialized = False

Generates an instance segmentation without prompts, using a decoder.

Implements the same interface as AutomaticMaskGenerator.

Use this class as follows:

segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image)  # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
Arguments:
  • predictor: The segment anything predictor.
  • decoder: The decoder to predict intermediate representations for instance segmentation.
InstanceSegmentationWithDecoder( predictor: segment_anything.predictor.SamPredictor, decoder: torch.nn.modules.module.Module)
970    def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None:
971        self._predictor = predictor
972        self._decoder = decoder
973
974        # The decoder outputs.
975        self._foreground = None
976        self._center_distances = None
977        self._boundary_distances = None
978
979        self._is_initialized = False
is_initialized
981    @property
982    def is_initialized(self):
983        """Whether the mask generator has already been initialized.
984        """
985        return self._is_initialized

Whether the mask generator has already been initialized.

@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None, ndim: int = 2) -> None:
 987    @torch.no_grad()
 988    def initialize(
 989        self,
 990        image: np.ndarray,
 991        image_embeddings: Optional[util.ImageEmbeddings] = None,
 992        i: Optional[int] = None,
 993        verbose: bool = False,
 994        pbar_init: Optional[callable] = None,
 995        pbar_update: Optional[callable] = None,
 996        ndim: int = 2,
 997    ) -> None:
 998        """Initialize image embeddings and decoder predictions for an image.
 999
1000        Args:
1001            image: The input image, volume or timeseries.
1002            image_embeddings: Optional precomputed image embeddings.
1003                See `util.precompute_image_embeddings` for details.
1004            i: Index for the image data. Required if `image` has three spatial dimensions
1005                or a time dimension and two spatial dimensions.
1006            verbose: Whether to be verbose. By default, set to 'False'.
1007            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1008                Can be used together with pbar_update to handle napari progress bar in other thread.
1009                To enable using this function within a threadworker.
1010            pbar_update: Callback to update an external progress bar.
1011            ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2.
1012        """
1013        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1014        pbar_init(1, "Initialize instance segmentation with decoder")
1015
1016        if image_embeddings is None:
1017            image_embeddings = util.precompute_image_embeddings(
1018                predictor=self._predictor, input_=image, ndim=ndim, verbose=verbose
1019            )
1020
1021        # Get the image embeddings from the predictor.
1022        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
1023        embeddings = self._predictor.features
1024        input_shape = tuple(self._predictor.input_size)
1025        original_shape = tuple(self._predictor.original_size)
1026
1027        # Run prediction with the UNETR decoder.
1028        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1029        assert output.shape[0] == 3, f"{output.shape}"
1030        pbar_update(1)
1031        pbar_close()
1032
1033        # Set the state.
1034        self._foreground = output[0]
1035        self._center_distances = output[1]
1036        self._boundary_distances = output[2]
1037        self._i = i
1038        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to be verbose. By default, set to 'False'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enable using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
  • ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2.
def generate( self, center_distance_threshold: float = 0.5, boundary_distance_threshold: float = 0.5, foreground_threshold: float = 0.5, foreground_smoothing: float = 1.0, distance_smoothing: float = 1.6, min_size: int = 0, output_mode: str = 'instance_segmentation', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, n_threads: Optional[int] = None, optimize_memory: bool = False, segmentation: Optional[numpy.ndarray] = None) -> Union[List[Dict[str, Any]], numpy.ndarray]:
1082    def generate(
1083        self,
1084        center_distance_threshold: float = 0.5,
1085        boundary_distance_threshold: float = 0.5,
1086        foreground_threshold: float = 0.5,
1087        foreground_smoothing: float = 1.0,
1088        distance_smoothing: float = 1.6,
1089        min_size: int = 0,
1090        output_mode: str = "instance_segmentation",
1091        tile_shape: Optional[Tuple[int, int]] = None,
1092        halo: Optional[Tuple[int, int]] = None,
1093        n_threads: Optional[int] = None,
1094        optimize_memory: bool = False,
1095        segmentation: Optional[np.ndarray] = None,
1096    ) -> Union[List[Dict[str, Any]], np.ndarray]:
1097        """Generate instance segmentation for the currently initialized image.
1098
1099        Args:
1100            center_distance_threshold: Center distance predictions below this value will be
1101                used to find seeds (intersected with thresholded boundary distance predictions).
1102                By default, set to '0.5'.
1103            boundary_distance_threshold: Boundary distance predictions below this value will be
1104                used to find seeds (intersected with thresholded center distance predictions).
1105                By default, set to '0.5'.
1106            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1107                By default, set to '0.5'.
1108            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1109                checkerboard artifacts in the prediction. By default, set to '1.0'.
1110            distance_smoothing: Sigma value for smoothing the distance predictions.
1111            min_size: Minimal object size in the segmentation result. By default, set to '0'.
1112            output_mode: The form masks are returned in. Possible values are:
1113                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1114                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1115                By default, set to 'instance_segmentation'.
1116            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1117                This parameter is independent from the tile shape for computing the embeddings.
1118                If not given then post-processing will not be parallelized.
1119            halo: Halo for parallel post-processing. See also `tile_shape`.
1120            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1121            optimize_memory: Whether to optimize the memory consumption by allocating intermediate files.
1122            segmentation: Optional pre-allocated segmentation.
1123
1124        Returns:
1125            The segmentation masks.
1126        """
1127        if not self.is_initialized:
1128            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1129
1130        if foreground_smoothing > 0:
1131            foreground = _apply_smoothing(self._foreground, foreground_smoothing, tile_shape, n_threads)
1132        else:
1133            foreground = self._foreground
1134
1135        if tile_shape is None:
1136            segmentation = watershed_from_center_and_boundary_distances(
1137                center_distances=self._center_distances,
1138                boundary_distances=self._boundary_distances,
1139                foreground_map=foreground,
1140                center_distance_threshold=center_distance_threshold,
1141                boundary_distance_threshold=boundary_distance_threshold,
1142                foreground_threshold=foreground_threshold,
1143                distance_smoothing=distance_smoothing,
1144                min_size=min_size,
1145            )
1146        else:
1147            if halo is None:
1148                raise ValueError("You must pass a value for halo if tile_shape is given.")
1149
1150            # Shards are not thread-safe for parallel writing! So if we have shards we have to use them for tiling.
1151            # This is ok in terms efficiency as GPU tiles are small; shards should still be manegable for the watershed.
1152            if isinstance(segmentation, zarr.Array) and getattr(segmentation, "shards", None) is not None:
1153                tile_shape = segmentation.shards
1154
1155            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1156                center_distances=self._center_distances,
1157                boundary_distances=self._boundary_distances,
1158                foreground_map=foreground,
1159                center_distance_threshold=center_distance_threshold,
1160                boundary_distance_threshold=boundary_distance_threshold,
1161                foreground_threshold=foreground_threshold,
1162                distance_smoothing=distance_smoothing,
1163                min_size=min_size,
1164                tile_shape=tile_shape,
1165                halo=halo,
1166                n_threads=n_threads,
1167                verbose=False,
1168                optimize_memory=optimize_memory,
1169                segmentation=segmentation,
1170            )
1171
1172        if output_mode != "instance_segmentation":
1173            segmentation = self._to_masks(segmentation, output_mode)
1174        return segmentation

Generate instance segmentation for the currently initialized image.

Arguments:
  • center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions). By default, set to '0.5'.
  • boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions). By default, set to '0.5'.
  • foreground_threshold: Foreground predictions above this value will be used as foreground mask. By default, set to '0.5'.
  • foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction. By default, set to '1.0'.
  • distance_smoothing: Sigma value for smoothing the distance predictions.
  • min_size: Minimal object size in the segmentation result. By default, set to '0'.
  • output_mode: The form masks are returned in. Possible values are:
    • 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
    • 'instance_segmentation': Return masks merged into an instance segmentation in a single array. By default, set to 'instance_segmentation'.
  • tile_shape: Tile shape for parallelizing the instance segmentation post-processing. This parameter is independent from the tile shape for computing the embeddings. If not given then post-processing will not be parallelized.
  • halo: Halo for parallel post-processing. See also tile_shape.
  • n_threads: Number of threads for parallel post-processing. See also tile_shape.
  • optimize_memory: Whether to optimize the memory consumption by allocating intermediate files.
  • segmentation: Optional pre-allocated segmentation.
Returns:

The segmentation masks.

def get_state(self) -> Dict[str, Any]:
1176    def get_state(self) -> Dict[str, Any]:
1177        """Get the initialized state of the instance segmenter.
1178
1179        Returns:
1180            Instance segmentation state.
1181        """
1182        if not self.is_initialized:
1183            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1184
1185        return {
1186            "foreground": self._foreground,
1187            "center_distances": self._center_distances,
1188            "boundary_distances": self._boundary_distances,
1189        }

Get the initialized state of the instance segmenter.

Returns:

Instance segmentation state.

def set_state(self, state: Dict[str, Any]) -> None:
1191    def set_state(self, state: Dict[str, Any]) -> None:
1192        """Set the state of the instance segmenter.
1193
1194        Args:
1195            state: The instance segmentation state
1196        """
1197        self._foreground = state["foreground"]
1198        self._center_distances = state["center_distances"]
1199        self._boundary_distances = state["boundary_distances"]
1200        self._is_initialized = True

Set the state of the instance segmenter.

Arguments:
  • state: The instance segmentation state
def clear_state(self):
1202    def clear_state(self):
1203        """Clear the state of the instance segmenter.
1204        """
1205        self._foreground = None
1206        self._center_distances = None
1207        self._boundary_distances = None
1208        self._is_initialized = False

Clear the state of the instance segmenter.

class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1211class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1212    """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings.
1213    """
1214
1215    # Apply the decoder in a batched fashion, and then perform the resizing independently per output.
1216    # This is necessary, because the individual tiles may have different tile shapes due to border tiles.
1217    def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes):
1218        batched_embeddings = torch.cat(batched_embeddings)
1219        output = self._decoder._forward_impl(batched_embeddings)
1220
1221        batched_output = []
1222        for x, input_shape, original_shape in zip(output, input_shapes, original_shapes):
1223            x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0)
1224            batched_output.append(x.cpu().numpy())
1225        return batched_output
1226
1227    @torch.no_grad()
1228    def initialize(
1229        self,
1230        image: np.ndarray,
1231        image_embeddings: Optional[util.ImageEmbeddings] = None,
1232        i: Optional[int] = None,
1233        tile_shape: Optional[Tuple[int, int]] = None,
1234        halo: Optional[Tuple[int, int]] = None,
1235        verbose: bool = False,
1236        pbar_init: Optional[callable] = None,
1237        pbar_update: Optional[callable] = None,
1238        batch_size: int = 1,
1239        mask: Optional[np.typing.ArrayLike] = None,
1240    ) -> None:
1241        """Initialize image embeddings and decoder predictions for an image.
1242
1243        Args:
1244            image: The input image, volume or timeseries.
1245            image_embeddings: Optional precomputed image embeddings.
1246                See `util.precompute_image_embeddings` for details.
1247            i: Index for the image data. Required if `image` has three spatial dimensions
1248                or a time dimension and two spatial dimensions.
1249            tile_shape: Shape of the tiles for precomputing image embeddings.
1250            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1251            verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
1252            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1253                Can be used together with pbar_update to handle napari progress bar in other thread.
1254                To enable using this function within a threadworker.
1255            pbar_update: Callback to update an external progress bar.
1256            batch_size: The batch size for image embedding computation and segmentation decoder prediction.
1257            mask: An optional mask to define areas that are ignored in the segmentation.
1258        """
1259        original_size = image.shape[:2]
1260        self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings(
1261            self._predictor, image, image_embeddings, tile_shape, halo,
1262            verbose=verbose, batch_size=batch_size, mask=mask, i=i,
1263        )
1264        tiling = Blocking([0, 0], original_size, tile_shape)
1265
1266        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1267
1268        foreground = np.zeros(original_size, dtype="float32")
1269        center_distances = np.zeros(original_size, dtype="float32")
1270        boundary_distances = np.zeros(original_size, dtype="float32")
1271
1272        msg = "Initialize tiled instance segmentation with decoder"
1273        if tiles_in_mask is None:
1274            n_tiles = tiling.number_of_blocks
1275            all_tile_ids = list(range(n_tiles))
1276        else:
1277            n_tiles = len(tiles_in_mask)
1278            all_tile_ids = tiles_in_mask
1279            msg += " and mask"
1280
1281        n_batches = int(np.ceil(n_tiles / batch_size))
1282        pbar_init(n_tiles, msg)
1283        tile_ids_for_batches = np.array_split(all_tile_ids, n_batches)
1284
1285        for tile_ids in tile_ids_for_batches:
1286            batched_embeddings, input_shapes, original_shapes = [], [], []
1287            for tile_id in tile_ids:
1288                # Get the image embeddings from the predictor for this tile.
1289                self._predictor = util.set_precomputed(self._predictor, self._image_embeddings, i=i, tile_id=tile_id)
1290
1291                batched_embeddings.append(self._predictor.features)
1292                input_shapes.append(tuple(self._predictor.input_size))
1293                original_shapes.append(tuple(self._predictor.original_size))
1294
1295            batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes)
1296
1297            for output_id, tile_id in enumerate(tile_ids):
1298                output = batched_output[output_id]
1299                assert output.shape[0] == 3
1300
1301                # Set the predictions in the output for this tile.
1302                block = tiling.get_block_with_halo(tile_id, halo=list(halo))
1303                local_bb = tuple(
1304                    slice(beg, end) for beg, end in zip(block.inner_block_local.begin, block.inner_block_local.end)
1305                )
1306                inner_bb = tuple(slice(beg, end) for beg, end in zip(block.inner_block.begin, block.inner_block.end))
1307
1308                foreground[inner_bb] = output[0][local_bb]
1309                center_distances[inner_bb] = output[1][local_bb]
1310                boundary_distances[inner_bb] = output[2][local_bb]
1311                pbar_update(1)
1312
1313        pbar_close()
1314
1315        # Set the state.
1316        self._i = i
1317        self._foreground = foreground
1318        self._center_distances = center_distances
1319        self._boundary_distances = boundary_distances
1320        self._is_initialized = True

Same as InstanceSegmentationWithDecoder but for tiled image embeddings.

@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None, batch_size: int = 1, mask: Optional[ArrayLike] = None) -> None:
1227    @torch.no_grad()
1228    def initialize(
1229        self,
1230        image: np.ndarray,
1231        image_embeddings: Optional[util.ImageEmbeddings] = None,
1232        i: Optional[int] = None,
1233        tile_shape: Optional[Tuple[int, int]] = None,
1234        halo: Optional[Tuple[int, int]] = None,
1235        verbose: bool = False,
1236        pbar_init: Optional[callable] = None,
1237        pbar_update: Optional[callable] = None,
1238        batch_size: int = 1,
1239        mask: Optional[np.typing.ArrayLike] = None,
1240    ) -> None:
1241        """Initialize image embeddings and decoder predictions for an image.
1242
1243        Args:
1244            image: The input image, volume or timeseries.
1245            image_embeddings: Optional precomputed image embeddings.
1246                See `util.precompute_image_embeddings` for details.
1247            i: Index for the image data. Required if `image` has three spatial dimensions
1248                or a time dimension and two spatial dimensions.
1249            tile_shape: Shape of the tiles for precomputing image embeddings.
1250            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1251            verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
1252            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1253                Can be used together with pbar_update to handle napari progress bar in other thread.
1254                To enable using this function within a threadworker.
1255            pbar_update: Callback to update an external progress bar.
1256            batch_size: The batch size for image embedding computation and segmentation decoder prediction.
1257            mask: An optional mask to define areas that are ignored in the segmentation.
1258        """
1259        original_size = image.shape[:2]
1260        self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings(
1261            self._predictor, image, image_embeddings, tile_shape, halo,
1262            verbose=verbose, batch_size=batch_size, mask=mask, i=i,
1263        )
1264        tiling = Blocking([0, 0], original_size, tile_shape)
1265
1266        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1267
1268        foreground = np.zeros(original_size, dtype="float32")
1269        center_distances = np.zeros(original_size, dtype="float32")
1270        boundary_distances = np.zeros(original_size, dtype="float32")
1271
1272        msg = "Initialize tiled instance segmentation with decoder"
1273        if tiles_in_mask is None:
1274            n_tiles = tiling.number_of_blocks
1275            all_tile_ids = list(range(n_tiles))
1276        else:
1277            n_tiles = len(tiles_in_mask)
1278            all_tile_ids = tiles_in_mask
1279            msg += " and mask"
1280
1281        n_batches = int(np.ceil(n_tiles / batch_size))
1282        pbar_init(n_tiles, msg)
1283        tile_ids_for_batches = np.array_split(all_tile_ids, n_batches)
1284
1285        for tile_ids in tile_ids_for_batches:
1286            batched_embeddings, input_shapes, original_shapes = [], [], []
1287            for tile_id in tile_ids:
1288                # Get the image embeddings from the predictor for this tile.
1289                self._predictor = util.set_precomputed(self._predictor, self._image_embeddings, i=i, tile_id=tile_id)
1290
1291                batched_embeddings.append(self._predictor.features)
1292                input_shapes.append(tuple(self._predictor.input_size))
1293                original_shapes.append(tuple(self._predictor.original_size))
1294
1295            batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes)
1296
1297            for output_id, tile_id in enumerate(tile_ids):
1298                output = batched_output[output_id]
1299                assert output.shape[0] == 3
1300
1301                # Set the predictions in the output for this tile.
1302                block = tiling.get_block_with_halo(tile_id, halo=list(halo))
1303                local_bb = tuple(
1304                    slice(beg, end) for beg, end in zip(block.inner_block_local.begin, block.inner_block_local.end)
1305                )
1306                inner_bb = tuple(slice(beg, end) for beg, end in zip(block.inner_block.begin, block.inner_block.end))
1307
1308                foreground[inner_bb] = output[0][local_bb]
1309                center_distances[inner_bb] = output[1][local_bb]
1310                boundary_distances[inner_bb] = output[2][local_bb]
1311                pbar_update(1)
1312
1313        pbar_close()
1314
1315        # Set the state.
1316        self._i = i
1317        self._foreground = foreground
1318        self._center_distances = center_distances
1319        self._boundary_distances = boundary_distances
1320        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_shape: Shape of the tiles for precomputing image embeddings.
  • halo: Overlap of the tiles for tiled precomputation of image embeddings.
  • verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enable using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
  • batch_size: The batch size for image embedding computation and segmentation decoder prediction.
  • mask: An optional mask to define areas that are ignored in the segmentation.
class AutomaticPromptGenerator(InstanceSegmentationWithDecoder):
1395class AutomaticPromptGenerator(InstanceSegmentationWithDecoder):
1396    """Generates an instance segmentation automatically, using automatically generated prompts from a decoder.
1397
1398    This class is used in the same way as `InstanceSegmentationWithDecoder` and `AutomaticMaskGenerator`
1399
1400    Args:
1401        predictor: The segment anything predictor.
1402        decoder: The derive prompts for automatic instance segmentation.
1403    """
1404    def generate(
1405        self,
1406        min_size: int = 25,
1407        center_distance_threshold: float = 0.5,
1408        boundary_distance_threshold: float = 0.5,
1409        foreground_threshold: float = 0.5,
1410        multimasking: bool = False,
1411        batch_size: int = 32,
1412        nms_threshold: float = 0.9,
1413        intersection_over_min: bool = False,
1414        output_mode: str = "instance_segmentation",
1415        mask_threshold: Optional[Union[float, str]] = None,
1416        refine_with_box_prompts: bool = False,
1417        prompt_function: Optional[callable] = None,
1418    ) -> Union[List[Dict[str, Any]], np.ndarray]:
1419        """Generate the instance segmentation for the currently initialized image.
1420
1421        The instance segmentation is generated by deriving prompts from the foreground and
1422        distance predictions of the segmentation decoder by thresholding these predictions,
1423        intersecting them, computing connected components, and then using the component's
1424        centers as point prompts. The masks are then filtered via NMS and merged into a segmentation.
1425
1426        Args:
1427            min_size: Minimal object size in the segmentation result. By default, set to '25'.
1428            center_distance_threshold: The threshold for the center distance predictions.
1429            boundary_distance_threshold: The threshold for the boundary distance predictions.
1430            multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
1431            batch_size: The batch size for parallelizing the prediction based on prompts.
1432            nms_threshold: The threshold for non-maximum suppression (NMS).
1433            intersection_over_min: Whether to use the minimum area of the two objects or the
1434                intersection over union area (default) in NMS.
1435            output_mode: The form masks are returned in. Possible values are:
1436                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1437                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1438                By default, set to 'instance_segmentation'.
1439            mask_threshold: The threshold for turning logits into masks in `micro_sam.inference.batched_inference`.`
1440            refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps
1441                derived from the segmentations after point prompts.
1442            prompt_function: A custom function for deriving prompts from the segmentation decoder predictions.
1443                If given, the default prompt derivation procedure is not used. Must have the following signature:
1444                ```
1445                    def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
1446                ```
1447                where `foreground`, `center_distances`, and `boundary_distances` are the respective
1448                predictions from the segmentation decoder. It must returns a dictionary containing
1449                either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`.
1450
1451        Returns:
1452            The instance segmentation masks.
1453        """
1454        if not self.is_initialized:
1455            raise RuntimeError("AutomaticPromptGenerator has not been initialized. Call initialize first.")
1456        foreground, center_distances, boundary_distances =\
1457            self._foreground, self._center_distances, self._boundary_distances
1458
1459        # 1.) Derive promtps from the decoder predictions.
1460        prompt_function = _derive_point_prompts if prompt_function is None else prompt_function
1461        prompts = prompt_function(
1462            foreground=foreground,
1463            center_distances=center_distances,
1464            boundary_distances=boundary_distances,
1465            foreground_threshold=foreground_threshold,
1466            center_distance_threshold=center_distance_threshold,
1467            boundary_distance_threshold=boundary_distance_threshold,
1468        )
1469
1470        # 2.) Apply the predictor to the prompts.
1471        if prompts is None:  # No prompts were derived, we can't do much further and return empty masks.
1472            return np.zeros(foreground.shape, dtype="uint32") if output_mode == "instance_segmentation" else []
1473        else:
1474            predictions = batched_inference(
1475                self._predictor,
1476                image=None,
1477                batch_size=batch_size,
1478                return_instance_segmentation=False,
1479                multimasking=multimasking,
1480                mask_threshold=mask_threshold,
1481                i=getattr(self, "_i", None),
1482                **prompts,
1483            )
1484
1485        # 3.) Refine the segmentation with box prompts.
1486        if refine_with_box_prompts:
1487            box_extension = 0.01  # expose as hyperparam?
1488            prompts = _derive_box_prompts(predictions, box_extension)
1489            predictions = batched_inference(
1490                self._predictor,
1491                image=None,
1492                batch_size=batch_size,
1493                return_instance_segmentation=False,
1494                multimasking=multimasking,
1495                mask_threshold=mask_threshold,
1496                i=getattr(self, "_i", None),
1497                **prompts,
1498            )
1499
1500        # 4.) Apply non-max suppression to the masks.
1501        segmentation = util.apply_nms(
1502            predictions, min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min
1503        )
1504        if output_mode != "instance_segmentation":
1505            segmentation = self._to_masks(segmentation, output_mode)
1506        return segmentation

Generates an instance segmentation automatically, using automatically generated prompts from a decoder.

This class is used in the same way as InstanceSegmentationWithDecoder and AutomaticMaskGenerator

Arguments:
  • predictor: The segment anything predictor.
  • decoder: The derive prompts for automatic instance segmentation.
def generate( self, min_size: int = 25, center_distance_threshold: float = 0.5, boundary_distance_threshold: float = 0.5, foreground_threshold: float = 0.5, multimasking: bool = False, batch_size: int = 32, nms_threshold: float = 0.9, intersection_over_min: bool = False, output_mode: str = 'instance_segmentation', mask_threshold: Union[float, str, NoneType] = None, refine_with_box_prompts: bool = False, prompt_function: Optional[<built-in function callable>] = None) -> Union[List[Dict[str, Any]], numpy.ndarray]:
1404    def generate(
1405        self,
1406        min_size: int = 25,
1407        center_distance_threshold: float = 0.5,
1408        boundary_distance_threshold: float = 0.5,
1409        foreground_threshold: float = 0.5,
1410        multimasking: bool = False,
1411        batch_size: int = 32,
1412        nms_threshold: float = 0.9,
1413        intersection_over_min: bool = False,
1414        output_mode: str = "instance_segmentation",
1415        mask_threshold: Optional[Union[float, str]] = None,
1416        refine_with_box_prompts: bool = False,
1417        prompt_function: Optional[callable] = None,
1418    ) -> Union[List[Dict[str, Any]], np.ndarray]:
1419        """Generate the instance segmentation for the currently initialized image.
1420
1421        The instance segmentation is generated by deriving prompts from the foreground and
1422        distance predictions of the segmentation decoder by thresholding these predictions,
1423        intersecting them, computing connected components, and then using the component's
1424        centers as point prompts. The masks are then filtered via NMS and merged into a segmentation.
1425
1426        Args:
1427            min_size: Minimal object size in the segmentation result. By default, set to '25'.
1428            center_distance_threshold: The threshold for the center distance predictions.
1429            boundary_distance_threshold: The threshold for the boundary distance predictions.
1430            multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
1431            batch_size: The batch size for parallelizing the prediction based on prompts.
1432            nms_threshold: The threshold for non-maximum suppression (NMS).
1433            intersection_over_min: Whether to use the minimum area of the two objects or the
1434                intersection over union area (default) in NMS.
1435            output_mode: The form masks are returned in. Possible values are:
1436                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1437                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1438                By default, set to 'instance_segmentation'.
1439            mask_threshold: The threshold for turning logits into masks in `micro_sam.inference.batched_inference`.`
1440            refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps
1441                derived from the segmentations after point prompts.
1442            prompt_function: A custom function for deriving prompts from the segmentation decoder predictions.
1443                If given, the default prompt derivation procedure is not used. Must have the following signature:
1444                ```
1445                    def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
1446                ```
1447                where `foreground`, `center_distances`, and `boundary_distances` are the respective
1448                predictions from the segmentation decoder. It must returns a dictionary containing
1449                either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`.
1450
1451        Returns:
1452            The instance segmentation masks.
1453        """
1454        if not self.is_initialized:
1455            raise RuntimeError("AutomaticPromptGenerator has not been initialized. Call initialize first.")
1456        foreground, center_distances, boundary_distances =\
1457            self._foreground, self._center_distances, self._boundary_distances
1458
1459        # 1.) Derive promtps from the decoder predictions.
1460        prompt_function = _derive_point_prompts if prompt_function is None else prompt_function
1461        prompts = prompt_function(
1462            foreground=foreground,
1463            center_distances=center_distances,
1464            boundary_distances=boundary_distances,
1465            foreground_threshold=foreground_threshold,
1466            center_distance_threshold=center_distance_threshold,
1467            boundary_distance_threshold=boundary_distance_threshold,
1468        )
1469
1470        # 2.) Apply the predictor to the prompts.
1471        if prompts is None:  # No prompts were derived, we can't do much further and return empty masks.
1472            return np.zeros(foreground.shape, dtype="uint32") if output_mode == "instance_segmentation" else []
1473        else:
1474            predictions = batched_inference(
1475                self._predictor,
1476                image=None,
1477                batch_size=batch_size,
1478                return_instance_segmentation=False,
1479                multimasking=multimasking,
1480                mask_threshold=mask_threshold,
1481                i=getattr(self, "_i", None),
1482                **prompts,
1483            )
1484
1485        # 3.) Refine the segmentation with box prompts.
1486        if refine_with_box_prompts:
1487            box_extension = 0.01  # expose as hyperparam?
1488            prompts = _derive_box_prompts(predictions, box_extension)
1489            predictions = batched_inference(
1490                self._predictor,
1491                image=None,
1492                batch_size=batch_size,
1493                return_instance_segmentation=False,
1494                multimasking=multimasking,
1495                mask_threshold=mask_threshold,
1496                i=getattr(self, "_i", None),
1497                **prompts,
1498            )
1499
1500        # 4.) Apply non-max suppression to the masks.
1501        segmentation = util.apply_nms(
1502            predictions, min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min
1503        )
1504        if output_mode != "instance_segmentation":
1505            segmentation = self._to_masks(segmentation, output_mode)
1506        return segmentation

Generate the instance segmentation for the currently initialized image.

The instance segmentation is generated by deriving prompts from the foreground and distance predictions of the segmentation decoder by thresholding these predictions, intersecting them, computing connected components, and then using the component's centers as point prompts. The masks are then filtered via NMS and merged into a segmentation.

Arguments:
  • min_size: Minimal object size in the segmentation result. By default, set to '25'.
  • center_distance_threshold: The threshold for the center distance predictions.
  • boundary_distance_threshold: The threshold for the boundary distance predictions.
  • multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
  • batch_size: The batch size for parallelizing the prediction based on prompts.
  • nms_threshold: The threshold for non-maximum suppression (NMS).
  • intersection_over_min: Whether to use the minimum area of the two objects or the intersection over union area (default) in NMS.
  • output_mode: The form masks are returned in. Possible values are:
    • 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
    • 'instance_segmentation': Return masks merged into an instance segmentation in a single array. By default, set to 'instance_segmentation'.
  • mask_threshold: The threshold for turning logits into masks in micro_sam.inference.batched_inference.`
  • refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps derived from the segmentations after point prompts.
  • prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. If given, the default prompt derivation procedure is not used. Must have the following signature:

        def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
    

    where foreground, center_distances, and boundary_distances are the respective predictions from the segmentation decoder. It must returns a dictionary containing either point, box, or mask prompts in a format compattible with micro_sam.inference.batched_inference.

Returns:

The instance segmentation masks.

class TiledAutomaticPromptGenerator(TiledInstanceSegmentationWithDecoder):
1509class TiledAutomaticPromptGenerator(TiledInstanceSegmentationWithDecoder):
1510    """Same as `AutomaticPromptGenerator` but for tiled image embeddings.
1511    """
1512    def generate(
1513        self,
1514        min_size: int = 25,
1515        center_distance_threshold: float = 0.5,
1516        boundary_distance_threshold: float = 0.5,
1517        foreground_threshold: float = 0.5,
1518        multimasking: bool = False,
1519        batch_size: int = 32,
1520        nms_threshold: float = 0.9,
1521        intersection_over_min: bool = False,
1522        output_mode: str = "instance_segmentation",
1523        mask_threshold: Optional[Union[float, str]] = None,
1524        refine_with_box_prompts: bool = False,
1525        prompt_function: Optional[callable] = None,
1526        optimize_memory: bool = False,
1527    ) -> List[Dict[str, Any]]:
1528        """Generate tiling-based instance segmentation for the currently initialized image.
1529
1530        Args:
1531            min_size: Minimal object size in the segmentation result. By default, set to '25'.
1532            center_distance_threshold: The threshold for the center distance predictions.
1533            boundary_distance_threshold: The threshold for the boundary distance predictions.
1534            multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
1535            batch_size: The batch size for parallelizing the prediction based on prompts.
1536            nms_threshold: The threshold for non-maximum suppression (NMS).
1537            intersection_over_min: Whether to use the minimum area of the two objects or the
1538                intersection over union area (default) in NMS.
1539            output_mode: The form masks are returned in. Possible values are:
1540                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1541                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1542                By default, set to 'instance_segmentation'.
1543            mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.`
1544            refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps
1545                derived from the segmentations after point prompts. Currently not supported for tiled segmentation.
1546            prompt_function: A custom function for deriving prompts from the segmentation decoder predictions.
1547                If given, the default prompt derivation procedure is not used. Must have the following signature:
1548                ```
1549                    def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
1550                ```
1551                where `foreground`, `center_distances`, and `boundary_distances` are the respective
1552                predictions from the segmentation decoder. It must returns a dictionary containing
1553                either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`.
1554            optimize_memory: Whether to optimize the memory consumption by merging the per-slice
1555                segmentation results immediatly with NMS, rather than running a NMS for all results.
1556                This may lead to a slightly different segmentation result and is only compatible with
1557                `refine_with_box_prompts=False` and `output_mode="instance_segmentation"`.
1558
1559        Returns:
1560            The instance segmentation masks.
1561        """
1562        if not self.is_initialized:
1563            raise RuntimeError("TiledAutomaticPromptGenerator has not been initialized. Call initialize first.")
1564        if optimize_memory and (output_mode != "instance_segmentation" or refine_with_box_prompts):
1565            raise ValueError("Invalid settings")
1566        foreground, center_distances, boundary_distances =\
1567            self._foreground, self._center_distances, self._boundary_distances
1568
1569        # 1.) Derive promtps from the decoder predictions.
1570        prompt_function = _derive_point_prompts if prompt_function is None else prompt_function
1571        prompts = prompt_function(
1572            foreground,
1573            center_distances,
1574            boundary_distances,
1575            foreground_threshold=foreground_threshold,
1576            center_distance_threshold=center_distance_threshold,
1577            boundary_distance_threshold=boundary_distance_threshold,
1578        )
1579
1580        # 2.) Apply the predictor to the prompts.
1581        shape = foreground.shape
1582        if prompts is None:  # No prompts were derived, we can't do much further and return empty masks.
1583            return np.zeros(shape, dtype="uint32") if output_mode == "instance_segmentation" else []
1584        else:
1585            if optimize_memory:
1586                prompts.update(dict(
1587                    min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min
1588                ))
1589            predictions = batched_tiled_inference(
1590                self._predictor,
1591                image=None,
1592                batch_size=batch_size,
1593                image_embeddings=self._image_embeddings,
1594                return_instance_segmentation=False,
1595                multimasking=multimasking,
1596                optimize_memory=optimize_memory,
1597                i=getattr(self, "_i", None),
1598                **prompts
1599            )
1600        # Optimize memory directly returns an instance segmentation and does not
1601        # allow for any further refinements.
1602        if optimize_memory:
1603            return predictions
1604
1605        # 3.) Refine the segmentation with box prompts.
1606        if refine_with_box_prompts:
1607            # TODO
1608            raise NotImplementedError
1609
1610        # 4.) Apply non-max suppression to the masks.
1611        segmentation = util.apply_nms(
1612            predictions, shape=shape, min_size=min_size, nms_thresh=nms_threshold,
1613            intersection_over_min=intersection_over_min,
1614        )
1615        if output_mode != "instance_segmentation":
1616            segmentation = self._to_masks(segmentation, output_mode)
1617        return segmentation
1618
1619    # Set state and get state are not implemented yet, as this generator relies on having the image embeddings
1620    # in the state. However, they should not be serialized here and we have to address this a bit differently.
1621    def get_state(self):
1622        """@private
1623        """
1624        raise NotImplementedError
1625
1626    def set_state(self, state):
1627        """@private
1628        """
1629        raise NotImplementedError

Same as AutomaticPromptGenerator but for tiled image embeddings.

def generate( self, min_size: int = 25, center_distance_threshold: float = 0.5, boundary_distance_threshold: float = 0.5, foreground_threshold: float = 0.5, multimasking: bool = False, batch_size: int = 32, nms_threshold: float = 0.9, intersection_over_min: bool = False, output_mode: str = 'instance_segmentation', mask_threshold: Union[float, str, NoneType] = None, refine_with_box_prompts: bool = False, prompt_function: Optional[<built-in function callable>] = None, optimize_memory: bool = False) -> List[Dict[str, Any]]:
1512    def generate(
1513        self,
1514        min_size: int = 25,
1515        center_distance_threshold: float = 0.5,
1516        boundary_distance_threshold: float = 0.5,
1517        foreground_threshold: float = 0.5,
1518        multimasking: bool = False,
1519        batch_size: int = 32,
1520        nms_threshold: float = 0.9,
1521        intersection_over_min: bool = False,
1522        output_mode: str = "instance_segmentation",
1523        mask_threshold: Optional[Union[float, str]] = None,
1524        refine_with_box_prompts: bool = False,
1525        prompt_function: Optional[callable] = None,
1526        optimize_memory: bool = False,
1527    ) -> List[Dict[str, Any]]:
1528        """Generate tiling-based instance segmentation for the currently initialized image.
1529
1530        Args:
1531            min_size: Minimal object size in the segmentation result. By default, set to '25'.
1532            center_distance_threshold: The threshold for the center distance predictions.
1533            boundary_distance_threshold: The threshold for the boundary distance predictions.
1534            multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
1535            batch_size: The batch size for parallelizing the prediction based on prompts.
1536            nms_threshold: The threshold for non-maximum suppression (NMS).
1537            intersection_over_min: Whether to use the minimum area of the two objects or the
1538                intersection over union area (default) in NMS.
1539            output_mode: The form masks are returned in. Possible values are:
1540                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1541                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1542                By default, set to 'instance_segmentation'.
1543            mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.`
1544            refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps
1545                derived from the segmentations after point prompts. Currently not supported for tiled segmentation.
1546            prompt_function: A custom function for deriving prompts from the segmentation decoder predictions.
1547                If given, the default prompt derivation procedure is not used. Must have the following signature:
1548                ```
1549                    def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
1550                ```
1551                where `foreground`, `center_distances`, and `boundary_distances` are the respective
1552                predictions from the segmentation decoder. It must returns a dictionary containing
1553                either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`.
1554            optimize_memory: Whether to optimize the memory consumption by merging the per-slice
1555                segmentation results immediatly with NMS, rather than running a NMS for all results.
1556                This may lead to a slightly different segmentation result and is only compatible with
1557                `refine_with_box_prompts=False` and `output_mode="instance_segmentation"`.
1558
1559        Returns:
1560            The instance segmentation masks.
1561        """
1562        if not self.is_initialized:
1563            raise RuntimeError("TiledAutomaticPromptGenerator has not been initialized. Call initialize first.")
1564        if optimize_memory and (output_mode != "instance_segmentation" or refine_with_box_prompts):
1565            raise ValueError("Invalid settings")
1566        foreground, center_distances, boundary_distances =\
1567            self._foreground, self._center_distances, self._boundary_distances
1568
1569        # 1.) Derive promtps from the decoder predictions.
1570        prompt_function = _derive_point_prompts if prompt_function is None else prompt_function
1571        prompts = prompt_function(
1572            foreground,
1573            center_distances,
1574            boundary_distances,
1575            foreground_threshold=foreground_threshold,
1576            center_distance_threshold=center_distance_threshold,
1577            boundary_distance_threshold=boundary_distance_threshold,
1578        )
1579
1580        # 2.) Apply the predictor to the prompts.
1581        shape = foreground.shape
1582        if prompts is None:  # No prompts were derived, we can't do much further and return empty masks.
1583            return np.zeros(shape, dtype="uint32") if output_mode == "instance_segmentation" else []
1584        else:
1585            if optimize_memory:
1586                prompts.update(dict(
1587                    min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min
1588                ))
1589            predictions = batched_tiled_inference(
1590                self._predictor,
1591                image=None,
1592                batch_size=batch_size,
1593                image_embeddings=self._image_embeddings,
1594                return_instance_segmentation=False,
1595                multimasking=multimasking,
1596                optimize_memory=optimize_memory,
1597                i=getattr(self, "_i", None),
1598                **prompts
1599            )
1600        # Optimize memory directly returns an instance segmentation and does not
1601        # allow for any further refinements.
1602        if optimize_memory:
1603            return predictions
1604
1605        # 3.) Refine the segmentation with box prompts.
1606        if refine_with_box_prompts:
1607            # TODO
1608            raise NotImplementedError
1609
1610        # 4.) Apply non-max suppression to the masks.
1611        segmentation = util.apply_nms(
1612            predictions, shape=shape, min_size=min_size, nms_thresh=nms_threshold,
1613            intersection_over_min=intersection_over_min,
1614        )
1615        if output_mode != "instance_segmentation":
1616            segmentation = self._to_masks(segmentation, output_mode)
1617        return segmentation

Generate tiling-based instance segmentation for the currently initialized image.

Arguments:
  • min_size: Minimal object size in the segmentation result. By default, set to '25'.
  • center_distance_threshold: The threshold for the center distance predictions.
  • boundary_distance_threshold: The threshold for the boundary distance predictions.
  • multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
  • batch_size: The batch size for parallelizing the prediction based on prompts.
  • nms_threshold: The threshold for non-maximum suppression (NMS).
  • intersection_over_min: Whether to use the minimum area of the two objects or the intersection over union area (default) in NMS.
  • output_mode: The form masks are returned in. Possible values are:
    • 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
    • 'instance_segmentation': Return masks merged into an instance segmentation in a single array. By default, set to 'instance_segmentation'.
  • mask_threshold: The threshold for turining logits into masks in micro_sam.inference.batched_inference.`
  • refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps derived from the segmentations after point prompts. Currently not supported for tiled segmentation.
  • prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. If given, the default prompt derivation procedure is not used. Must have the following signature:

        def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
    

    where foreground, center_distances, and boundary_distances are the respective predictions from the segmentation decoder. It must returns a dictionary containing either point, box, or mask prompts in a format compattible with micro_sam.inference.batched_inference.

  • optimize_memory: Whether to optimize the memory consumption by merging the per-slice segmentation results immediatly with NMS, rather than running a NMS for all results. This may lead to a slightly different segmentation result and is only compatible with refine_with_box_prompts=False and output_mode="instance_segmentation".
Returns:

The instance segmentation masks.

def get_instance_segmentation_generator( predictor: segment_anything.predictor.SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.modules.module.Module] = None, segmentation_mode: Optional[Literal['amg', 'ais', 'apg']] = None, **kwargs) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1632def get_instance_segmentation_generator(
1633    predictor: SamPredictor,
1634    is_tiled: bool,
1635    decoder: Optional[torch.nn.Module] = None,
1636    segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None,
1637    **kwargs,
1638) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1639    f"""Get the automatic mask generator.
1640
1641    Args:
1642        predictor: The segment anything predictor.
1643        is_tiled: Whether tiled embeddings are used.
1644        decoder: Decoder to predict instacne segmmentation.
1645        segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
1646            By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used if a decoder is passed,
1647            otherwise 'amg' is used.
1648        kwargs: The keyword arguments of the segmentation genetator class.
1649
1650    Returns:
1651        The segmentation generator instance.
1652    """
1653    # Choose the segmentation decoder default depending on whether we have a decoder.
1654    if segmentation_mode is None:
1655        segmentation_mode = "amg" if decoder is None else DEFAULT_SEGMENTATION_MODE_WITH_DECODER
1656
1657    if segmentation_mode.lower() == "amg":
1658        segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator
1659        segmenter = segmenter_class(predictor, **kwargs)
1660    elif segmentation_mode.lower() == "ais":
1661        assert decoder is not None
1662        segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder
1663        segmenter = segmenter_class(predictor, decoder, **kwargs)
1664    elif segmentation_mode.lower() == "apg":
1665        assert decoder is not None
1666        segmenter_class = TiledAutomaticPromptGenerator if is_tiled else AutomaticPromptGenerator
1667        segmenter = segmenter_class(predictor, decoder, **kwargs)
1668    else:
1669        raise ValueError(f"Invalid segmentation_mode: {segmentation_mode}. Choose one of 'amg', 'ais', or 'apg'.")
1670
1671    return segmenter