micro_sam.instance_segmentation

Automated instance segmentation functionality. The classes implemented here extend the automatic instance segmentation from Segment Anything: https://computational-cell-analytics.github.io/micro-sam/micro_sam.html

   1"""Automated instance segmentation functionality.
   2The classes implemented here extend the automatic instance segmentation from Segment Anything:
   3https://computational-cell-analytics.github.io/micro-sam/micro_sam.html
   4"""
   5
   6import os
   7import warnings
   8from abc import ABC
   9from copy import deepcopy
  10from collections import OrderedDict
  11from typing import Any, Dict, List, Optional, Tuple, Union
  12
  13import vigra
  14import numpy as np
  15from skimage.measure import regionprops
  16from skimage.segmentation import find_boundaries
  17
  18import torch
  19from torchvision.ops.boxes import batched_nms, box_area
  20
  21from torch_em.model import UNETR
  22from torch_em.util.segmentation import watershed_from_center_and_boundary_distances
  23
  24import elf.parallel as parallel_impl
  25from elf.parallel.filters import apply_filter
  26
  27from nifty.tools import blocking
  28
  29import segment_anything.utils.amg as amg_utils
  30from segment_anything.predictor import SamPredictor
  31
  32from . import util
  33from .inference import batched_inference, batched_tiled_inference
  34from ._vendored import batched_mask_to_box, mask_to_rle_pytorch
  35
  36# We may change this to 'apg' in version 1.8.
  37DEFAULT_SEGMENTATION_MODE_WITH_DECODER = "ais"
  38
  39#
  40# Utility Functionality
  41#
  42
  43
  44class _FakeInput:
  45    def __init__(self, shape):
  46        self.shape = shape
  47
  48    def __getitem__(self, index):
  49        block_shape = tuple(ind.stop - ind.start for ind in index)
  50        return np.zeros(block_shape, dtype="float32")
  51
  52
  53#
  54# Classes for automatic instance segmentation
  55#
  56
  57
  58class AMGBase(ABC):
  59    """Base class for the automatic mask generators.
  60    """
  61    def __init__(self):
  62        # the state that has to be computed by the 'initialize' method of the child classes
  63        self._is_initialized = False
  64        self._crop_list = None
  65        self._crop_boxes = None
  66        self._original_size = None
  67
  68    @property
  69    def is_initialized(self):
  70        """Whether the mask generator has already been initialized.
  71        """
  72        return self._is_initialized
  73
  74    @property
  75    def crop_list(self):
  76        """The list of mask data after initialization.
  77        """
  78        return self._crop_list
  79
  80    @property
  81    def crop_boxes(self):
  82        """The list of crop boxes.
  83        """
  84        return self._crop_boxes
  85
  86    @property
  87    def original_size(self):
  88        """The original image size.
  89        """
  90        return self._original_size
  91
  92    def _postprocess_batch(
  93        self,
  94        data,
  95        crop_box,
  96        original_size,
  97        pred_iou_thresh,
  98        stability_score_thresh,
  99        box_nms_thresh,
 100    ):
 101        orig_h, orig_w = original_size
 102
 103        # filter by predicted IoU
 104        if pred_iou_thresh > 0.0:
 105            keep_mask = data["iou_preds"] > pred_iou_thresh
 106            data.filter(keep_mask)
 107
 108        # filter by stability score
 109        if stability_score_thresh > 0.0:
 110            keep_mask = data["stability_score"] >= stability_score_thresh
 111            data.filter(keep_mask)
 112
 113        # filter boxes that touch crop boundaries
 114        keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
 115        if not torch.all(keep_mask):
 116            data.filter(keep_mask)
 117
 118        # remove duplicates within this crop.
 119        keep_by_nms = batched_nms(
 120            data["boxes"].float(),
 121            data["iou_preds"],
 122            torch.zeros_like(data["boxes"][:, 0]),  # categories
 123            iou_threshold=box_nms_thresh,
 124        )
 125        data.filter(keep_by_nms)
 126
 127        # return to the original image frame
 128        data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box)
 129        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
 130        # the data from embedding based segmentation doesn't have the points
 131        # so we skip if the corresponding key can't be found
 132        try:
 133            data["points"] = amg_utils.uncrop_points(data["points"], crop_box)
 134        except KeyError:
 135            pass
 136
 137        return data
 138
 139    def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
 140
 141        if len(mask_data["rles"]) == 0:
 142            return mask_data
 143
 144        # filter small disconnected regions and holes
 145        new_masks = []
 146        scores = []
 147        for rle in mask_data["rles"]:
 148            mask = amg_utils.rle_to_mask(rle)
 149
 150            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes")
 151            unchanged = not changed
 152            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands")
 153            unchanged = unchanged and not changed
 154
 155            new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0))
 156            # give score=0 to changed masks and score=1 to unchanged masks
 157            # so NMS will prefer ones that didn't need postprocessing
 158            scores.append(float(unchanged))
 159
 160        # recalculate boxes and remove any new duplicates
 161        masks = torch.cat(new_masks, dim=0)
 162        boxes = batched_mask_to_box(masks.to(torch.bool))  # Casting this to boolean as we work with one-hot labels.
 163        keep_by_nms = batched_nms(
 164            boxes.float(),
 165            torch.as_tensor(scores, dtype=torch.float),
 166            torch.zeros_like(boxes[:, 0]),  # categories
 167            iou_threshold=nms_thresh,
 168        )
 169
 170        # only recalculate RLEs for masks that have changed
 171        for i_mask in keep_by_nms:
 172            if scores[i_mask] == 0.0:
 173                mask_torch = masks[i_mask].unsqueeze(0)
 174                # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
 175                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
 176                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
 177        mask_data.filter(keep_by_nms)
 178
 179        return mask_data
 180
 181    def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode):
 182        # filter small disconnected regions and holes in masks
 183        if min_mask_region_area > 0:
 184            mask_data = self._postprocess_small_regions(
 185                mask_data,
 186                min_mask_region_area,
 187                max(box_nms_thresh, crop_nms_thresh),
 188            )
 189
 190        # encode masks
 191        if output_mode == "coco_rle":
 192            mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]]
 193        # We require the binary mask output for creating the instance-seg output.
 194        elif output_mode in ("binary_mask", "instance_segmentation"):
 195            mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]]
 196        elif output_mode == "rle":
 197            mask_data["segmentations"] = mask_data["rles"]
 198        else:
 199            raise ValueError(f"Invalid output mode {output_mode}.")
 200
 201        # write mask records
 202        curr_anns = []
 203        for idx in range(len(mask_data["segmentations"])):
 204            ann = {
 205                "segmentation": mask_data["segmentations"][idx],
 206                "area": amg_utils.area_from_rle(mask_data["rles"][idx]),
 207                "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
 208                "predicted_iou": mask_data["iou_preds"][idx].item(),
 209                "stability_score": mask_data["stability_score"][idx].item(),
 210                "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
 211            }
 212            # the data from embedding based segmentation doesn't have the points
 213            # so we skip if the corresponding key can't be found
 214            try:
 215                ann["point_coords"] = [mask_data["points"][idx].tolist()]
 216            except KeyError:
 217                pass
 218            curr_anns.append(ann)
 219
 220        return curr_anns
 221
 222    def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
 223        orig_h, orig_w = original_size
 224
 225        # serialize predictions and store in MaskData
 226        data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
 227        if points is not None:
 228            data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
 229
 230        del masks
 231
 232        # calculate the stability scores
 233        data["stability_score"] = amg_utils.calculate_stability_score(
 234            data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
 235        )
 236
 237        # threshold masks and calculate boxes
 238        data["masks"] = data["masks"] > self._predictor.model.mask_threshold
 239        data["masks"] = data["masks"].type(torch.bool)
 240        data["boxes"] = batched_mask_to_box(data["masks"])
 241
 242        # compress to RLE
 243        data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
 244        # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
 245        data["rles"] = mask_to_rle_pytorch(data["masks"])
 246        del data["masks"]
 247
 248        return data
 249
 250    def get_state(self) -> Dict[str, Any]:
 251        """Get the initialized state of the mask generator.
 252
 253        Returns:
 254            State of the mask generator.
 255        """
 256        if not self.is_initialized:
 257            raise RuntimeError("The state has not been computed yet. Call initialize first.")
 258
 259        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
 260
 261    def set_state(self, state: Dict[str, Any]) -> None:
 262        """Set the state of the mask generator.
 263
 264        Args:
 265            state: The state of the mask generator, e.g. from serialized state.
 266        """
 267        self._crop_list = state["crop_list"]
 268        self._crop_boxes = state["crop_boxes"]
 269        self._original_size = state["original_size"]
 270        self._is_initialized = True
 271
 272    def clear_state(self):
 273        """Clear the state of the mask generator.
 274        """
 275        self._crop_list = None
 276        self._crop_boxes = None
 277        self._original_size = None
 278        self._is_initialized = False
 279
 280
 281class AutomaticMaskGenerator(AMGBase):
 282    """Generates an instance segmentation without prompts, using a point grid.
 283
 284    This class implements the same logic as
 285    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
 286    It decouples the computationally expensive steps of generating masks from the cheap post-processing operation
 287    to filter these masks to enable grid search and interactively changing the post-processing.
 288
 289    Use this class as follows:
 290    ```python
 291    amg = AutomaticMaskGenerator(predictor)
 292    amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
 293    masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
 294    ```
 295
 296    Args:
 297        predictor: The segment anything predictor.
 298        points_per_side: The number of points to be sampled along one side of the image.
 299            If None, `point_grids` must provide explicit point sampling. By default, set to '32'.
 300        points_per_batch: The number of points run simultaneously by the model.
 301            Higher numbers may be faster but use more GPU memory.
 302            By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons).
 303        crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
 304            By default, set to '0'.
 305        crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'.
 306        crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
 307            By default, set to '1'.
 308        point_grids: A list over explicit grids of points used for sampling masks.
 309            Normalized to [0, 1] with respect to the image coordinate system.
 310        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
 311            By default, set to '1.0'.
 312    """
 313    def __init__(
 314        self,
 315        predictor: SamPredictor,
 316        points_per_side: Optional[int] = 32,
 317        points_per_batch: Optional[int] = None,
 318        crop_n_layers: int = 0,
 319        crop_overlap_ratio: float = 512 / 1500,
 320        crop_n_points_downscale_factor: int = 1,
 321        point_grids: Optional[List[np.ndarray]] = None,
 322        stability_score_offset: float = 1.0,
 323    ):
 324        super().__init__()
 325
 326        if points_per_side is not None:
 327            self.point_grids = amg_utils.build_all_layer_point_grids(
 328                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
 329            )
 330        elif point_grids is not None:
 331            self.point_grids = point_grids
 332        else:
 333            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
 334
 335        self._predictor = predictor
 336        self._points_per_side = points_per_side
 337
 338        # we set the points per batch to 16 for mps for performance reasons
 339        # and otherwise keep them at the default of 64
 340        if points_per_batch is None:
 341            points_per_batch = 16 if str(predictor.device) == "mps" else 64
 342        self._points_per_batch = points_per_batch
 343
 344        self._crop_n_layers = crop_n_layers
 345        self._crop_overlap_ratio = crop_overlap_ratio
 346        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
 347        self._stability_score_offset = stability_score_offset
 348
 349    def _process_batch(self, points, im_size, crop_box, original_size):
 350        # run model on this batch
 351        transformed_points = self._predictor.transform.apply_coords(points, im_size)
 352        in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
 353        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
 354        masks, iou_preds, _ = self._predictor.predict_torch(
 355            point_coords=in_points[:, None, :],
 356            point_labels=in_labels[:, None],
 357            multimask_output=True,
 358            return_logits=True,
 359        )
 360        data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
 361        del masks
 362        return data
 363
 364    def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None):
 365        # Crop the image and calculate embeddings.
 366        x0, y0, x1, y1 = crop_box
 367        cropped_im = image[y0:y1, x0:x1, :]
 368        cropped_im_size = cropped_im.shape[:2]
 369
 370        if not precomputed_embeddings:
 371            self._predictor.set_image(cropped_im)
 372
 373        # Get the points for this crop.
 374        points_scale = np.array(cropped_im_size)[None, ::-1]
 375        points_for_image = self.point_grids[crop_layer_idx] * points_scale
 376
 377        # Generate masks for this crop in batches.
 378        data = amg_utils.MaskData()
 379        n_batches = len(points_for_image) // self._points_per_batch +\
 380            int(len(points_for_image) % self._points_per_batch != 0)
 381        if pbar_init is not None:
 382            pbar_init(n_batches, "Predict masks for point grid prompts")
 383
 384        for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image):
 385            batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
 386            data.cat(batch_data)
 387            del batch_data
 388            if pbar_update is not None:
 389                pbar_update(1)
 390
 391        if not precomputed_embeddings:
 392            self._predictor.reset_image()
 393
 394        return data
 395
 396    @torch.no_grad()
 397    def initialize(
 398        self,
 399        image: np.ndarray,
 400        image_embeddings: Optional[util.ImageEmbeddings] = None,
 401        i: Optional[int] = None,
 402        verbose: bool = False,
 403        pbar_init: Optional[callable] = None,
 404        pbar_update: Optional[callable] = None,
 405    ) -> None:
 406        """Initialize image embeddings and masks for an image.
 407
 408        Args:
 409            image: The input image, volume or timeseries.
 410            image_embeddings: Optional precomputed image embeddings.
 411                See `util.precompute_image_embeddings` for details.
 412            i: Index for the image data. Required if `image` has three spatial dimensions
 413                or a time dimension and two spatial dimensions.
 414            verbose: Whether to print computation progress. By default, set to 'False'.
 415            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 416                Can be used together with pbar_update to handle napari progress bar in other thread.
 417                To enables using this function within a threadworker.
 418            pbar_update: Callback to update an external progress bar.
 419        """
 420        original_size = image.shape[:2]
 421        self._original_size = original_size
 422
 423        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
 424            original_size, self._crop_n_layers, self._crop_overlap_ratio
 425        )
 426
 427        # We can set fixed image embeddings if we only have a single crop box (the default setting).
 428        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
 429        if len(crop_boxes) == 1:
 430            if image_embeddings is None:
 431                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 432            util.set_precomputed(self._predictor, image_embeddings, i=i)
 433            precomputed_embeddings = True
 434        else:
 435            precomputed_embeddings = False
 436
 437        # we need to cast to the image representation that is compatible with SAM
 438        image = util._to_image(image)
 439
 440        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 441
 442        crop_list = []
 443        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
 444            crop_data = self._process_crop(
 445                image, crop_box, layer_idx,
 446                precomputed_embeddings=precomputed_embeddings,
 447                pbar_init=pbar_init, pbar_update=pbar_update,
 448            )
 449            crop_list.append(crop_data)
 450        pbar_close()
 451
 452        self._is_initialized = True
 453        self._crop_list = crop_list
 454        self._crop_boxes = crop_boxes
 455
 456    @torch.no_grad()
 457    def generate(
 458        self,
 459        pred_iou_thresh: float = 0.88,
 460        stability_score_thresh: float = 0.95,
 461        box_nms_thresh: float = 0.7,
 462        crop_nms_thresh: float = 0.7,
 463        min_mask_region_area: int = 0,
 464        output_mode: str = "instance_segmentation",
 465        with_background: bool = True,
 466    ) -> Union[List[Dict[str, Any]], np.ndarray]:
 467        """Generate instance segmentation for the currently initialized image.
 468
 469        Args:
 470            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
 471                By default, set to '0.88'.
 472            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
 473                under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
 474            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
 475                By default, set to '0.7'.
 476            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
 477                By default, set to '0.7'.
 478            min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
 479            output_mode: The form masks are returned in. Possible values are:
 480                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
 481                - 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format.
 482                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
 483                - 'rle': Return a list of dictionaries with run-length encoded masks.
 484                By default, set to 'instance_segmentation'.
 485            with_background: Whether to remove the largest object, which often covers the background.
 486
 487        Returns:
 488            The segmentation masks.
 489        """
 490        if not self.is_initialized:
 491            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
 492
 493        data = amg_utils.MaskData()
 494        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
 495            crop_data = self._postprocess_batch(
 496                data=deepcopy(data_),
 497                crop_box=crop_box, original_size=self.original_size,
 498                pred_iou_thresh=pred_iou_thresh,
 499                stability_score_thresh=stability_score_thresh,
 500                box_nms_thresh=box_nms_thresh
 501            )
 502            data.cat(crop_data)
 503
 504        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
 505            # Prefer masks from smaller crops
 506            scores = 1 / box_area(data["crop_boxes"])
 507            scores = scores.to(data["boxes"].device)
 508            keep_by_nms = batched_nms(
 509                data["boxes"].float(),
 510                scores,
 511                torch.zeros_like(data["boxes"][:, 0]),  # categories
 512                iou_threshold=crop_nms_thresh,
 513            )
 514            data.filter(keep_by_nms)
 515
 516        data.to_numpy()
 517        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
 518        if output_mode == "instance_segmentation":
 519            shape = next(iter(masks))["segmentation"].shape if len(masks) > 0 else self.original_size
 520            masks = util.mask_data_to_segmentation(
 521                masks, shape=shape, with_background=with_background, merge_exclusively=False
 522            )
 523        return masks
 524
 525
 526# Helper function for tiled embedding computation and checking consistent state.
 527def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo, verbose, batch_size, mask, i):
 528    if image_embeddings is None:
 529        if tile_shape is None or halo is None:
 530            raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.")
 531        image_embeddings = util.precompute_image_embeddings(
 532            predictor, image, tile_shape=tile_shape, halo=halo, verbose=verbose, batch_size=batch_size, mask=mask,
 533        )
 534
 535    # Use tile shape and halo from the precomputed embeddings if not given.
 536    # Otherwise check that they are consistent.
 537    feats = image_embeddings["features"]
 538    tile_shape_, halo_ = tuple(feats.attrs["tile_shape"]), tuple(feats.attrs["halo"])
 539    if tile_shape is None:
 540        tile_shape = tile_shape_
 541    elif tile_shape != tile_shape_:
 542        raise ValueError(
 543            f"Inconsistent tile_shape parameter {tile_shape} with precomputed embeedings: {tile_shape_}."
 544        )
 545    if halo is None:
 546        halo = halo_
 547    elif halo != halo_:
 548        raise ValueError(f"Inconsistent halo parameter {halo} with precomputed embeedings: {halo_}.")
 549
 550    tiles_in_mask = feats.attrs.get("tiles_in_mask", None)
 551    if tiles_in_mask is not None and i is not None:
 552        tiles_in_mask = tiles_in_mask[str(i)]
 553
 554    return image_embeddings, tile_shape, halo, tiles_in_mask
 555
 556
 557class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
 558    """Generates an instance segmentation without prompts, using a point grid.
 559
 560    Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
 561
 562    Args:
 563        predictor: The Segment Anything predictor.
 564        points_per_side: The number of points to be sampled along one side of the image.
 565            If None, `point_grids` must provide explicit point sampling. By default, set to '32'.
 566        points_per_batch: The number of points run simultaneously by the model.
 567            Higher numbers may be faster but use more GPU memory. By default, set to '64'.
 568        point_grids: A list over explicit grids of points used for sampling masks.
 569            Normalized to [0, 1] with respect to the image coordinate system.
 570        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
 571            By default, set to '1.0'.
 572    """
 573
 574    # We only expose the arguments that make sense for the tiled mask generator.
 575    # Anything related to crops doesn't make sense, because we re-use that functionality
 576    # for tiling, so these parameters wouldn't have any effect.
 577    def __init__(
 578        self,
 579        predictor: SamPredictor,
 580        points_per_side: Optional[int] = 32,
 581        points_per_batch: int = 64,
 582        point_grids: Optional[List[np.ndarray]] = None,
 583        stability_score_offset: float = 1.0,
 584    ) -> None:
 585        super().__init__(
 586            predictor=predictor,
 587            points_per_side=points_per_side,
 588            points_per_batch=points_per_batch,
 589            point_grids=point_grids,
 590            stability_score_offset=stability_score_offset,
 591        )
 592
 593    @torch.no_grad()
 594    def initialize(
 595        self,
 596        image: np.ndarray,
 597        image_embeddings: Optional[util.ImageEmbeddings] = None,
 598        i: Optional[int] = None,
 599        tile_shape: Optional[Tuple[int, int]] = None,
 600        halo: Optional[Tuple[int, int]] = None,
 601        verbose: bool = False,
 602        pbar_init: Optional[callable] = None,
 603        pbar_update: Optional[callable] = None,
 604        batch_size: int = 1,
 605        mask: Optional[np.typing.ArrayLike] = None,
 606    ) -> None:
 607        """Initialize image embeddings and masks for an image.
 608
 609        Args:
 610            image: The input image, volume or timeseries.
 611            image_embeddings: Optional precomputed image embeddings.
 612                See `util.precompute_image_embeddings` for details.
 613            i: Index for the image data. Required if `image` has three spatial dimensions
 614                or a time dimension and two spatial dimensions.
 615            tile_shape: The tile shape for embedding prediction.
 616            halo: The overlap of between tiles.
 617            verbose: Whether to print computation progress. By default, set to 'False'.
 618            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 619                Can be used together with pbar_update to handle napari progress bar in other thread.
 620                To enables using this function within a threadworker.
 621            pbar_update: Callback to update an external progress bar.
 622            batch_size: The batch size for image embedding prediction. By default, set to '1'.
 623            mask: An optional mask to define areas that are ignored in the segmentation.
 624        """
 625        original_size = image.shape[:2]
 626        self._original_size = original_size
 627
 628        self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings(
 629            self._predictor, image, image_embeddings, tile_shape, halo,
 630            verbose=verbose, batch_size=batch_size, mask=mask, i=i,
 631        )
 632
 633        tiling = blocking([0, 0], original_size, tile_shape)
 634        if tiles_in_mask is None:
 635            n_tiles = tiling.numberOfBlocks
 636            tile_ids = range(n_tiles)
 637        else:
 638            n_tiles = len(tiles_in_mask)
 639            tile_ids = tiles_in_mask
 640
 641        # The crop box is always the full local tile.
 642        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in tile_ids]
 643        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
 644
 645        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 646        pbar_init(n_tiles, "Compute masks for tile")
 647
 648        # We need to cast to the image representation that is compatible with SAM.
 649        image = util._to_image(image)
 650
 651        mask_data = []
 652        for idx, tile_id in enumerate(tile_ids):
 653            # set the pre-computed embeddings for this tile
 654            features = image_embeddings["features"][str(tile_id)]
 655            tile_embeddings = {
 656                "features": features,
 657                "input_size": features.attrs["input_size"],
 658                "original_size": features.attrs["original_size"],
 659            }
 660            util.set_precomputed(self._predictor, tile_embeddings, i)
 661
 662            # Compute the mask data for this tile and append it
 663            this_mask_data = self._process_crop(
 664                image, crop_box=crop_boxes[idx], crop_layer_idx=0, precomputed_embeddings=True
 665            )
 666            mask_data.append(this_mask_data)
 667            pbar_update(1)
 668        pbar_close()
 669
 670        # set the initialized data
 671        self._is_initialized = True
 672        self._crop_list = mask_data
 673        self._crop_boxes = crop_boxes
 674
 675
 676#
 677# Instance segmentation functionality based on fine-tuned decoder
 678#
 679
 680
 681class DecoderAdapter(torch.nn.Module):
 682    """Adapter to contain the UNETR decoder in a single module.
 683
 684    To apply the decoder on top of pre-computed embeddings for the segmentation functionality.
 685    See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
 686    """
 687    def __init__(self, unetr: torch.nn.Module):
 688        super().__init__()
 689
 690        self.base = unetr.base
 691        self.out_conv = unetr.out_conv
 692        self.deconv_out = unetr.deconv_out
 693        self.decoder_head = unetr.decoder_head
 694        self.final_activation = unetr.final_activation
 695        self.postprocess_masks = unetr.postprocess_masks
 696
 697        self.decoder = unetr.decoder
 698        self.deconv1 = unetr.deconv1
 699        self.deconv2 = unetr.deconv2
 700        self.deconv3 = unetr.deconv3
 701        self.deconv4 = unetr.deconv4
 702
 703    def _forward_impl(self, input_):
 704        z12 = input_
 705
 706        z9 = self.deconv1(z12)
 707        z6 = self.deconv2(z9)
 708        z3 = self.deconv3(z6)
 709        z0 = self.deconv4(z3)
 710
 711        updated_from_encoder = [z9, z6, z3]
 712
 713        x = self.base(z12)
 714        x = self.decoder(x, encoder_inputs=updated_from_encoder)
 715        x = self.deconv_out(x)
 716
 717        x = torch.cat([x, z0], dim=1)
 718        x = self.decoder_head(x)
 719
 720        x = self.out_conv(x)
 721        if self.final_activation is not None:
 722            x = self.final_activation(x)
 723        return x
 724
 725    def forward(self, input_, input_shape, original_shape):
 726        x = self._forward_impl(input_)
 727        x = self.postprocess_masks(x, input_shape, original_shape)
 728        return x
 729
 730
 731def get_unetr(
 732    image_encoder: torch.nn.Module,
 733    decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
 734    device: Optional[Union[str, torch.device]] = None,
 735    out_channels: int = 3,
 736    flexible_load_checkpoint: bool = False,
 737) -> torch.nn.Module:
 738    """Get UNETR model for automatic instance segmentation.
 739
 740    Args:
 741        image_encoder: The image encoder of the SAM model.
 742            This is used as encoder by the UNETR too.
 743        decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
 744        device: The device. By default, automatically chooses the best available device.
 745        out_channels: The number of output channels. By default, set to '3'.
 746        flexible_load_checkpoint: Whether to allow reinitialization of parameters
 747            which could not be found in the provided decoder state. By default, set to 'False'.
 748
 749    Returns:
 750        The UNETR model.
 751    """
 752    device = util.get_device(device)
 753
 754    if decoder_state is None:
 755        use_conv_transpose = False  # By default, we use interpolation for upsampling.
 756    else:
 757        # From the provided pretrained 'decoder_state', we check whether it uses transposed convolutions.
 758        # NOTE: Explanation to the logic below -
 759        # - We do this by looking for parameter names that contain '.block.' within the "decoder.samplers"
 760        #   submodules. This naming convention indicates that transposed convolutions are used,
 761        #   wrapped inside a custom block module.
 762        # - Otherwise '.conv.' appears. It indicates a standard `Conv2d` applied after interpolation for upsampling.
 763        use_conv_transpose = any(".block." in k for k in decoder_state.keys() if k.startswith("decoder.samplers"))
 764
 765    unetr = UNETR(
 766        backbone="sam",
 767        encoder=image_encoder,
 768        out_channels=out_channels,
 769        use_sam_stats=True,
 770        final_activation="Sigmoid",
 771        use_skip_connection=False,
 772        resize_input=True,
 773        use_conv_transpose=use_conv_transpose,
 774    )
 775
 776    if decoder_state is not None:
 777        unetr_state_dict = unetr.state_dict()
 778        for k, v in unetr_state_dict.items():
 779            if not k.startswith("encoder"):
 780                if flexible_load_checkpoint:  # Whether allow reinitalization of params, if not found.
 781                    if k in decoder_state:  # First check whether the key is available in the provided decoder state.
 782                        unetr_state_dict[k] = decoder_state[k]
 783                    else:  # Otherwise, allow it to initialize it.
 784                        warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.")
 785                        unetr_state_dict[k] = v
 786
 787                else:  # Whether be strict on finding the parameter in the decoder state.
 788                    if k not in decoder_state:
 789                        raise RuntimeError(f"The parameters for '{k}' could not be found.")
 790                    unetr_state_dict[k] = decoder_state[k]
 791
 792        unetr.load_state_dict(unetr_state_dict)
 793
 794    unetr.to(device)
 795    return unetr
 796
 797
 798def get_decoder(
 799    image_encoder: torch.nn.Module,
 800    decoder_state: OrderedDict[str, torch.Tensor],
 801    device: Optional[Union[str, torch.device]] = None,
 802) -> DecoderAdapter:
 803    """Get decoder to predict outputs for automatic instance segmentation
 804
 805    Args:
 806        image_encoder: The image encoder of the SAM model.
 807        decoder_state: State to initialize the weights of the UNETR decoder.
 808        device: The device. By default, automatically chooses the best available device.
 809
 810    Returns:
 811        The decoder for instance segmentation.
 812    """
 813    unetr = get_unetr(image_encoder, decoder_state, device)
 814    return DecoderAdapter(unetr)
 815
 816
 817def get_predictor_and_decoder(
 818    model_type: str,
 819    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 820    device: Optional[Union[str, torch.device]] = None,
 821    peft_kwargs: Optional[Dict] = None,
 822) -> Tuple[SamPredictor, DecoderAdapter]:
 823    """Load the SAM model (predictor) and instance segmentation decoder.
 824
 825    This requires a checkpoint that contains the state for both predictor
 826    and decoder.
 827
 828    Args:
 829        model_type: The type of the image encoder used in the SAM model.
 830        checkpoint_path: Path to the checkpoint from which to load the data.
 831        device: The device. By default, automatically chooses the best available device.
 832        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 833
 834    Returns:
 835        The SAM predictor.
 836        The decoder for instance segmentation.
 837    """
 838    device = util.get_device(device)
 839    predictor, state = util.get_sam_model(
 840        model_type=model_type,
 841        checkpoint_path=checkpoint_path,
 842        device=device,
 843        return_state=True,
 844        peft_kwargs=peft_kwargs,
 845    )
 846
 847    if "decoder_state" not in state:
 848        raise ValueError(
 849            f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
 850        )
 851
 852    decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
 853    return predictor, decoder
 854
 855
 856def _watershed_from_center_and_boundary_distances_parallel(
 857    center_distances,
 858    boundary_distances,
 859    foreground_map,
 860    center_distance_threshold,
 861    boundary_distance_threshold,
 862    foreground_threshold,
 863    distance_smoothing,
 864    min_size,
 865    tile_shape,
 866    halo,
 867    n_threads,
 868    verbose=False,
 869):
 870    center_distances = apply_filter(
 871        center_distances, "gaussianSmoothing", sigma=distance_smoothing,
 872        block_shape=tile_shape, n_threads=n_threads
 873    )
 874    boundary_distances = apply_filter(
 875        boundary_distances, "gaussianSmoothing", sigma=distance_smoothing,
 876        block_shape=tile_shape, n_threads=n_threads
 877    )
 878
 879    fg_mask = foreground_map > foreground_threshold
 880
 881    marker_map = np.logical_and(
 882        center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold
 883    )
 884    marker_map[~fg_mask] = 0
 885
 886    markers = np.zeros(marker_map.shape, dtype="uint64")
 887    markers = parallel_impl.label(
 888        marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose,
 889    )
 890
 891    seg = np.zeros_like(markers, dtype="uint64")
 892    seg = parallel_impl.seeded_watershed(
 893        boundary_distances, seeds=markers, out=seg, block_shape=tile_shape,
 894        halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask,
 895    )
 896
 897    out = np.zeros_like(seg, dtype="uint64")
 898    out = parallel_impl.size_filter(
 899        seg, out=out, min_size=min_size, block_shape=tile_shape, n_threads=n_threads, verbose=verbose
 900    )
 901
 902    return out
 903
 904
 905class InstanceSegmentationWithDecoder:
 906    """Generates an instance segmentation without prompts, using a decoder.
 907
 908    Implements the same interface as `AutomaticMaskGenerator`.
 909
 910    Use this class as follows:
 911    ```python
 912    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
 913    segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
 914    masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
 915    ```
 916
 917    Args:
 918        predictor: The segment anything predictor.
 919        decoder: The decoder to predict intermediate representations for instance segmentation.
 920    """
 921    def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None:
 922        self._predictor = predictor
 923        self._decoder = decoder
 924
 925        # The decoder outputs.
 926        self._foreground = None
 927        self._center_distances = None
 928        self._boundary_distances = None
 929
 930        self._is_initialized = False
 931
 932    @property
 933    def is_initialized(self):
 934        """Whether the mask generator has already been initialized.
 935        """
 936        return self._is_initialized
 937
 938    @torch.no_grad()
 939    def initialize(
 940        self,
 941        image: np.ndarray,
 942        image_embeddings: Optional[util.ImageEmbeddings] = None,
 943        i: Optional[int] = None,
 944        verbose: bool = False,
 945        pbar_init: Optional[callable] = None,
 946        pbar_update: Optional[callable] = None,
 947        ndim: int = 2,
 948    ) -> None:
 949        """Initialize image embeddings and decoder predictions for an image.
 950
 951        Args:
 952            image: The input image, volume or timeseries.
 953            image_embeddings: Optional precomputed image embeddings.
 954                See `util.precompute_image_embeddings` for details.
 955            i: Index for the image data. Required if `image` has three spatial dimensions
 956                or a time dimension and two spatial dimensions.
 957            verbose: Whether to be verbose. By default, set to 'False'.
 958            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 959                Can be used together with pbar_update to handle napari progress bar in other thread.
 960                To enables using this function within a threadworker.
 961            pbar_update: Callback to update an external progress bar.
 962            ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2.
 963        """
 964        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 965        pbar_init(1, "Initialize instance segmentation with decoder")
 966
 967        if image_embeddings is None:
 968            image_embeddings = util.precompute_image_embeddings(
 969                predictor=self._predictor, input_=image, ndim=ndim, verbose=verbose
 970            )
 971
 972        # Get the image embeddings from the predictor.
 973        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
 974        embeddings = self._predictor.features
 975        input_shape = tuple(self._predictor.input_size)
 976        original_shape = tuple(self._predictor.original_size)
 977
 978        # Run prediction with the UNETR decoder.
 979        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
 980        assert output.shape[0] == 3, f"{output.shape}"
 981        pbar_update(1)
 982        pbar_close()
 983
 984        # Set the state.
 985        self._foreground = output[0]
 986        self._center_distances = output[1]
 987        self._boundary_distances = output[2]
 988        self._i = i
 989        self._is_initialized = True
 990
 991    def _to_masks(self, segmentation, output_mode):
 992        if output_mode != "binary_mask":
 993            raise ValueError(
 994                f"Output mode {output_mode} is not supported. Choose one of 'instance_segmentation', 'binary_masks'"
 995            )
 996
 997        props = regionprops(segmentation)
 998        ndim = segmentation.ndim
 999        assert ndim in (2, 3)
1000
1001        shape = segmentation.shape
1002        if ndim == 2:
1003            crop_box = [0, shape[1], 0, shape[0]]
1004        else:
1005            crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]
1006
1007        # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
1008        def to_bbox_2d(bbox):
1009            y0, x0 = bbox[0], bbox[1]
1010            w = bbox[3] - x0
1011            h = bbox[2] - y0
1012            return [x0, w, y0, h]
1013
1014        def to_bbox_3d(bbox):
1015            z0, y0, x0 = bbox[0], bbox[1], bbox[2]
1016            w = bbox[5] - x0
1017            h = bbox[4] - y0
1018            d = bbox[3] - y0
1019            return [x0, w, y0, h, z0, d]
1020
1021        to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
1022        masks = [
1023            {
1024                "segmentation": segmentation == prop.label,
1025                "area": prop.area,
1026                "bbox": to_bbox(prop.bbox),
1027                "crop_box": crop_box,
1028                "seg_id": prop.label,
1029            } for prop in props
1030        ]
1031        return masks
1032
1033    def generate(
1034        self,
1035        center_distance_threshold: float = 0.5,
1036        boundary_distance_threshold: float = 0.5,
1037        foreground_threshold: float = 0.5,
1038        foreground_smoothing: float = 1.0,
1039        distance_smoothing: float = 1.6,
1040        min_size: int = 0,
1041        output_mode: str = "instance_segmentation",
1042        tile_shape: Optional[Tuple[int, int]] = None,
1043        halo: Optional[Tuple[int, int]] = None,
1044        n_threads: Optional[int] = None,
1045    ) -> Union[List[Dict[str, Any]], np.ndarray]:
1046        """Generate instance segmentation for the currently initialized image.
1047
1048        Args:
1049            center_distance_threshold: Center distance predictions below this value will be
1050                used to find seeds (intersected with thresholded boundary distance predictions).
1051                By default, set to '0.5'.
1052            boundary_distance_threshold: Boundary distance predictions below this value will be
1053                used to find seeds (intersected with thresholded center distance predictions).
1054                By default, set to '0.5'.
1055            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1056                By default, set to '0.5'.
1057            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1058                checkerboard artifacts in the prediction. By default, set to '1.0'.
1059            distance_smoothing: Sigma value for smoothing the distance predictions.
1060            min_size: Minimal object size in the segmentation result. By default, set to '0'.
1061            output_mode: The form masks are returned in. Possible values are:
1062                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1063                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1064                By default, set to 'instance_segmentation'.
1065            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1066                This parameter is independent from the tile shape for computing the embeddings.
1067                If not given then post-processing will not be parallelized.
1068            halo: Halo for parallel post-processing. See also `tile_shape`.
1069            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1070
1071        Returns:
1072            The segmentation masks.
1073        """
1074        if not self.is_initialized:
1075            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1076
1077        if foreground_smoothing > 0:
1078            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1079        else:
1080            foreground = self._foreground
1081
1082        if tile_shape is None:
1083            segmentation = watershed_from_center_and_boundary_distances(
1084                center_distances=self._center_distances,
1085                boundary_distances=self._boundary_distances,
1086                foreground_map=foreground,
1087                center_distance_threshold=center_distance_threshold,
1088                boundary_distance_threshold=boundary_distance_threshold,
1089                foreground_threshold=foreground_threshold,
1090                distance_smoothing=distance_smoothing,
1091                min_size=min_size,
1092            )
1093        else:
1094            if halo is None:
1095                raise ValueError("You must pass a value for halo if tile_shape is given.")
1096            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1097                center_distances=self._center_distances,
1098                boundary_distances=self._boundary_distances,
1099                foreground_map=foreground,
1100                center_distance_threshold=center_distance_threshold,
1101                boundary_distance_threshold=boundary_distance_threshold,
1102                foreground_threshold=foreground_threshold,
1103                distance_smoothing=distance_smoothing,
1104                min_size=min_size,
1105                tile_shape=tile_shape,
1106                halo=halo,
1107                n_threads=n_threads,
1108                verbose=False,
1109            )
1110
1111        if output_mode != "instance_segmentation":
1112            segmentation = self._to_masks(segmentation, output_mode)
1113        return segmentation
1114
1115    def get_state(self) -> Dict[str, Any]:
1116        """Get the initialized state of the instance segmenter.
1117
1118        Returns:
1119            Instance segmentation state.
1120        """
1121        if not self.is_initialized:
1122            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1123
1124        return {
1125            "foreground": self._foreground,
1126            "center_distances": self._center_distances,
1127            "boundary_distances": self._boundary_distances,
1128        }
1129
1130    def set_state(self, state: Dict[str, Any]) -> None:
1131        """Set the state of the instance segmenter.
1132
1133        Args:
1134            state: The instance segmentation state
1135        """
1136        self._foreground = state["foreground"]
1137        self._center_distances = state["center_distances"]
1138        self._boundary_distances = state["boundary_distances"]
1139        self._is_initialized = True
1140
1141    def clear_state(self):
1142        """Clear the state of the instance segmenter.
1143        """
1144        self._foreground = None
1145        self._center_distances = None
1146        self._boundary_distances = None
1147        self._is_initialized = False
1148
1149
1150class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1151    """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings.
1152    """
1153
1154    # Apply the decoder in a batched fashion, and then perform the resizing independently per output.
1155    # This is necessary, because the individual tiles may have different tile shapes due to border tiles.
1156    def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes):
1157        batched_embeddings = torch.cat(batched_embeddings)
1158        output = self._decoder._forward_impl(batched_embeddings)
1159
1160        batched_output = []
1161        for x, input_shape, original_shape in zip(output, input_shapes, original_shapes):
1162            x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0)
1163            batched_output.append(x.cpu().numpy())
1164        return batched_output
1165
1166    @torch.no_grad()
1167    def initialize(
1168        self,
1169        image: np.ndarray,
1170        image_embeddings: Optional[util.ImageEmbeddings] = None,
1171        i: Optional[int] = None,
1172        tile_shape: Optional[Tuple[int, int]] = None,
1173        halo: Optional[Tuple[int, int]] = None,
1174        verbose: bool = False,
1175        pbar_init: Optional[callable] = None,
1176        pbar_update: Optional[callable] = None,
1177        batch_size: int = 1,
1178        mask: Optional[np.typing.ArrayLike] = None,
1179    ) -> None:
1180        """Initialize image embeddings and decoder predictions for an image.
1181
1182        Args:
1183            image: The input image, volume or timeseries.
1184            image_embeddings: Optional precomputed image embeddings.
1185                See `util.precompute_image_embeddings` for details.
1186            i: Index for the image data. Required if `image` has three spatial dimensions
1187                or a time dimension and two spatial dimensions.
1188            tile_shape: Shape of the tiles for precomputing image embeddings.
1189            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1190            verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
1191            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1192                Can be used together with pbar_update to handle napari progress bar in other thread.
1193                To enables using this function within a threadworker.
1194            pbar_update: Callback to update an external progress bar.
1195            batch_size: The batch size for image embedding computation and segmentation decoder prediction.
1196            mask: An optional mask to define areas that are ignored in the segmentation.
1197        """
1198        original_size = image.shape[:2]
1199        self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings(
1200            self._predictor, image, image_embeddings, tile_shape, halo,
1201            verbose=verbose, batch_size=batch_size, mask=mask, i=i,
1202        )
1203        tiling = blocking([0, 0], original_size, tile_shape)
1204
1205        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1206
1207        foreground = np.zeros(original_size, dtype="float32")
1208        center_distances = np.zeros(original_size, dtype="float32")
1209        boundary_distances = np.zeros(original_size, dtype="float32")
1210
1211        msg = "Initialize tiled instance segmentation with decoder"
1212        if tiles_in_mask is None:
1213            n_tiles = tiling.numberOfBlocks
1214            all_tile_ids = list(range(n_tiles))
1215        else:
1216            n_tiles = len(tiles_in_mask)
1217            all_tile_ids = tiles_in_mask
1218            msg += " and mask"
1219
1220        n_batches = int(np.ceil(n_tiles / batch_size))
1221        pbar_init(n_tiles, msg)
1222        tile_ids_for_batches = np.array_split(all_tile_ids, n_batches)
1223
1224        for tile_ids in tile_ids_for_batches:
1225            batched_embeddings, input_shapes, original_shapes = [], [], []
1226            for tile_id in tile_ids:
1227                # Get the image embeddings from the predictor for this tile.
1228                self._predictor = util.set_precomputed(self._predictor, self._image_embeddings, i=i, tile_id=tile_id)
1229
1230                batched_embeddings.append(self._predictor.features)
1231                input_shapes.append(tuple(self._predictor.input_size))
1232                original_shapes.append(tuple(self._predictor.original_size))
1233
1234            batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes)
1235
1236            for output_id, tile_id in enumerate(tile_ids):
1237                output = batched_output[output_id]
1238                assert output.shape[0] == 3
1239
1240                # Set the predictions in the output for this tile.
1241                block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1242                local_bb = tuple(
1243                    slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1244                )
1245                inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1246
1247                foreground[inner_bb] = output[0][local_bb]
1248                center_distances[inner_bb] = output[1][local_bb]
1249                boundary_distances[inner_bb] = output[2][local_bb]
1250                pbar_update(1)
1251
1252        pbar_close()
1253
1254        # Set the state.
1255        self._i = i
1256        self._foreground = foreground
1257        self._center_distances = center_distances
1258        self._boundary_distances = boundary_distances
1259        self._is_initialized = True
1260
1261
1262def _get_centers(segmentation, avoid_image_border=True):
1263    # Not sure if 'find_boundaries' has to be paralellized.
1264    # Compute the distance transform to object bounaries.
1265    boundaries = find_boundaries(segmentation, mode="outer") == 0
1266    # Avoid centroids on the border.
1267    if avoid_image_border:
1268        boundaries[0, :] = False
1269        boundaries[:, 0] = False
1270        boundaries[-1, :] = False
1271        boundaries[:, -1] = False
1272    block_shape = (512, 512)
1273    halo = (16, 16)
1274    distances = parallel_impl.distance_transform(boundaries, halo=halo, block_shape=block_shape)
1275
1276    # Get the maximum coordinate of the distance transform for each id.
1277    props = regionprops(segmentation)
1278    centers = []
1279    for prop in props:
1280        seg_id = prop.label
1281        # Get the bounding box and mask for this segmentation id.
1282        bounding_box = np.s_[prop.bbox[0]:prop.bbox[2], prop.bbox[1]:prop.bbox[3]]
1283        mask = segmentation[bounding_box] == seg_id
1284        # Restrict the distances to the bounding box.
1285        dist = distances[bounding_box].copy()
1286        # Set the distances outside of the mask to 0.
1287        dist[~mask] = 0
1288        # Find the center coordinate (= distance maximum).
1289        center = np.argmax(dist)
1290        center = np.unravel_index(center, dist.shape)
1291        # Bring the center coordinate back to the full coordinate system.
1292        center = tuple(ce + bb.start for ce, bb in zip(center, bounding_box))
1293        centers.append(center)
1294
1295    return np.array(centers)
1296
1297
1298def _derive_point_prompts(
1299    foreground: np.ndarray,
1300    center_distances: np.ndarray,
1301    boundary_distances: np.ndarray,
1302    foreground_threshold: float = 0.5,
1303    center_distance_threshold: float = 0.5,
1304    boundary_distance_threshold: float = 0.5,
1305):
1306    bg_mask = foreground < foreground_threshold
1307    hmap_cc = np.logical_and(
1308        center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold,
1309    )
1310    hmap_cc[bg_mask] = 0
1311    cc = np.zeros_like(hmap_cc, dtype="uint32")
1312    cc = parallel_impl.label(hmap_cc, out=cc, block_shape=(512, 512))
1313    prompts = _get_centers(cc)
1314    if len(prompts) == 0:
1315        return None
1316
1317    points = prompts[:, None, ::-1]
1318    labels = np.ones((len(prompts), 1))
1319    return {"points": points, "point_labels": labels}
1320
1321
1322def _derive_box_prompts(predictions, box_extension):
1323    shape = predictions[0]["segmentation"].shape
1324    bboxes = [pred["bbox"] for pred in predictions]
1325    prompts = [[
1326        max(x - w * box_extension, 0),
1327        max(y - h * box_extension, 0),
1328        min(x + (1 + box_extension) * w, shape[0]),
1329        min(y + (1 + box_extension) * h, shape[1]),
1330    ] for (x, y, w, h) in bboxes]
1331    return {"boxes": np.array(prompts)}
1332
1333
1334class AutomaticPromptGenerator(InstanceSegmentationWithDecoder):
1335    """Generates an instance segmentation automatically, using automatically generated prompts from a decoder.
1336
1337    This class is used in the same way as `InstanceSegmentationWithDecoder` and `AutomaticMaskGenerator`
1338
1339    Args:
1340        predictor: The segment anything predictor.
1341        decoder: The derive prompts for automatic instance segmentation.
1342    """
1343    def generate(
1344        self,
1345        min_size: int = 25,
1346        center_distance_threshold: float = 0.5,
1347        boundary_distance_threshold: float = 0.5,
1348        foreground_threshold: float = 0.5,
1349        multimasking: bool = False,
1350        batch_size: int = 32,
1351        nms_threshold: float = 0.9,
1352        intersection_over_min: bool = False,
1353        output_mode: str = "instance_segmentation",
1354        mask_threshold: Optional[Union[float, str]] = None,
1355        refine_with_box_prompts: bool = False,
1356        prompt_function: Optional[callable] = None,
1357    ) -> Union[List[Dict[str, Any]], np.ndarray]:
1358        """Generate the instance segmentation for the currently initialized image.
1359
1360        The instance segmentation is generated by deriving prompts from the foreground and
1361        distance predictions of the segmentation decoder by thresholding these predictions,
1362        intersecting them, computing connected components, and then using the component's
1363        centers as point prompts. The masks are then filtered via NMS and merged into a segmentation.
1364
1365        Args:
1366            min_size: Minimal object size in the segmentation result. By default, set to '25'.
1367            center_distance_threshold: The threshold for the center distance predictions.
1368            boundary_distance_threshold: The threshold for the boundary distance predictions.
1369            multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
1370            batch_size: The batch size for parallelizing the prediction based on prompts.
1371            nms_threshold: The threshold for non-maximum suppression (NMS).
1372            intersection_over_min: Whether to use the minimum area of the two objects or the
1373                intersection over union area (default) in NMS.
1374            output_mode: The form masks are returned in. Possible values are:
1375                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1376                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1377                By default, set to 'instance_segmentation'.
1378            mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.`
1379            refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps
1380                derived from the segmentations after point prompts.
1381            prompt_function: A custom function for deriving prompts from the segmentation decoder predictions.
1382                If given, the default prompt derivation procedure is not used. Must have the following signature:
1383                ```
1384                    def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
1385                ```
1386                where `foreground`, `center_distances`, and `boundary_distances` are the respective
1387                predictions from the segmentation decoder. It must returns a dictionary containing
1388                either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`.
1389
1390        Returns:
1391            The instance segmentation masks.
1392        """
1393        if not self.is_initialized:
1394            raise RuntimeError("AutomaticPromptGenerator has not been initialized. Call initialize first.")
1395        foreground, center_distances, boundary_distances =\
1396            self._foreground, self._center_distances, self._boundary_distances
1397
1398        # 1.) Derive promtps from the decoder predictions.
1399        prompt_function = _derive_point_prompts if prompt_function is None else prompt_function
1400        prompts = prompt_function(
1401            foreground=foreground,
1402            center_distances=center_distances,
1403            boundary_distances=boundary_distances,
1404            foreground_threshold=foreground_threshold,
1405            center_distance_threshold=center_distance_threshold,
1406            boundary_distance_threshold=boundary_distance_threshold,
1407        )
1408
1409        # 2.) Apply the predictor to the prompts.
1410        if prompts is None:  # No prompts were derived, we can't do much further and return empty masks.
1411            return np.zeros(foreground.shape, dtype="uint32") if output_mode == "instance_egmentation" else []
1412        else:
1413            predictions = batched_inference(
1414                self._predictor,
1415                image=None,
1416                batch_size=batch_size,
1417                return_instance_segmentation=False,
1418                multimasking=multimasking,
1419                mask_threshold=mask_threshold,
1420                i=getattr(self, "_i", None),
1421                **prompts,
1422            )
1423
1424        # 3.) Refine the segmentation with box prompts.
1425        if refine_with_box_prompts:
1426            box_extension = 0.01  # expose as hyperparam?
1427            prompts = _derive_box_prompts(predictions, box_extension)
1428            predictions = batched_inference(
1429                self._predictor,
1430                image=None,
1431                batch_size=batch_size,
1432                return_instance_segmentation=False,
1433                multimasking=multimasking,
1434                mask_threshold=mask_threshold,
1435                i=getattr(self, "_i", None),
1436                **prompts,
1437            )
1438
1439        # 4.) Apply non-max suppression to the masks.
1440        segmentation = util.apply_nms(
1441            predictions, min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min
1442        )
1443        if output_mode != "instance_segmentation":
1444            segmentation = self._to_masks(segmentation, output_mode)
1445        return segmentation
1446
1447
1448class TiledAutomaticPromptGenerator(TiledInstanceSegmentationWithDecoder):
1449    """Same as `AutomaticPromptGenerator` but for tiled image embeddings.
1450    """
1451    def generate(
1452        self,
1453        min_size: int = 25,
1454        center_distance_threshold: float = 0.5,
1455        boundary_distance_threshold: float = 0.5,
1456        foreground_threshold: float = 0.5,
1457        multimasking: bool = False,
1458        batch_size: int = 32,
1459        nms_threshold: float = 0.9,
1460        intersection_over_min: bool = False,
1461        output_mode: str = "instance_segmentation",
1462        mask_threshold: Optional[Union[float, str]] = None,
1463        refine_with_box_prompts: bool = False,
1464        prompt_function: Optional[callable] = None,
1465        optimize_memory: bool = False,
1466    ) -> List[Dict[str, Any]]:
1467        """Generate tiling-based instance segmentation for the currently initialized image.
1468
1469        Args:
1470            min_size: Minimal object size in the segmentation result. By default, set to '25'.
1471            center_distance_threshold: The threshold for the center distance predictions.
1472            boundary_distance_threshold: The threshold for the boundary distance predictions.
1473            multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
1474            batch_size: The batch size for parallelizing the prediction based on prompts.
1475            nms_threshold: The threshold for non-maximum suppression (NMS).
1476            intersection_over_min: Whether to use the minimum area of the two objects or the
1477                intersection over union area (default) in NMS.
1478            output_mode: The form masks are returned in. Possible values are:
1479                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1480                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1481                By default, set to 'instance_segmentation'.
1482            mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.`
1483            refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps
1484                derived from the segmentations after point prompts. Currently not supported for tiled segmentation.
1485            prompt_function: A custom function for deriving prompts from the segmentation decoder predictions.
1486                If given, the default prompt derivation procedure is not used. Must have the following signature:
1487                ```
1488                    def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
1489                ```
1490                where `foreground`, `center_distances`, and `boundary_distances` are the respective
1491                predictions from the segmentation decoder. It must returns a dictionary containing
1492                either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`.
1493            optimize_memory: Whether to optimize the memory consumption by merging the per-slice
1494                segmentation results immediatly with NMS, rather than running a NMS for all results.
1495                This may lead to a slightly different segmentation result and is only compatible with
1496                `refine_with_box_prompts=False` and `output_mode="instance_segmentation"`.
1497
1498        Returns:
1499            The instance segmentation masks.
1500        """
1501        if not self.is_initialized:
1502            raise RuntimeError("TiledAutomaticPromptGenerator has not been initialized. Call initialize first.")
1503        if optimize_memory and (output_mode != "instance_segmentation" or refine_with_box_prompts):
1504            raise ValueError("Invalid settings")
1505        foreground, center_distances, boundary_distances =\
1506            self._foreground, self._center_distances, self._boundary_distances
1507
1508        # 1.) Derive promtps from the decoder predictions.
1509        prompt_function = _derive_point_prompts if prompt_function is None else prompt_function
1510        prompts = prompt_function(
1511            foreground,
1512            center_distances,
1513            boundary_distances,
1514            foreground_threshold=foreground_threshold,
1515            center_distance_threshold=center_distance_threshold,
1516            boundary_distance_threshold=boundary_distance_threshold,
1517        )
1518
1519        # 2.) Apply the predictor to the prompts.
1520        shape = foreground.shape
1521        if prompts is None:  # No prompts were derived, we can't do much further and return empty masks.
1522            return np.zeros(shape, dtype="uint32") if output_mode == "instance_segmentation" else []
1523        else:
1524            if optimize_memory:
1525                prompts.update(dict(
1526                    min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min
1527                ))
1528            predictions = batched_tiled_inference(
1529                self._predictor,
1530                image=None,
1531                batch_size=batch_size,
1532                image_embeddings=self._image_embeddings,
1533                return_instance_segmentation=False,
1534                multimasking=multimasking,
1535                optimize_memory=optimize_memory,
1536                i=getattr(self, "_i", None),
1537                **prompts
1538            )
1539        # Optimize memory directly returns an instance segmentation and does not
1540        # allow for any further refinements.
1541        if optimize_memory:
1542            return predictions
1543
1544        # 3.) Refine the segmentation with box prompts.
1545        if refine_with_box_prompts:
1546            # TODO
1547            raise NotImplementedError
1548
1549        # 4.) Apply non-max suppression to the masks.
1550        segmentation = util.apply_nms(
1551            predictions, shape=shape, min_size=min_size, nms_thresh=nms_threshold,
1552            intersection_over_min=intersection_over_min,
1553        )
1554        if output_mode != "instance_segmentation":
1555            segmentation = self._to_masks(segmentation, output_mode)
1556        return segmentation
1557
1558    # Set state and get state are not implemented yet, as this generator relies on having the image embeddings
1559    # in the state. However, they should not be serialized here and we have to address this a bit differently.
1560    def get_state(self):
1561        """@private
1562        """
1563        raise NotImplementedError
1564
1565    def set_state(self, state):
1566        """@private
1567        """
1568        raise NotImplementedError
1569
1570
1571def get_instance_segmentation_generator(
1572    predictor: SamPredictor,
1573    is_tiled: bool,
1574    decoder: Optional[torch.nn.Module] = None,
1575    segmentation_mode: Optional[str] = None,
1576    **kwargs,
1577) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1578    f"""Get the automatic mask generator.
1579
1580    Args:
1581        predictor: The segment anything predictor.
1582        is_tiled: Whether tiled embeddings are used.
1583        decoder: Decoder to predict instacne segmmentation.
1584        segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
1585            By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used if a decoder is passed,
1586            otherwise 'amg' is used.
1587        kwargs: The keyword arguments of the segmentation genetator class.
1588
1589    Returns:
1590        The segmentation generator instance.
1591    """
1592    # Choose the segmentation decoder default depending on whether we have a decoder.
1593    if segmentation_mode is None:
1594        segmentation_mode = "amg" if decoder is None else DEFAULT_SEGMENTATION_MODE_WITH_DECODER
1595
1596    if segmentation_mode.lower() == "amg":
1597        segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator
1598        segmenter = segmenter_class(predictor, **kwargs)
1599    elif segmentation_mode.lower() == "ais":
1600        assert decoder is not None
1601        segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder
1602        segmenter = segmenter_class(predictor, decoder, **kwargs)
1603    elif segmentation_mode.lower() == "apg":
1604        assert decoder is not None
1605        segmenter_class = TiledAutomaticPromptGenerator if is_tiled else AutomaticPromptGenerator
1606        segmenter = segmenter_class(predictor, decoder, **kwargs)
1607    else:
1608        raise ValueError(f"Invalid segmentation_mode: {segmentation_mode}. Choose one of 'amg', 'ais', or 'apg'.")
1609
1610    return segmenter
DEFAULT_SEGMENTATION_MODE_WITH_DECODER = 'ais'
class AMGBase(abc.ABC):
 59class AMGBase(ABC):
 60    """Base class for the automatic mask generators.
 61    """
 62    def __init__(self):
 63        # the state that has to be computed by the 'initialize' method of the child classes
 64        self._is_initialized = False
 65        self._crop_list = None
 66        self._crop_boxes = None
 67        self._original_size = None
 68
 69    @property
 70    def is_initialized(self):
 71        """Whether the mask generator has already been initialized.
 72        """
 73        return self._is_initialized
 74
 75    @property
 76    def crop_list(self):
 77        """The list of mask data after initialization.
 78        """
 79        return self._crop_list
 80
 81    @property
 82    def crop_boxes(self):
 83        """The list of crop boxes.
 84        """
 85        return self._crop_boxes
 86
 87    @property
 88    def original_size(self):
 89        """The original image size.
 90        """
 91        return self._original_size
 92
 93    def _postprocess_batch(
 94        self,
 95        data,
 96        crop_box,
 97        original_size,
 98        pred_iou_thresh,
 99        stability_score_thresh,
100        box_nms_thresh,
101    ):
102        orig_h, orig_w = original_size
103
104        # filter by predicted IoU
105        if pred_iou_thresh > 0.0:
106            keep_mask = data["iou_preds"] > pred_iou_thresh
107            data.filter(keep_mask)
108
109        # filter by stability score
110        if stability_score_thresh > 0.0:
111            keep_mask = data["stability_score"] >= stability_score_thresh
112            data.filter(keep_mask)
113
114        # filter boxes that touch crop boundaries
115        keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
116        if not torch.all(keep_mask):
117            data.filter(keep_mask)
118
119        # remove duplicates within this crop.
120        keep_by_nms = batched_nms(
121            data["boxes"].float(),
122            data["iou_preds"],
123            torch.zeros_like(data["boxes"][:, 0]),  # categories
124            iou_threshold=box_nms_thresh,
125        )
126        data.filter(keep_by_nms)
127
128        # return to the original image frame
129        data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box)
130        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
131        # the data from embedding based segmentation doesn't have the points
132        # so we skip if the corresponding key can't be found
133        try:
134            data["points"] = amg_utils.uncrop_points(data["points"], crop_box)
135        except KeyError:
136            pass
137
138        return data
139
140    def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
141
142        if len(mask_data["rles"]) == 0:
143            return mask_data
144
145        # filter small disconnected regions and holes
146        new_masks = []
147        scores = []
148        for rle in mask_data["rles"]:
149            mask = amg_utils.rle_to_mask(rle)
150
151            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes")
152            unchanged = not changed
153            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands")
154            unchanged = unchanged and not changed
155
156            new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0))
157            # give score=0 to changed masks and score=1 to unchanged masks
158            # so NMS will prefer ones that didn't need postprocessing
159            scores.append(float(unchanged))
160
161        # recalculate boxes and remove any new duplicates
162        masks = torch.cat(new_masks, dim=0)
163        boxes = batched_mask_to_box(masks.to(torch.bool))  # Casting this to boolean as we work with one-hot labels.
164        keep_by_nms = batched_nms(
165            boxes.float(),
166            torch.as_tensor(scores, dtype=torch.float),
167            torch.zeros_like(boxes[:, 0]),  # categories
168            iou_threshold=nms_thresh,
169        )
170
171        # only recalculate RLEs for masks that have changed
172        for i_mask in keep_by_nms:
173            if scores[i_mask] == 0.0:
174                mask_torch = masks[i_mask].unsqueeze(0)
175                # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
176                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
177                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
178        mask_data.filter(keep_by_nms)
179
180        return mask_data
181
182    def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode):
183        # filter small disconnected regions and holes in masks
184        if min_mask_region_area > 0:
185            mask_data = self._postprocess_small_regions(
186                mask_data,
187                min_mask_region_area,
188                max(box_nms_thresh, crop_nms_thresh),
189            )
190
191        # encode masks
192        if output_mode == "coco_rle":
193            mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]]
194        # We require the binary mask output for creating the instance-seg output.
195        elif output_mode in ("binary_mask", "instance_segmentation"):
196            mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]]
197        elif output_mode == "rle":
198            mask_data["segmentations"] = mask_data["rles"]
199        else:
200            raise ValueError(f"Invalid output mode {output_mode}.")
201
202        # write mask records
203        curr_anns = []
204        for idx in range(len(mask_data["segmentations"])):
205            ann = {
206                "segmentation": mask_data["segmentations"][idx],
207                "area": amg_utils.area_from_rle(mask_data["rles"][idx]),
208                "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
209                "predicted_iou": mask_data["iou_preds"][idx].item(),
210                "stability_score": mask_data["stability_score"][idx].item(),
211                "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
212            }
213            # the data from embedding based segmentation doesn't have the points
214            # so we skip if the corresponding key can't be found
215            try:
216                ann["point_coords"] = [mask_data["points"][idx].tolist()]
217            except KeyError:
218                pass
219            curr_anns.append(ann)
220
221        return curr_anns
222
223    def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
224        orig_h, orig_w = original_size
225
226        # serialize predictions and store in MaskData
227        data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
228        if points is not None:
229            data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
230
231        del masks
232
233        # calculate the stability scores
234        data["stability_score"] = amg_utils.calculate_stability_score(
235            data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
236        )
237
238        # threshold masks and calculate boxes
239        data["masks"] = data["masks"] > self._predictor.model.mask_threshold
240        data["masks"] = data["masks"].type(torch.bool)
241        data["boxes"] = batched_mask_to_box(data["masks"])
242
243        # compress to RLE
244        data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
245        # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
246        data["rles"] = mask_to_rle_pytorch(data["masks"])
247        del data["masks"]
248
249        return data
250
251    def get_state(self) -> Dict[str, Any]:
252        """Get the initialized state of the mask generator.
253
254        Returns:
255            State of the mask generator.
256        """
257        if not self.is_initialized:
258            raise RuntimeError("The state has not been computed yet. Call initialize first.")
259
260        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
261
262    def set_state(self, state: Dict[str, Any]) -> None:
263        """Set the state of the mask generator.
264
265        Args:
266            state: The state of the mask generator, e.g. from serialized state.
267        """
268        self._crop_list = state["crop_list"]
269        self._crop_boxes = state["crop_boxes"]
270        self._original_size = state["original_size"]
271        self._is_initialized = True
272
273    def clear_state(self):
274        """Clear the state of the mask generator.
275        """
276        self._crop_list = None
277        self._crop_boxes = None
278        self._original_size = None
279        self._is_initialized = False

Base class for the automatic mask generators.

is_initialized
69    @property
70    def is_initialized(self):
71        """Whether the mask generator has already been initialized.
72        """
73        return self._is_initialized

Whether the mask generator has already been initialized.

crop_list
75    @property
76    def crop_list(self):
77        """The list of mask data after initialization.
78        """
79        return self._crop_list

The list of mask data after initialization.

crop_boxes
81    @property
82    def crop_boxes(self):
83        """The list of crop boxes.
84        """
85        return self._crop_boxes

The list of crop boxes.

original_size
87    @property
88    def original_size(self):
89        """The original image size.
90        """
91        return self._original_size

The original image size.

def get_state(self) -> Dict[str, Any]:
251    def get_state(self) -> Dict[str, Any]:
252        """Get the initialized state of the mask generator.
253
254        Returns:
255            State of the mask generator.
256        """
257        if not self.is_initialized:
258            raise RuntimeError("The state has not been computed yet. Call initialize first.")
259
260        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}

Get the initialized state of the mask generator.

Returns:

State of the mask generator.

def set_state(self, state: Dict[str, Any]) -> None:
262    def set_state(self, state: Dict[str, Any]) -> None:
263        """Set the state of the mask generator.
264
265        Args:
266            state: The state of the mask generator, e.g. from serialized state.
267        """
268        self._crop_list = state["crop_list"]
269        self._crop_boxes = state["crop_boxes"]
270        self._original_size = state["original_size"]
271        self._is_initialized = True

Set the state of the mask generator.

Arguments:
  • state: The state of the mask generator, e.g. from serialized state.
def clear_state(self):
273    def clear_state(self):
274        """Clear the state of the mask generator.
275        """
276        self._crop_list = None
277        self._crop_boxes = None
278        self._original_size = None
279        self._is_initialized = False

Clear the state of the mask generator.

class AutomaticMaskGenerator(AMGBase):
282class AutomaticMaskGenerator(AMGBase):
283    """Generates an instance segmentation without prompts, using a point grid.
284
285    This class implements the same logic as
286    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
287    It decouples the computationally expensive steps of generating masks from the cheap post-processing operation
288    to filter these masks to enable grid search and interactively changing the post-processing.
289
290    Use this class as follows:
291    ```python
292    amg = AutomaticMaskGenerator(predictor)
293    amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
294    masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
295    ```
296
297    Args:
298        predictor: The segment anything predictor.
299        points_per_side: The number of points to be sampled along one side of the image.
300            If None, `point_grids` must provide explicit point sampling. By default, set to '32'.
301        points_per_batch: The number of points run simultaneously by the model.
302            Higher numbers may be faster but use more GPU memory.
303            By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons).
304        crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
305            By default, set to '0'.
306        crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'.
307        crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
308            By default, set to '1'.
309        point_grids: A list over explicit grids of points used for sampling masks.
310            Normalized to [0, 1] with respect to the image coordinate system.
311        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
312            By default, set to '1.0'.
313    """
314    def __init__(
315        self,
316        predictor: SamPredictor,
317        points_per_side: Optional[int] = 32,
318        points_per_batch: Optional[int] = None,
319        crop_n_layers: int = 0,
320        crop_overlap_ratio: float = 512 / 1500,
321        crop_n_points_downscale_factor: int = 1,
322        point_grids: Optional[List[np.ndarray]] = None,
323        stability_score_offset: float = 1.0,
324    ):
325        super().__init__()
326
327        if points_per_side is not None:
328            self.point_grids = amg_utils.build_all_layer_point_grids(
329                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
330            )
331        elif point_grids is not None:
332            self.point_grids = point_grids
333        else:
334            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
335
336        self._predictor = predictor
337        self._points_per_side = points_per_side
338
339        # we set the points per batch to 16 for mps for performance reasons
340        # and otherwise keep them at the default of 64
341        if points_per_batch is None:
342            points_per_batch = 16 if str(predictor.device) == "mps" else 64
343        self._points_per_batch = points_per_batch
344
345        self._crop_n_layers = crop_n_layers
346        self._crop_overlap_ratio = crop_overlap_ratio
347        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
348        self._stability_score_offset = stability_score_offset
349
350    def _process_batch(self, points, im_size, crop_box, original_size):
351        # run model on this batch
352        transformed_points = self._predictor.transform.apply_coords(points, im_size)
353        in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
354        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
355        masks, iou_preds, _ = self._predictor.predict_torch(
356            point_coords=in_points[:, None, :],
357            point_labels=in_labels[:, None],
358            multimask_output=True,
359            return_logits=True,
360        )
361        data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
362        del masks
363        return data
364
365    def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None):
366        # Crop the image and calculate embeddings.
367        x0, y0, x1, y1 = crop_box
368        cropped_im = image[y0:y1, x0:x1, :]
369        cropped_im_size = cropped_im.shape[:2]
370
371        if not precomputed_embeddings:
372            self._predictor.set_image(cropped_im)
373
374        # Get the points for this crop.
375        points_scale = np.array(cropped_im_size)[None, ::-1]
376        points_for_image = self.point_grids[crop_layer_idx] * points_scale
377
378        # Generate masks for this crop in batches.
379        data = amg_utils.MaskData()
380        n_batches = len(points_for_image) // self._points_per_batch +\
381            int(len(points_for_image) % self._points_per_batch != 0)
382        if pbar_init is not None:
383            pbar_init(n_batches, "Predict masks for point grid prompts")
384
385        for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image):
386            batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
387            data.cat(batch_data)
388            del batch_data
389            if pbar_update is not None:
390                pbar_update(1)
391
392        if not precomputed_embeddings:
393            self._predictor.reset_image()
394
395        return data
396
397    @torch.no_grad()
398    def initialize(
399        self,
400        image: np.ndarray,
401        image_embeddings: Optional[util.ImageEmbeddings] = None,
402        i: Optional[int] = None,
403        verbose: bool = False,
404        pbar_init: Optional[callable] = None,
405        pbar_update: Optional[callable] = None,
406    ) -> None:
407        """Initialize image embeddings and masks for an image.
408
409        Args:
410            image: The input image, volume or timeseries.
411            image_embeddings: Optional precomputed image embeddings.
412                See `util.precompute_image_embeddings` for details.
413            i: Index for the image data. Required if `image` has three spatial dimensions
414                or a time dimension and two spatial dimensions.
415            verbose: Whether to print computation progress. By default, set to 'False'.
416            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
417                Can be used together with pbar_update to handle napari progress bar in other thread.
418                To enables using this function within a threadworker.
419            pbar_update: Callback to update an external progress bar.
420        """
421        original_size = image.shape[:2]
422        self._original_size = original_size
423
424        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
425            original_size, self._crop_n_layers, self._crop_overlap_ratio
426        )
427
428        # We can set fixed image embeddings if we only have a single crop box (the default setting).
429        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
430        if len(crop_boxes) == 1:
431            if image_embeddings is None:
432                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
433            util.set_precomputed(self._predictor, image_embeddings, i=i)
434            precomputed_embeddings = True
435        else:
436            precomputed_embeddings = False
437
438        # we need to cast to the image representation that is compatible with SAM
439        image = util._to_image(image)
440
441        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
442
443        crop_list = []
444        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
445            crop_data = self._process_crop(
446                image, crop_box, layer_idx,
447                precomputed_embeddings=precomputed_embeddings,
448                pbar_init=pbar_init, pbar_update=pbar_update,
449            )
450            crop_list.append(crop_data)
451        pbar_close()
452
453        self._is_initialized = True
454        self._crop_list = crop_list
455        self._crop_boxes = crop_boxes
456
457    @torch.no_grad()
458    def generate(
459        self,
460        pred_iou_thresh: float = 0.88,
461        stability_score_thresh: float = 0.95,
462        box_nms_thresh: float = 0.7,
463        crop_nms_thresh: float = 0.7,
464        min_mask_region_area: int = 0,
465        output_mode: str = "instance_segmentation",
466        with_background: bool = True,
467    ) -> Union[List[Dict[str, Any]], np.ndarray]:
468        """Generate instance segmentation for the currently initialized image.
469
470        Args:
471            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
472                By default, set to '0.88'.
473            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
474                under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
475            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
476                By default, set to '0.7'.
477            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
478                By default, set to '0.7'.
479            min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
480            output_mode: The form masks are returned in. Possible values are:
481                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
482                - 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format.
483                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
484                - 'rle': Return a list of dictionaries with run-length encoded masks.
485                By default, set to 'instance_segmentation'.
486            with_background: Whether to remove the largest object, which often covers the background.
487
488        Returns:
489            The segmentation masks.
490        """
491        if not self.is_initialized:
492            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
493
494        data = amg_utils.MaskData()
495        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
496            crop_data = self._postprocess_batch(
497                data=deepcopy(data_),
498                crop_box=crop_box, original_size=self.original_size,
499                pred_iou_thresh=pred_iou_thresh,
500                stability_score_thresh=stability_score_thresh,
501                box_nms_thresh=box_nms_thresh
502            )
503            data.cat(crop_data)
504
505        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
506            # Prefer masks from smaller crops
507            scores = 1 / box_area(data["crop_boxes"])
508            scores = scores.to(data["boxes"].device)
509            keep_by_nms = batched_nms(
510                data["boxes"].float(),
511                scores,
512                torch.zeros_like(data["boxes"][:, 0]),  # categories
513                iou_threshold=crop_nms_thresh,
514            )
515            data.filter(keep_by_nms)
516
517        data.to_numpy()
518        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
519        if output_mode == "instance_segmentation":
520            shape = next(iter(masks))["segmentation"].shape if len(masks) > 0 else self.original_size
521            masks = util.mask_data_to_segmentation(
522                masks, shape=shape, with_background=with_background, merge_exclusively=False
523            )
524        return masks

Generates an instance segmentation without prompts, using a point grid.

This class implements the same logic as https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py It decouples the computationally expensive steps of generating masks from the cheap post-processing operation to filter these masks to enable grid search and interactively changing the post-processing.

Use this class as follows:

amg = AutomaticMaskGenerator(predictor)
amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
Arguments:
  • predictor: The segment anything predictor.
  • points_per_side: The number of points to be sampled along one side of the image. If None, point_grids must provide explicit point sampling. By default, set to '32'.
  • points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons).
  • crop_n_layers: If >0, the mask prediction will be run again on crops of the image. By default, set to '0'.
  • crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'.
  • crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. By default, set to '1'.
  • point_grids: A list over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
  • stability_score_offset: The amount to shift the cutoff when calculating the stability score. By default, set to '1.0'.
AutomaticMaskGenerator( predictor: segment_anything.predictor.SamPredictor, points_per_side: Optional[int] = 32, points_per_batch: Optional[int] = None, crop_n_layers: int = 0, crop_overlap_ratio: float = 0.3413333333333333, crop_n_points_downscale_factor: int = 1, point_grids: Optional[List[numpy.ndarray]] = None, stability_score_offset: float = 1.0)
314    def __init__(
315        self,
316        predictor: SamPredictor,
317        points_per_side: Optional[int] = 32,
318        points_per_batch: Optional[int] = None,
319        crop_n_layers: int = 0,
320        crop_overlap_ratio: float = 512 / 1500,
321        crop_n_points_downscale_factor: int = 1,
322        point_grids: Optional[List[np.ndarray]] = None,
323        stability_score_offset: float = 1.0,
324    ):
325        super().__init__()
326
327        if points_per_side is not None:
328            self.point_grids = amg_utils.build_all_layer_point_grids(
329                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
330            )
331        elif point_grids is not None:
332            self.point_grids = point_grids
333        else:
334            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
335
336        self._predictor = predictor
337        self._points_per_side = points_per_side
338
339        # we set the points per batch to 16 for mps for performance reasons
340        # and otherwise keep them at the default of 64
341        if points_per_batch is None:
342            points_per_batch = 16 if str(predictor.device) == "mps" else 64
343        self._points_per_batch = points_per_batch
344
345        self._crop_n_layers = crop_n_layers
346        self._crop_overlap_ratio = crop_overlap_ratio
347        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
348        self._stability_score_offset = stability_score_offset
@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
397    @torch.no_grad()
398    def initialize(
399        self,
400        image: np.ndarray,
401        image_embeddings: Optional[util.ImageEmbeddings] = None,
402        i: Optional[int] = None,
403        verbose: bool = False,
404        pbar_init: Optional[callable] = None,
405        pbar_update: Optional[callable] = None,
406    ) -> None:
407        """Initialize image embeddings and masks for an image.
408
409        Args:
410            image: The input image, volume or timeseries.
411            image_embeddings: Optional precomputed image embeddings.
412                See `util.precompute_image_embeddings` for details.
413            i: Index for the image data. Required if `image` has three spatial dimensions
414                or a time dimension and two spatial dimensions.
415            verbose: Whether to print computation progress. By default, set to 'False'.
416            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
417                Can be used together with pbar_update to handle napari progress bar in other thread.
418                To enables using this function within a threadworker.
419            pbar_update: Callback to update an external progress bar.
420        """
421        original_size = image.shape[:2]
422        self._original_size = original_size
423
424        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
425            original_size, self._crop_n_layers, self._crop_overlap_ratio
426        )
427
428        # We can set fixed image embeddings if we only have a single crop box (the default setting).
429        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
430        if len(crop_boxes) == 1:
431            if image_embeddings is None:
432                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
433            util.set_precomputed(self._predictor, image_embeddings, i=i)
434            precomputed_embeddings = True
435        else:
436            precomputed_embeddings = False
437
438        # we need to cast to the image representation that is compatible with SAM
439        image = util._to_image(image)
440
441        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
442
443        crop_list = []
444        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
445            crop_data = self._process_crop(
446                image, crop_box, layer_idx,
447                precomputed_embeddings=precomputed_embeddings,
448                pbar_init=pbar_init, pbar_update=pbar_update,
449            )
450            crop_list.append(crop_data)
451        pbar_close()
452
453        self._is_initialized = True
454        self._crop_list = crop_list
455        self._crop_boxes = crop_boxes

Initialize image embeddings and masks for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to print computation progress. By default, set to 'False'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
@torch.no_grad()
def generate( self, pred_iou_thresh: float = 0.88, stability_score_thresh: float = 0.95, box_nms_thresh: float = 0.7, crop_nms_thresh: float = 0.7, min_mask_region_area: int = 0, output_mode: str = 'instance_segmentation', with_background: bool = True) -> Union[List[Dict[str, Any]], numpy.ndarray]:
457    @torch.no_grad()
458    def generate(
459        self,
460        pred_iou_thresh: float = 0.88,
461        stability_score_thresh: float = 0.95,
462        box_nms_thresh: float = 0.7,
463        crop_nms_thresh: float = 0.7,
464        min_mask_region_area: int = 0,
465        output_mode: str = "instance_segmentation",
466        with_background: bool = True,
467    ) -> Union[List[Dict[str, Any]], np.ndarray]:
468        """Generate instance segmentation for the currently initialized image.
469
470        Args:
471            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
472                By default, set to '0.88'.
473            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
474                under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
475            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
476                By default, set to '0.7'.
477            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
478                By default, set to '0.7'.
479            min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
480            output_mode: The form masks are returned in. Possible values are:
481                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
482                - 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format.
483                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
484                - 'rle': Return a list of dictionaries with run-length encoded masks.
485                By default, set to 'instance_segmentation'.
486            with_background: Whether to remove the largest object, which often covers the background.
487
488        Returns:
489            The segmentation masks.
490        """
491        if not self.is_initialized:
492            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
493
494        data = amg_utils.MaskData()
495        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
496            crop_data = self._postprocess_batch(
497                data=deepcopy(data_),
498                crop_box=crop_box, original_size=self.original_size,
499                pred_iou_thresh=pred_iou_thresh,
500                stability_score_thresh=stability_score_thresh,
501                box_nms_thresh=box_nms_thresh
502            )
503            data.cat(crop_data)
504
505        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
506            # Prefer masks from smaller crops
507            scores = 1 / box_area(data["crop_boxes"])
508            scores = scores.to(data["boxes"].device)
509            keep_by_nms = batched_nms(
510                data["boxes"].float(),
511                scores,
512                torch.zeros_like(data["boxes"][:, 0]),  # categories
513                iou_threshold=crop_nms_thresh,
514            )
515            data.filter(keep_by_nms)
516
517        data.to_numpy()
518        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
519        if output_mode == "instance_segmentation":
520            shape = next(iter(masks))["segmentation"].shape if len(masks) > 0 else self.original_size
521            masks = util.mask_data_to_segmentation(
522                masks, shape=shape, with_background=with_background, merge_exclusively=False
523            )
524        return masks

Generate instance segmentation for the currently initialized image.

Arguments:
  • pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. By default, set to '0.88'.
  • stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
  • box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. By default, set to '0.7'.
  • crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. By default, set to '0.7'.
  • min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
  • output_mode: The form masks are returned in. Possible values are:
    • 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
    • 'coco_rle': Return a list of dictionaries with run-length encoded masks in MS COCO format.
    • 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
    • 'rle': Return a list of dictionaries with run-length encoded masks. By default, set to 'instance_segmentation'.
  • with_background: Whether to remove the largest object, which often covers the background.
Returns:

The segmentation masks.

class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
558class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
559    """Generates an instance segmentation without prompts, using a point grid.
560
561    Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
562
563    Args:
564        predictor: The Segment Anything predictor.
565        points_per_side: The number of points to be sampled along one side of the image.
566            If None, `point_grids` must provide explicit point sampling. By default, set to '32'.
567        points_per_batch: The number of points run simultaneously by the model.
568            Higher numbers may be faster but use more GPU memory. By default, set to '64'.
569        point_grids: A list over explicit grids of points used for sampling masks.
570            Normalized to [0, 1] with respect to the image coordinate system.
571        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
572            By default, set to '1.0'.
573    """
574
575    # We only expose the arguments that make sense for the tiled mask generator.
576    # Anything related to crops doesn't make sense, because we re-use that functionality
577    # for tiling, so these parameters wouldn't have any effect.
578    def __init__(
579        self,
580        predictor: SamPredictor,
581        points_per_side: Optional[int] = 32,
582        points_per_batch: int = 64,
583        point_grids: Optional[List[np.ndarray]] = None,
584        stability_score_offset: float = 1.0,
585    ) -> None:
586        super().__init__(
587            predictor=predictor,
588            points_per_side=points_per_side,
589            points_per_batch=points_per_batch,
590            point_grids=point_grids,
591            stability_score_offset=stability_score_offset,
592        )
593
594    @torch.no_grad()
595    def initialize(
596        self,
597        image: np.ndarray,
598        image_embeddings: Optional[util.ImageEmbeddings] = None,
599        i: Optional[int] = None,
600        tile_shape: Optional[Tuple[int, int]] = None,
601        halo: Optional[Tuple[int, int]] = None,
602        verbose: bool = False,
603        pbar_init: Optional[callable] = None,
604        pbar_update: Optional[callable] = None,
605        batch_size: int = 1,
606        mask: Optional[np.typing.ArrayLike] = None,
607    ) -> None:
608        """Initialize image embeddings and masks for an image.
609
610        Args:
611            image: The input image, volume or timeseries.
612            image_embeddings: Optional precomputed image embeddings.
613                See `util.precompute_image_embeddings` for details.
614            i: Index for the image data. Required if `image` has three spatial dimensions
615                or a time dimension and two spatial dimensions.
616            tile_shape: The tile shape for embedding prediction.
617            halo: The overlap of between tiles.
618            verbose: Whether to print computation progress. By default, set to 'False'.
619            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
620                Can be used together with pbar_update to handle napari progress bar in other thread.
621                To enables using this function within a threadworker.
622            pbar_update: Callback to update an external progress bar.
623            batch_size: The batch size for image embedding prediction. By default, set to '1'.
624            mask: An optional mask to define areas that are ignored in the segmentation.
625        """
626        original_size = image.shape[:2]
627        self._original_size = original_size
628
629        self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings(
630            self._predictor, image, image_embeddings, tile_shape, halo,
631            verbose=verbose, batch_size=batch_size, mask=mask, i=i,
632        )
633
634        tiling = blocking([0, 0], original_size, tile_shape)
635        if tiles_in_mask is None:
636            n_tiles = tiling.numberOfBlocks
637            tile_ids = range(n_tiles)
638        else:
639            n_tiles = len(tiles_in_mask)
640            tile_ids = tiles_in_mask
641
642        # The crop box is always the full local tile.
643        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in tile_ids]
644        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
645
646        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
647        pbar_init(n_tiles, "Compute masks for tile")
648
649        # We need to cast to the image representation that is compatible with SAM.
650        image = util._to_image(image)
651
652        mask_data = []
653        for idx, tile_id in enumerate(tile_ids):
654            # set the pre-computed embeddings for this tile
655            features = image_embeddings["features"][str(tile_id)]
656            tile_embeddings = {
657                "features": features,
658                "input_size": features.attrs["input_size"],
659                "original_size": features.attrs["original_size"],
660            }
661            util.set_precomputed(self._predictor, tile_embeddings, i)
662
663            # Compute the mask data for this tile and append it
664            this_mask_data = self._process_crop(
665                image, crop_box=crop_boxes[idx], crop_layer_idx=0, precomputed_embeddings=True
666            )
667            mask_data.append(this_mask_data)
668            pbar_update(1)
669        pbar_close()
670
671        # set the initialized data
672        self._is_initialized = True
673        self._crop_list = mask_data
674        self._crop_boxes = crop_boxes

Generates an instance segmentation without prompts, using a point grid.

Implements the same functionality as AutomaticMaskGenerator but for tiled embeddings.

Arguments:
  • predictor: The Segment Anything predictor.
  • points_per_side: The number of points to be sampled along one side of the image. If None, point_grids must provide explicit point sampling. By default, set to '32'.
  • points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. By default, set to '64'.
  • point_grids: A list over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
  • stability_score_offset: The amount to shift the cutoff when calculating the stability score. By default, set to '1.0'.
TiledAutomaticMaskGenerator( predictor: segment_anything.predictor.SamPredictor, points_per_side: Optional[int] = 32, points_per_batch: int = 64, point_grids: Optional[List[numpy.ndarray]] = None, stability_score_offset: float = 1.0)
578    def __init__(
579        self,
580        predictor: SamPredictor,
581        points_per_side: Optional[int] = 32,
582        points_per_batch: int = 64,
583        point_grids: Optional[List[np.ndarray]] = None,
584        stability_score_offset: float = 1.0,
585    ) -> None:
586        super().__init__(
587            predictor=predictor,
588            points_per_side=points_per_side,
589            points_per_batch=points_per_batch,
590            point_grids=point_grids,
591            stability_score_offset=stability_score_offset,
592        )
@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None, batch_size: int = 1, mask: Union[collections.abc.Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]], NoneType] = None) -> None:
594    @torch.no_grad()
595    def initialize(
596        self,
597        image: np.ndarray,
598        image_embeddings: Optional[util.ImageEmbeddings] = None,
599        i: Optional[int] = None,
600        tile_shape: Optional[Tuple[int, int]] = None,
601        halo: Optional[Tuple[int, int]] = None,
602        verbose: bool = False,
603        pbar_init: Optional[callable] = None,
604        pbar_update: Optional[callable] = None,
605        batch_size: int = 1,
606        mask: Optional[np.typing.ArrayLike] = None,
607    ) -> None:
608        """Initialize image embeddings and masks for an image.
609
610        Args:
611            image: The input image, volume or timeseries.
612            image_embeddings: Optional precomputed image embeddings.
613                See `util.precompute_image_embeddings` for details.
614            i: Index for the image data. Required if `image` has three spatial dimensions
615                or a time dimension and two spatial dimensions.
616            tile_shape: The tile shape for embedding prediction.
617            halo: The overlap of between tiles.
618            verbose: Whether to print computation progress. By default, set to 'False'.
619            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
620                Can be used together with pbar_update to handle napari progress bar in other thread.
621                To enables using this function within a threadworker.
622            pbar_update: Callback to update an external progress bar.
623            batch_size: The batch size for image embedding prediction. By default, set to '1'.
624            mask: An optional mask to define areas that are ignored in the segmentation.
625        """
626        original_size = image.shape[:2]
627        self._original_size = original_size
628
629        self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings(
630            self._predictor, image, image_embeddings, tile_shape, halo,
631            verbose=verbose, batch_size=batch_size, mask=mask, i=i,
632        )
633
634        tiling = blocking([0, 0], original_size, tile_shape)
635        if tiles_in_mask is None:
636            n_tiles = tiling.numberOfBlocks
637            tile_ids = range(n_tiles)
638        else:
639            n_tiles = len(tiles_in_mask)
640            tile_ids = tiles_in_mask
641
642        # The crop box is always the full local tile.
643        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in tile_ids]
644        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
645
646        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
647        pbar_init(n_tiles, "Compute masks for tile")
648
649        # We need to cast to the image representation that is compatible with SAM.
650        image = util._to_image(image)
651
652        mask_data = []
653        for idx, tile_id in enumerate(tile_ids):
654            # set the pre-computed embeddings for this tile
655            features = image_embeddings["features"][str(tile_id)]
656            tile_embeddings = {
657                "features": features,
658                "input_size": features.attrs["input_size"],
659                "original_size": features.attrs["original_size"],
660            }
661            util.set_precomputed(self._predictor, tile_embeddings, i)
662
663            # Compute the mask data for this tile and append it
664            this_mask_data = self._process_crop(
665                image, crop_box=crop_boxes[idx], crop_layer_idx=0, precomputed_embeddings=True
666            )
667            mask_data.append(this_mask_data)
668            pbar_update(1)
669        pbar_close()
670
671        # set the initialized data
672        self._is_initialized = True
673        self._crop_list = mask_data
674        self._crop_boxes = crop_boxes

Initialize image embeddings and masks for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_shape: The tile shape for embedding prediction.
  • halo: The overlap of between tiles.
  • verbose: Whether to print computation progress. By default, set to 'False'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
  • batch_size: The batch size for image embedding prediction. By default, set to '1'.
  • mask: An optional mask to define areas that are ignored in the segmentation.
class DecoderAdapter(torch.nn.modules.module.Module):
682class DecoderAdapter(torch.nn.Module):
683    """Adapter to contain the UNETR decoder in a single module.
684
685    To apply the decoder on top of pre-computed embeddings for the segmentation functionality.
686    See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
687    """
688    def __init__(self, unetr: torch.nn.Module):
689        super().__init__()
690
691        self.base = unetr.base
692        self.out_conv = unetr.out_conv
693        self.deconv_out = unetr.deconv_out
694        self.decoder_head = unetr.decoder_head
695        self.final_activation = unetr.final_activation
696        self.postprocess_masks = unetr.postprocess_masks
697
698        self.decoder = unetr.decoder
699        self.deconv1 = unetr.deconv1
700        self.deconv2 = unetr.deconv2
701        self.deconv3 = unetr.deconv3
702        self.deconv4 = unetr.deconv4
703
704    def _forward_impl(self, input_):
705        z12 = input_
706
707        z9 = self.deconv1(z12)
708        z6 = self.deconv2(z9)
709        z3 = self.deconv3(z6)
710        z0 = self.deconv4(z3)
711
712        updated_from_encoder = [z9, z6, z3]
713
714        x = self.base(z12)
715        x = self.decoder(x, encoder_inputs=updated_from_encoder)
716        x = self.deconv_out(x)
717
718        x = torch.cat([x, z0], dim=1)
719        x = self.decoder_head(x)
720
721        x = self.out_conv(x)
722        if self.final_activation is not None:
723            x = self.final_activation(x)
724        return x
725
726    def forward(self, input_, input_shape, original_shape):
727        x = self._forward_impl(input_)
728        x = self.postprocess_masks(x, input_shape, original_shape)
729        return x

Adapter to contain the UNETR decoder in a single module.

To apply the decoder on top of pre-computed embeddings for the segmentation functionality. See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py

DecoderAdapter(unetr: torch.nn.modules.module.Module)
688    def __init__(self, unetr: torch.nn.Module):
689        super().__init__()
690
691        self.base = unetr.base
692        self.out_conv = unetr.out_conv
693        self.deconv_out = unetr.deconv_out
694        self.decoder_head = unetr.decoder_head
695        self.final_activation = unetr.final_activation
696        self.postprocess_masks = unetr.postprocess_masks
697
698        self.decoder = unetr.decoder
699        self.deconv1 = unetr.deconv1
700        self.deconv2 = unetr.deconv2
701        self.deconv3 = unetr.deconv3
702        self.deconv4 = unetr.deconv4

Initialize internal Module state, shared by both nn.Module and ScriptModule.

base
out_conv
deconv_out
decoder_head
final_activation
postprocess_masks
decoder
deconv1
deconv2
deconv3
deconv4
def forward(self, input_, input_shape, original_shape):
726    def forward(self, input_, input_shape, original_shape):
727        x = self._forward_impl(input_)
728        x = self.postprocess_masks(x, input_shape, original_shape)
729        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
set_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
mtia
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_post_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_pre_hook
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def get_unetr( image_encoder: torch.nn.modules.module.Module, decoder_state: Optional[collections.OrderedDict[str, torch.Tensor]] = None, device: Union[str, torch.device, NoneType] = None, out_channels: int = 3, flexible_load_checkpoint: bool = False) -> torch.nn.modules.module.Module:
732def get_unetr(
733    image_encoder: torch.nn.Module,
734    decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
735    device: Optional[Union[str, torch.device]] = None,
736    out_channels: int = 3,
737    flexible_load_checkpoint: bool = False,
738) -> torch.nn.Module:
739    """Get UNETR model for automatic instance segmentation.
740
741    Args:
742        image_encoder: The image encoder of the SAM model.
743            This is used as encoder by the UNETR too.
744        decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
745        device: The device. By default, automatically chooses the best available device.
746        out_channels: The number of output channels. By default, set to '3'.
747        flexible_load_checkpoint: Whether to allow reinitialization of parameters
748            which could not be found in the provided decoder state. By default, set to 'False'.
749
750    Returns:
751        The UNETR model.
752    """
753    device = util.get_device(device)
754
755    if decoder_state is None:
756        use_conv_transpose = False  # By default, we use interpolation for upsampling.
757    else:
758        # From the provided pretrained 'decoder_state', we check whether it uses transposed convolutions.
759        # NOTE: Explanation to the logic below -
760        # - We do this by looking for parameter names that contain '.block.' within the "decoder.samplers"
761        #   submodules. This naming convention indicates that transposed convolutions are used,
762        #   wrapped inside a custom block module.
763        # - Otherwise '.conv.' appears. It indicates a standard `Conv2d` applied after interpolation for upsampling.
764        use_conv_transpose = any(".block." in k for k in decoder_state.keys() if k.startswith("decoder.samplers"))
765
766    unetr = UNETR(
767        backbone="sam",
768        encoder=image_encoder,
769        out_channels=out_channels,
770        use_sam_stats=True,
771        final_activation="Sigmoid",
772        use_skip_connection=False,
773        resize_input=True,
774        use_conv_transpose=use_conv_transpose,
775    )
776
777    if decoder_state is not None:
778        unetr_state_dict = unetr.state_dict()
779        for k, v in unetr_state_dict.items():
780            if not k.startswith("encoder"):
781                if flexible_load_checkpoint:  # Whether allow reinitalization of params, if not found.
782                    if k in decoder_state:  # First check whether the key is available in the provided decoder state.
783                        unetr_state_dict[k] = decoder_state[k]
784                    else:  # Otherwise, allow it to initialize it.
785                        warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.")
786                        unetr_state_dict[k] = v
787
788                else:  # Whether be strict on finding the parameter in the decoder state.
789                    if k not in decoder_state:
790                        raise RuntimeError(f"The parameters for '{k}' could not be found.")
791                    unetr_state_dict[k] = decoder_state[k]
792
793        unetr.load_state_dict(unetr_state_dict)
794
795    unetr.to(device)
796    return unetr

Get UNETR model for automatic instance segmentation.

Arguments:
  • image_encoder: The image encoder of the SAM model. This is used as encoder by the UNETR too.
  • decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
  • device: The device. By default, automatically chooses the best available device.
  • out_channels: The number of output channels. By default, set to '3'.
  • flexible_load_checkpoint: Whether to allow reinitialization of parameters which could not be found in the provided decoder state. By default, set to 'False'.
Returns:

The UNETR model.

def get_decoder( image_encoder: torch.nn.modules.module.Module, decoder_state: collections.OrderedDict[str, torch.Tensor], device: Union[str, torch.device, NoneType] = None) -> DecoderAdapter:
799def get_decoder(
800    image_encoder: torch.nn.Module,
801    decoder_state: OrderedDict[str, torch.Tensor],
802    device: Optional[Union[str, torch.device]] = None,
803) -> DecoderAdapter:
804    """Get decoder to predict outputs for automatic instance segmentation
805
806    Args:
807        image_encoder: The image encoder of the SAM model.
808        decoder_state: State to initialize the weights of the UNETR decoder.
809        device: The device. By default, automatically chooses the best available device.
810
811    Returns:
812        The decoder for instance segmentation.
813    """
814    unetr = get_unetr(image_encoder, decoder_state, device)
815    return DecoderAdapter(unetr)

Get decoder to predict outputs for automatic instance segmentation

Arguments:
  • image_encoder: The image encoder of the SAM model.
  • decoder_state: State to initialize the weights of the UNETR decoder.
  • device: The device. By default, automatically chooses the best available device.
Returns:

The decoder for instance segmentation.

def get_predictor_and_decoder( model_type: str, checkpoint_path: Union[str, os.PathLike, NoneType] = None, device: Union[str, torch.device, NoneType] = None, peft_kwargs: Optional[Dict] = None) -> Tuple[segment_anything.predictor.SamPredictor, DecoderAdapter]:
818def get_predictor_and_decoder(
819    model_type: str,
820    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
821    device: Optional[Union[str, torch.device]] = None,
822    peft_kwargs: Optional[Dict] = None,
823) -> Tuple[SamPredictor, DecoderAdapter]:
824    """Load the SAM model (predictor) and instance segmentation decoder.
825
826    This requires a checkpoint that contains the state for both predictor
827    and decoder.
828
829    Args:
830        model_type: The type of the image encoder used in the SAM model.
831        checkpoint_path: Path to the checkpoint from which to load the data.
832        device: The device. By default, automatically chooses the best available device.
833        peft_kwargs: Keyword arguments for the PEFT wrapper class.
834
835    Returns:
836        The SAM predictor.
837        The decoder for instance segmentation.
838    """
839    device = util.get_device(device)
840    predictor, state = util.get_sam_model(
841        model_type=model_type,
842        checkpoint_path=checkpoint_path,
843        device=device,
844        return_state=True,
845        peft_kwargs=peft_kwargs,
846    )
847
848    if "decoder_state" not in state:
849        raise ValueError(
850            f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
851        )
852
853    decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
854    return predictor, decoder

Load the SAM model (predictor) and instance segmentation decoder.

This requires a checkpoint that contains the state for both predictor and decoder.

Arguments:
  • model_type: The type of the image encoder used in the SAM model.
  • checkpoint_path: Path to the checkpoint from which to load the data.
  • device: The device. By default, automatically chooses the best available device.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
Returns:

The SAM predictor. The decoder for instance segmentation.

class InstanceSegmentationWithDecoder:
 906class InstanceSegmentationWithDecoder:
 907    """Generates an instance segmentation without prompts, using a decoder.
 908
 909    Implements the same interface as `AutomaticMaskGenerator`.
 910
 911    Use this class as follows:
 912    ```python
 913    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
 914    segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
 915    masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
 916    ```
 917
 918    Args:
 919        predictor: The segment anything predictor.
 920        decoder: The decoder to predict intermediate representations for instance segmentation.
 921    """
 922    def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None:
 923        self._predictor = predictor
 924        self._decoder = decoder
 925
 926        # The decoder outputs.
 927        self._foreground = None
 928        self._center_distances = None
 929        self._boundary_distances = None
 930
 931        self._is_initialized = False
 932
 933    @property
 934    def is_initialized(self):
 935        """Whether the mask generator has already been initialized.
 936        """
 937        return self._is_initialized
 938
 939    @torch.no_grad()
 940    def initialize(
 941        self,
 942        image: np.ndarray,
 943        image_embeddings: Optional[util.ImageEmbeddings] = None,
 944        i: Optional[int] = None,
 945        verbose: bool = False,
 946        pbar_init: Optional[callable] = None,
 947        pbar_update: Optional[callable] = None,
 948        ndim: int = 2,
 949    ) -> None:
 950        """Initialize image embeddings and decoder predictions for an image.
 951
 952        Args:
 953            image: The input image, volume or timeseries.
 954            image_embeddings: Optional precomputed image embeddings.
 955                See `util.precompute_image_embeddings` for details.
 956            i: Index for the image data. Required if `image` has three spatial dimensions
 957                or a time dimension and two spatial dimensions.
 958            verbose: Whether to be verbose. By default, set to 'False'.
 959            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 960                Can be used together with pbar_update to handle napari progress bar in other thread.
 961                To enables using this function within a threadworker.
 962            pbar_update: Callback to update an external progress bar.
 963            ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2.
 964        """
 965        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 966        pbar_init(1, "Initialize instance segmentation with decoder")
 967
 968        if image_embeddings is None:
 969            image_embeddings = util.precompute_image_embeddings(
 970                predictor=self._predictor, input_=image, ndim=ndim, verbose=verbose
 971            )
 972
 973        # Get the image embeddings from the predictor.
 974        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
 975        embeddings = self._predictor.features
 976        input_shape = tuple(self._predictor.input_size)
 977        original_shape = tuple(self._predictor.original_size)
 978
 979        # Run prediction with the UNETR decoder.
 980        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
 981        assert output.shape[0] == 3, f"{output.shape}"
 982        pbar_update(1)
 983        pbar_close()
 984
 985        # Set the state.
 986        self._foreground = output[0]
 987        self._center_distances = output[1]
 988        self._boundary_distances = output[2]
 989        self._i = i
 990        self._is_initialized = True
 991
 992    def _to_masks(self, segmentation, output_mode):
 993        if output_mode != "binary_mask":
 994            raise ValueError(
 995                f"Output mode {output_mode} is not supported. Choose one of 'instance_segmentation', 'binary_masks'"
 996            )
 997
 998        props = regionprops(segmentation)
 999        ndim = segmentation.ndim
1000        assert ndim in (2, 3)
1001
1002        shape = segmentation.shape
1003        if ndim == 2:
1004            crop_box = [0, shape[1], 0, shape[0]]
1005        else:
1006            crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]
1007
1008        # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
1009        def to_bbox_2d(bbox):
1010            y0, x0 = bbox[0], bbox[1]
1011            w = bbox[3] - x0
1012            h = bbox[2] - y0
1013            return [x0, w, y0, h]
1014
1015        def to_bbox_3d(bbox):
1016            z0, y0, x0 = bbox[0], bbox[1], bbox[2]
1017            w = bbox[5] - x0
1018            h = bbox[4] - y0
1019            d = bbox[3] - y0
1020            return [x0, w, y0, h, z0, d]
1021
1022        to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
1023        masks = [
1024            {
1025                "segmentation": segmentation == prop.label,
1026                "area": prop.area,
1027                "bbox": to_bbox(prop.bbox),
1028                "crop_box": crop_box,
1029                "seg_id": prop.label,
1030            } for prop in props
1031        ]
1032        return masks
1033
1034    def generate(
1035        self,
1036        center_distance_threshold: float = 0.5,
1037        boundary_distance_threshold: float = 0.5,
1038        foreground_threshold: float = 0.5,
1039        foreground_smoothing: float = 1.0,
1040        distance_smoothing: float = 1.6,
1041        min_size: int = 0,
1042        output_mode: str = "instance_segmentation",
1043        tile_shape: Optional[Tuple[int, int]] = None,
1044        halo: Optional[Tuple[int, int]] = None,
1045        n_threads: Optional[int] = None,
1046    ) -> Union[List[Dict[str, Any]], np.ndarray]:
1047        """Generate instance segmentation for the currently initialized image.
1048
1049        Args:
1050            center_distance_threshold: Center distance predictions below this value will be
1051                used to find seeds (intersected with thresholded boundary distance predictions).
1052                By default, set to '0.5'.
1053            boundary_distance_threshold: Boundary distance predictions below this value will be
1054                used to find seeds (intersected with thresholded center distance predictions).
1055                By default, set to '0.5'.
1056            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1057                By default, set to '0.5'.
1058            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1059                checkerboard artifacts in the prediction. By default, set to '1.0'.
1060            distance_smoothing: Sigma value for smoothing the distance predictions.
1061            min_size: Minimal object size in the segmentation result. By default, set to '0'.
1062            output_mode: The form masks are returned in. Possible values are:
1063                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1064                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1065                By default, set to 'instance_segmentation'.
1066            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1067                This parameter is independent from the tile shape for computing the embeddings.
1068                If not given then post-processing will not be parallelized.
1069            halo: Halo for parallel post-processing. See also `tile_shape`.
1070            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1071
1072        Returns:
1073            The segmentation masks.
1074        """
1075        if not self.is_initialized:
1076            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1077
1078        if foreground_smoothing > 0:
1079            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1080        else:
1081            foreground = self._foreground
1082
1083        if tile_shape is None:
1084            segmentation = watershed_from_center_and_boundary_distances(
1085                center_distances=self._center_distances,
1086                boundary_distances=self._boundary_distances,
1087                foreground_map=foreground,
1088                center_distance_threshold=center_distance_threshold,
1089                boundary_distance_threshold=boundary_distance_threshold,
1090                foreground_threshold=foreground_threshold,
1091                distance_smoothing=distance_smoothing,
1092                min_size=min_size,
1093            )
1094        else:
1095            if halo is None:
1096                raise ValueError("You must pass a value for halo if tile_shape is given.")
1097            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1098                center_distances=self._center_distances,
1099                boundary_distances=self._boundary_distances,
1100                foreground_map=foreground,
1101                center_distance_threshold=center_distance_threshold,
1102                boundary_distance_threshold=boundary_distance_threshold,
1103                foreground_threshold=foreground_threshold,
1104                distance_smoothing=distance_smoothing,
1105                min_size=min_size,
1106                tile_shape=tile_shape,
1107                halo=halo,
1108                n_threads=n_threads,
1109                verbose=False,
1110            )
1111
1112        if output_mode != "instance_segmentation":
1113            segmentation = self._to_masks(segmentation, output_mode)
1114        return segmentation
1115
1116    def get_state(self) -> Dict[str, Any]:
1117        """Get the initialized state of the instance segmenter.
1118
1119        Returns:
1120            Instance segmentation state.
1121        """
1122        if not self.is_initialized:
1123            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1124
1125        return {
1126            "foreground": self._foreground,
1127            "center_distances": self._center_distances,
1128            "boundary_distances": self._boundary_distances,
1129        }
1130
1131    def set_state(self, state: Dict[str, Any]) -> None:
1132        """Set the state of the instance segmenter.
1133
1134        Args:
1135            state: The instance segmentation state
1136        """
1137        self._foreground = state["foreground"]
1138        self._center_distances = state["center_distances"]
1139        self._boundary_distances = state["boundary_distances"]
1140        self._is_initialized = True
1141
1142    def clear_state(self):
1143        """Clear the state of the instance segmenter.
1144        """
1145        self._foreground = None
1146        self._center_distances = None
1147        self._boundary_distances = None
1148        self._is_initialized = False

Generates an instance segmentation without prompts, using a decoder.

Implements the same interface as AutomaticMaskGenerator.

Use this class as follows:

segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
Arguments:
  • predictor: The segment anything predictor.
  • decoder: The decoder to predict intermediate representations for instance segmentation.
InstanceSegmentationWithDecoder( predictor: segment_anything.predictor.SamPredictor, decoder: torch.nn.modules.module.Module)
922    def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None:
923        self._predictor = predictor
924        self._decoder = decoder
925
926        # The decoder outputs.
927        self._foreground = None
928        self._center_distances = None
929        self._boundary_distances = None
930
931        self._is_initialized = False
is_initialized
933    @property
934    def is_initialized(self):
935        """Whether the mask generator has already been initialized.
936        """
937        return self._is_initialized

Whether the mask generator has already been initialized.

@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None, ndim: int = 2) -> None:
939    @torch.no_grad()
940    def initialize(
941        self,
942        image: np.ndarray,
943        image_embeddings: Optional[util.ImageEmbeddings] = None,
944        i: Optional[int] = None,
945        verbose: bool = False,
946        pbar_init: Optional[callable] = None,
947        pbar_update: Optional[callable] = None,
948        ndim: int = 2,
949    ) -> None:
950        """Initialize image embeddings and decoder predictions for an image.
951
952        Args:
953            image: The input image, volume or timeseries.
954            image_embeddings: Optional precomputed image embeddings.
955                See `util.precompute_image_embeddings` for details.
956            i: Index for the image data. Required if `image` has three spatial dimensions
957                or a time dimension and two spatial dimensions.
958            verbose: Whether to be verbose. By default, set to 'False'.
959            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
960                Can be used together with pbar_update to handle napari progress bar in other thread.
961                To enables using this function within a threadworker.
962            pbar_update: Callback to update an external progress bar.
963            ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2.
964        """
965        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
966        pbar_init(1, "Initialize instance segmentation with decoder")
967
968        if image_embeddings is None:
969            image_embeddings = util.precompute_image_embeddings(
970                predictor=self._predictor, input_=image, ndim=ndim, verbose=verbose
971            )
972
973        # Get the image embeddings from the predictor.
974        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
975        embeddings = self._predictor.features
976        input_shape = tuple(self._predictor.input_size)
977        original_shape = tuple(self._predictor.original_size)
978
979        # Run prediction with the UNETR decoder.
980        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
981        assert output.shape[0] == 3, f"{output.shape}"
982        pbar_update(1)
983        pbar_close()
984
985        # Set the state.
986        self._foreground = output[0]
987        self._center_distances = output[1]
988        self._boundary_distances = output[2]
989        self._i = i
990        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to be verbose. By default, set to 'False'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
  • ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, 2.
def generate( self, center_distance_threshold: float = 0.5, boundary_distance_threshold: float = 0.5, foreground_threshold: float = 0.5, foreground_smoothing: float = 1.0, distance_smoothing: float = 1.6, min_size: int = 0, output_mode: str = 'instance_segmentation', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, n_threads: Optional[int] = None) -> Union[List[Dict[str, Any]], numpy.ndarray]:
1034    def generate(
1035        self,
1036        center_distance_threshold: float = 0.5,
1037        boundary_distance_threshold: float = 0.5,
1038        foreground_threshold: float = 0.5,
1039        foreground_smoothing: float = 1.0,
1040        distance_smoothing: float = 1.6,
1041        min_size: int = 0,
1042        output_mode: str = "instance_segmentation",
1043        tile_shape: Optional[Tuple[int, int]] = None,
1044        halo: Optional[Tuple[int, int]] = None,
1045        n_threads: Optional[int] = None,
1046    ) -> Union[List[Dict[str, Any]], np.ndarray]:
1047        """Generate instance segmentation for the currently initialized image.
1048
1049        Args:
1050            center_distance_threshold: Center distance predictions below this value will be
1051                used to find seeds (intersected with thresholded boundary distance predictions).
1052                By default, set to '0.5'.
1053            boundary_distance_threshold: Boundary distance predictions below this value will be
1054                used to find seeds (intersected with thresholded center distance predictions).
1055                By default, set to '0.5'.
1056            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1057                By default, set to '0.5'.
1058            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1059                checkerboard artifacts in the prediction. By default, set to '1.0'.
1060            distance_smoothing: Sigma value for smoothing the distance predictions.
1061            min_size: Minimal object size in the segmentation result. By default, set to '0'.
1062            output_mode: The form masks are returned in. Possible values are:
1063                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1064                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1065                By default, set to 'instance_segmentation'.
1066            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1067                This parameter is independent from the tile shape for computing the embeddings.
1068                If not given then post-processing will not be parallelized.
1069            halo: Halo for parallel post-processing. See also `tile_shape`.
1070            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1071
1072        Returns:
1073            The segmentation masks.
1074        """
1075        if not self.is_initialized:
1076            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1077
1078        if foreground_smoothing > 0:
1079            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1080        else:
1081            foreground = self._foreground
1082
1083        if tile_shape is None:
1084            segmentation = watershed_from_center_and_boundary_distances(
1085                center_distances=self._center_distances,
1086                boundary_distances=self._boundary_distances,
1087                foreground_map=foreground,
1088                center_distance_threshold=center_distance_threshold,
1089                boundary_distance_threshold=boundary_distance_threshold,
1090                foreground_threshold=foreground_threshold,
1091                distance_smoothing=distance_smoothing,
1092                min_size=min_size,
1093            )
1094        else:
1095            if halo is None:
1096                raise ValueError("You must pass a value for halo if tile_shape is given.")
1097            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1098                center_distances=self._center_distances,
1099                boundary_distances=self._boundary_distances,
1100                foreground_map=foreground,
1101                center_distance_threshold=center_distance_threshold,
1102                boundary_distance_threshold=boundary_distance_threshold,
1103                foreground_threshold=foreground_threshold,
1104                distance_smoothing=distance_smoothing,
1105                min_size=min_size,
1106                tile_shape=tile_shape,
1107                halo=halo,
1108                n_threads=n_threads,
1109                verbose=False,
1110            )
1111
1112        if output_mode != "instance_segmentation":
1113            segmentation = self._to_masks(segmentation, output_mode)
1114        return segmentation

Generate instance segmentation for the currently initialized image.

Arguments:
  • center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions). By default, set to '0.5'.
  • boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions). By default, set to '0.5'.
  • foreground_threshold: Foreground predictions above this value will be used as foreground mask. By default, set to '0.5'.
  • foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction. By default, set to '1.0'.
  • distance_smoothing: Sigma value for smoothing the distance predictions.
  • min_size: Minimal object size in the segmentation result. By default, set to '0'.
  • output_mode: The form masks are returned in. Possible values are:
    • 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
    • 'instance_segmentation': Return masks merged into an instance segmentation in a single array. By default, set to 'instance_segmentation'.
  • tile_shape: Tile shape for parallelizing the instance segmentation post-processing. This parameter is independent from the tile shape for computing the embeddings. If not given then post-processing will not be parallelized.
  • halo: Halo for parallel post-processing. See also tile_shape.
  • n_threads: Number of threads for parallel post-processing. See also tile_shape.
Returns:

The segmentation masks.

def get_state(self) -> Dict[str, Any]:
1116    def get_state(self) -> Dict[str, Any]:
1117        """Get the initialized state of the instance segmenter.
1118
1119        Returns:
1120            Instance segmentation state.
1121        """
1122        if not self.is_initialized:
1123            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1124
1125        return {
1126            "foreground": self._foreground,
1127            "center_distances": self._center_distances,
1128            "boundary_distances": self._boundary_distances,
1129        }

Get the initialized state of the instance segmenter.

Returns:

Instance segmentation state.

def set_state(self, state: Dict[str, Any]) -> None:
1131    def set_state(self, state: Dict[str, Any]) -> None:
1132        """Set the state of the instance segmenter.
1133
1134        Args:
1135            state: The instance segmentation state
1136        """
1137        self._foreground = state["foreground"]
1138        self._center_distances = state["center_distances"]
1139        self._boundary_distances = state["boundary_distances"]
1140        self._is_initialized = True

Set the state of the instance segmenter.

Arguments:
  • state: The instance segmentation state
def clear_state(self):
1142    def clear_state(self):
1143        """Clear the state of the instance segmenter.
1144        """
1145        self._foreground = None
1146        self._center_distances = None
1147        self._boundary_distances = None
1148        self._is_initialized = False

Clear the state of the instance segmenter.

class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1151class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1152    """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings.
1153    """
1154
1155    # Apply the decoder in a batched fashion, and then perform the resizing independently per output.
1156    # This is necessary, because the individual tiles may have different tile shapes due to border tiles.
1157    def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes):
1158        batched_embeddings = torch.cat(batched_embeddings)
1159        output = self._decoder._forward_impl(batched_embeddings)
1160
1161        batched_output = []
1162        for x, input_shape, original_shape in zip(output, input_shapes, original_shapes):
1163            x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0)
1164            batched_output.append(x.cpu().numpy())
1165        return batched_output
1166
1167    @torch.no_grad()
1168    def initialize(
1169        self,
1170        image: np.ndarray,
1171        image_embeddings: Optional[util.ImageEmbeddings] = None,
1172        i: Optional[int] = None,
1173        tile_shape: Optional[Tuple[int, int]] = None,
1174        halo: Optional[Tuple[int, int]] = None,
1175        verbose: bool = False,
1176        pbar_init: Optional[callable] = None,
1177        pbar_update: Optional[callable] = None,
1178        batch_size: int = 1,
1179        mask: Optional[np.typing.ArrayLike] = None,
1180    ) -> None:
1181        """Initialize image embeddings and decoder predictions for an image.
1182
1183        Args:
1184            image: The input image, volume or timeseries.
1185            image_embeddings: Optional precomputed image embeddings.
1186                See `util.precompute_image_embeddings` for details.
1187            i: Index for the image data. Required if `image` has three spatial dimensions
1188                or a time dimension and two spatial dimensions.
1189            tile_shape: Shape of the tiles for precomputing image embeddings.
1190            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1191            verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
1192            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1193                Can be used together with pbar_update to handle napari progress bar in other thread.
1194                To enables using this function within a threadworker.
1195            pbar_update: Callback to update an external progress bar.
1196            batch_size: The batch size for image embedding computation and segmentation decoder prediction.
1197            mask: An optional mask to define areas that are ignored in the segmentation.
1198        """
1199        original_size = image.shape[:2]
1200        self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings(
1201            self._predictor, image, image_embeddings, tile_shape, halo,
1202            verbose=verbose, batch_size=batch_size, mask=mask, i=i,
1203        )
1204        tiling = blocking([0, 0], original_size, tile_shape)
1205
1206        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1207
1208        foreground = np.zeros(original_size, dtype="float32")
1209        center_distances = np.zeros(original_size, dtype="float32")
1210        boundary_distances = np.zeros(original_size, dtype="float32")
1211
1212        msg = "Initialize tiled instance segmentation with decoder"
1213        if tiles_in_mask is None:
1214            n_tiles = tiling.numberOfBlocks
1215            all_tile_ids = list(range(n_tiles))
1216        else:
1217            n_tiles = len(tiles_in_mask)
1218            all_tile_ids = tiles_in_mask
1219            msg += " and mask"
1220
1221        n_batches = int(np.ceil(n_tiles / batch_size))
1222        pbar_init(n_tiles, msg)
1223        tile_ids_for_batches = np.array_split(all_tile_ids, n_batches)
1224
1225        for tile_ids in tile_ids_for_batches:
1226            batched_embeddings, input_shapes, original_shapes = [], [], []
1227            for tile_id in tile_ids:
1228                # Get the image embeddings from the predictor for this tile.
1229                self._predictor = util.set_precomputed(self._predictor, self._image_embeddings, i=i, tile_id=tile_id)
1230
1231                batched_embeddings.append(self._predictor.features)
1232                input_shapes.append(tuple(self._predictor.input_size))
1233                original_shapes.append(tuple(self._predictor.original_size))
1234
1235            batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes)
1236
1237            for output_id, tile_id in enumerate(tile_ids):
1238                output = batched_output[output_id]
1239                assert output.shape[0] == 3
1240
1241                # Set the predictions in the output for this tile.
1242                block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1243                local_bb = tuple(
1244                    slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1245                )
1246                inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1247
1248                foreground[inner_bb] = output[0][local_bb]
1249                center_distances[inner_bb] = output[1][local_bb]
1250                boundary_distances[inner_bb] = output[2][local_bb]
1251                pbar_update(1)
1252
1253        pbar_close()
1254
1255        # Set the state.
1256        self._i = i
1257        self._foreground = foreground
1258        self._center_distances = center_distances
1259        self._boundary_distances = boundary_distances
1260        self._is_initialized = True

Same as InstanceSegmentationWithDecoder but for tiled image embeddings.

@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None, batch_size: int = 1, mask: Union[collections.abc.Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]], NoneType] = None) -> None:
1167    @torch.no_grad()
1168    def initialize(
1169        self,
1170        image: np.ndarray,
1171        image_embeddings: Optional[util.ImageEmbeddings] = None,
1172        i: Optional[int] = None,
1173        tile_shape: Optional[Tuple[int, int]] = None,
1174        halo: Optional[Tuple[int, int]] = None,
1175        verbose: bool = False,
1176        pbar_init: Optional[callable] = None,
1177        pbar_update: Optional[callable] = None,
1178        batch_size: int = 1,
1179        mask: Optional[np.typing.ArrayLike] = None,
1180    ) -> None:
1181        """Initialize image embeddings and decoder predictions for an image.
1182
1183        Args:
1184            image: The input image, volume or timeseries.
1185            image_embeddings: Optional precomputed image embeddings.
1186                See `util.precompute_image_embeddings` for details.
1187            i: Index for the image data. Required if `image` has three spatial dimensions
1188                or a time dimension and two spatial dimensions.
1189            tile_shape: Shape of the tiles for precomputing image embeddings.
1190            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1191            verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
1192            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1193                Can be used together with pbar_update to handle napari progress bar in other thread.
1194                To enables using this function within a threadworker.
1195            pbar_update: Callback to update an external progress bar.
1196            batch_size: The batch size for image embedding computation and segmentation decoder prediction.
1197            mask: An optional mask to define areas that are ignored in the segmentation.
1198        """
1199        original_size = image.shape[:2]
1200        self._image_embeddings, tile_shape, halo, tiles_in_mask = _process_tiled_embeddings(
1201            self._predictor, image, image_embeddings, tile_shape, halo,
1202            verbose=verbose, batch_size=batch_size, mask=mask, i=i,
1203        )
1204        tiling = blocking([0, 0], original_size, tile_shape)
1205
1206        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1207
1208        foreground = np.zeros(original_size, dtype="float32")
1209        center_distances = np.zeros(original_size, dtype="float32")
1210        boundary_distances = np.zeros(original_size, dtype="float32")
1211
1212        msg = "Initialize tiled instance segmentation with decoder"
1213        if tiles_in_mask is None:
1214            n_tiles = tiling.numberOfBlocks
1215            all_tile_ids = list(range(n_tiles))
1216        else:
1217            n_tiles = len(tiles_in_mask)
1218            all_tile_ids = tiles_in_mask
1219            msg += " and mask"
1220
1221        n_batches = int(np.ceil(n_tiles / batch_size))
1222        pbar_init(n_tiles, msg)
1223        tile_ids_for_batches = np.array_split(all_tile_ids, n_batches)
1224
1225        for tile_ids in tile_ids_for_batches:
1226            batched_embeddings, input_shapes, original_shapes = [], [], []
1227            for tile_id in tile_ids:
1228                # Get the image embeddings from the predictor for this tile.
1229                self._predictor = util.set_precomputed(self._predictor, self._image_embeddings, i=i, tile_id=tile_id)
1230
1231                batched_embeddings.append(self._predictor.features)
1232                input_shapes.append(tuple(self._predictor.input_size))
1233                original_shapes.append(tuple(self._predictor.original_size))
1234
1235            batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes)
1236
1237            for output_id, tile_id in enumerate(tile_ids):
1238                output = batched_output[output_id]
1239                assert output.shape[0] == 3
1240
1241                # Set the predictions in the output for this tile.
1242                block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1243                local_bb = tuple(
1244                    slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1245                )
1246                inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1247
1248                foreground[inner_bb] = output[0][local_bb]
1249                center_distances[inner_bb] = output[1][local_bb]
1250                boundary_distances[inner_bb] = output[2][local_bb]
1251                pbar_update(1)
1252
1253        pbar_close()
1254
1255        # Set the state.
1256        self._i = i
1257        self._foreground = foreground
1258        self._center_distances = center_distances
1259        self._boundary_distances = boundary_distances
1260        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_shape: Shape of the tiles for precomputing image embeddings.
  • halo: Overlap of the tiles for tiled precomputation of image embeddings.
  • verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
  • batch_size: The batch size for image embedding computation and segmentation decoder prediction.
  • mask: An optional mask to define areas that are ignored in the segmentation.
class AutomaticPromptGenerator(InstanceSegmentationWithDecoder):
1335class AutomaticPromptGenerator(InstanceSegmentationWithDecoder):
1336    """Generates an instance segmentation automatically, using automatically generated prompts from a decoder.
1337
1338    This class is used in the same way as `InstanceSegmentationWithDecoder` and `AutomaticMaskGenerator`
1339
1340    Args:
1341        predictor: The segment anything predictor.
1342        decoder: The derive prompts for automatic instance segmentation.
1343    """
1344    def generate(
1345        self,
1346        min_size: int = 25,
1347        center_distance_threshold: float = 0.5,
1348        boundary_distance_threshold: float = 0.5,
1349        foreground_threshold: float = 0.5,
1350        multimasking: bool = False,
1351        batch_size: int = 32,
1352        nms_threshold: float = 0.9,
1353        intersection_over_min: bool = False,
1354        output_mode: str = "instance_segmentation",
1355        mask_threshold: Optional[Union[float, str]] = None,
1356        refine_with_box_prompts: bool = False,
1357        prompt_function: Optional[callable] = None,
1358    ) -> Union[List[Dict[str, Any]], np.ndarray]:
1359        """Generate the instance segmentation for the currently initialized image.
1360
1361        The instance segmentation is generated by deriving prompts from the foreground and
1362        distance predictions of the segmentation decoder by thresholding these predictions,
1363        intersecting them, computing connected components, and then using the component's
1364        centers as point prompts. The masks are then filtered via NMS and merged into a segmentation.
1365
1366        Args:
1367            min_size: Minimal object size in the segmentation result. By default, set to '25'.
1368            center_distance_threshold: The threshold for the center distance predictions.
1369            boundary_distance_threshold: The threshold for the boundary distance predictions.
1370            multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
1371            batch_size: The batch size for parallelizing the prediction based on prompts.
1372            nms_threshold: The threshold for non-maximum suppression (NMS).
1373            intersection_over_min: Whether to use the minimum area of the two objects or the
1374                intersection over union area (default) in NMS.
1375            output_mode: The form masks are returned in. Possible values are:
1376                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1377                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1378                By default, set to 'instance_segmentation'.
1379            mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.`
1380            refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps
1381                derived from the segmentations after point prompts.
1382            prompt_function: A custom function for deriving prompts from the segmentation decoder predictions.
1383                If given, the default prompt derivation procedure is not used. Must have the following signature:
1384                ```
1385                    def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
1386                ```
1387                where `foreground`, `center_distances`, and `boundary_distances` are the respective
1388                predictions from the segmentation decoder. It must returns a dictionary containing
1389                either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`.
1390
1391        Returns:
1392            The instance segmentation masks.
1393        """
1394        if not self.is_initialized:
1395            raise RuntimeError("AutomaticPromptGenerator has not been initialized. Call initialize first.")
1396        foreground, center_distances, boundary_distances =\
1397            self._foreground, self._center_distances, self._boundary_distances
1398
1399        # 1.) Derive promtps from the decoder predictions.
1400        prompt_function = _derive_point_prompts if prompt_function is None else prompt_function
1401        prompts = prompt_function(
1402            foreground=foreground,
1403            center_distances=center_distances,
1404            boundary_distances=boundary_distances,
1405            foreground_threshold=foreground_threshold,
1406            center_distance_threshold=center_distance_threshold,
1407            boundary_distance_threshold=boundary_distance_threshold,
1408        )
1409
1410        # 2.) Apply the predictor to the prompts.
1411        if prompts is None:  # No prompts were derived, we can't do much further and return empty masks.
1412            return np.zeros(foreground.shape, dtype="uint32") if output_mode == "instance_egmentation" else []
1413        else:
1414            predictions = batched_inference(
1415                self._predictor,
1416                image=None,
1417                batch_size=batch_size,
1418                return_instance_segmentation=False,
1419                multimasking=multimasking,
1420                mask_threshold=mask_threshold,
1421                i=getattr(self, "_i", None),
1422                **prompts,
1423            )
1424
1425        # 3.) Refine the segmentation with box prompts.
1426        if refine_with_box_prompts:
1427            box_extension = 0.01  # expose as hyperparam?
1428            prompts = _derive_box_prompts(predictions, box_extension)
1429            predictions = batched_inference(
1430                self._predictor,
1431                image=None,
1432                batch_size=batch_size,
1433                return_instance_segmentation=False,
1434                multimasking=multimasking,
1435                mask_threshold=mask_threshold,
1436                i=getattr(self, "_i", None),
1437                **prompts,
1438            )
1439
1440        # 4.) Apply non-max suppression to the masks.
1441        segmentation = util.apply_nms(
1442            predictions, min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min
1443        )
1444        if output_mode != "instance_segmentation":
1445            segmentation = self._to_masks(segmentation, output_mode)
1446        return segmentation

Generates an instance segmentation automatically, using automatically generated prompts from a decoder.

This class is used in the same way as InstanceSegmentationWithDecoder and AutomaticMaskGenerator

Arguments:
  • predictor: The segment anything predictor.
  • decoder: The derive prompts for automatic instance segmentation.
def generate( self, min_size: int = 25, center_distance_threshold: float = 0.5, boundary_distance_threshold: float = 0.5, foreground_threshold: float = 0.5, multimasking: bool = False, batch_size: int = 32, nms_threshold: float = 0.9, intersection_over_min: bool = False, output_mode: str = 'instance_segmentation', mask_threshold: Union[float, str, NoneType] = None, refine_with_box_prompts: bool = False, prompt_function: Optional[<built-in function callable>] = None) -> Union[List[Dict[str, Any]], numpy.ndarray]:
1344    def generate(
1345        self,
1346        min_size: int = 25,
1347        center_distance_threshold: float = 0.5,
1348        boundary_distance_threshold: float = 0.5,
1349        foreground_threshold: float = 0.5,
1350        multimasking: bool = False,
1351        batch_size: int = 32,
1352        nms_threshold: float = 0.9,
1353        intersection_over_min: bool = False,
1354        output_mode: str = "instance_segmentation",
1355        mask_threshold: Optional[Union[float, str]] = None,
1356        refine_with_box_prompts: bool = False,
1357        prompt_function: Optional[callable] = None,
1358    ) -> Union[List[Dict[str, Any]], np.ndarray]:
1359        """Generate the instance segmentation for the currently initialized image.
1360
1361        The instance segmentation is generated by deriving prompts from the foreground and
1362        distance predictions of the segmentation decoder by thresholding these predictions,
1363        intersecting them, computing connected components, and then using the component's
1364        centers as point prompts. The masks are then filtered via NMS and merged into a segmentation.
1365
1366        Args:
1367            min_size: Minimal object size in the segmentation result. By default, set to '25'.
1368            center_distance_threshold: The threshold for the center distance predictions.
1369            boundary_distance_threshold: The threshold for the boundary distance predictions.
1370            multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
1371            batch_size: The batch size for parallelizing the prediction based on prompts.
1372            nms_threshold: The threshold for non-maximum suppression (NMS).
1373            intersection_over_min: Whether to use the minimum area of the two objects or the
1374                intersection over union area (default) in NMS.
1375            output_mode: The form masks are returned in. Possible values are:
1376                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1377                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1378                By default, set to 'instance_segmentation'.
1379            mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.`
1380            refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps
1381                derived from the segmentations after point prompts.
1382            prompt_function: A custom function for deriving prompts from the segmentation decoder predictions.
1383                If given, the default prompt derivation procedure is not used. Must have the following signature:
1384                ```
1385                    def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
1386                ```
1387                where `foreground`, `center_distances`, and `boundary_distances` are the respective
1388                predictions from the segmentation decoder. It must returns a dictionary containing
1389                either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`.
1390
1391        Returns:
1392            The instance segmentation masks.
1393        """
1394        if not self.is_initialized:
1395            raise RuntimeError("AutomaticPromptGenerator has not been initialized. Call initialize first.")
1396        foreground, center_distances, boundary_distances =\
1397            self._foreground, self._center_distances, self._boundary_distances
1398
1399        # 1.) Derive promtps from the decoder predictions.
1400        prompt_function = _derive_point_prompts if prompt_function is None else prompt_function
1401        prompts = prompt_function(
1402            foreground=foreground,
1403            center_distances=center_distances,
1404            boundary_distances=boundary_distances,
1405            foreground_threshold=foreground_threshold,
1406            center_distance_threshold=center_distance_threshold,
1407            boundary_distance_threshold=boundary_distance_threshold,
1408        )
1409
1410        # 2.) Apply the predictor to the prompts.
1411        if prompts is None:  # No prompts were derived, we can't do much further and return empty masks.
1412            return np.zeros(foreground.shape, dtype="uint32") if output_mode == "instance_egmentation" else []
1413        else:
1414            predictions = batched_inference(
1415                self._predictor,
1416                image=None,
1417                batch_size=batch_size,
1418                return_instance_segmentation=False,
1419                multimasking=multimasking,
1420                mask_threshold=mask_threshold,
1421                i=getattr(self, "_i", None),
1422                **prompts,
1423            )
1424
1425        # 3.) Refine the segmentation with box prompts.
1426        if refine_with_box_prompts:
1427            box_extension = 0.01  # expose as hyperparam?
1428            prompts = _derive_box_prompts(predictions, box_extension)
1429            predictions = batched_inference(
1430                self._predictor,
1431                image=None,
1432                batch_size=batch_size,
1433                return_instance_segmentation=False,
1434                multimasking=multimasking,
1435                mask_threshold=mask_threshold,
1436                i=getattr(self, "_i", None),
1437                **prompts,
1438            )
1439
1440        # 4.) Apply non-max suppression to the masks.
1441        segmentation = util.apply_nms(
1442            predictions, min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min
1443        )
1444        if output_mode != "instance_segmentation":
1445            segmentation = self._to_masks(segmentation, output_mode)
1446        return segmentation

Generate the instance segmentation for the currently initialized image.

The instance segmentation is generated by deriving prompts from the foreground and distance predictions of the segmentation decoder by thresholding these predictions, intersecting them, computing connected components, and then using the component's centers as point prompts. The masks are then filtered via NMS and merged into a segmentation.

Arguments:
  • min_size: Minimal object size in the segmentation result. By default, set to '25'.
  • center_distance_threshold: The threshold for the center distance predictions.
  • boundary_distance_threshold: The threshold for the boundary distance predictions.
  • multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
  • batch_size: The batch size for parallelizing the prediction based on prompts.
  • nms_threshold: The threshold for non-maximum suppression (NMS).
  • intersection_over_min: Whether to use the minimum area of the two objects or the intersection over union area (default) in NMS.
  • output_mode: The form masks are returned in. Possible values are:
    • 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
    • 'instance_segmentation': Return masks merged into an instance segmentation in a single array. By default, set to 'instance_segmentation'.
  • mask_threshold: The threshold for turining logits into masks in micro_sam.inference.batched_inference.`
  • refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps derived from the segmentations after point prompts.
  • prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. If given, the default prompt derivation procedure is not used. Must have the following signature:
        def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
    
    where foreground, center_distances, and boundary_distances are the respective predictions from the segmentation decoder. It must returns a dictionary containing either point, box, or mask prompts in a format compattible with micro_sam.inference.batched_inference.
Returns:

The instance segmentation masks.

class TiledAutomaticPromptGenerator(TiledInstanceSegmentationWithDecoder):
1449class TiledAutomaticPromptGenerator(TiledInstanceSegmentationWithDecoder):
1450    """Same as `AutomaticPromptGenerator` but for tiled image embeddings.
1451    """
1452    def generate(
1453        self,
1454        min_size: int = 25,
1455        center_distance_threshold: float = 0.5,
1456        boundary_distance_threshold: float = 0.5,
1457        foreground_threshold: float = 0.5,
1458        multimasking: bool = False,
1459        batch_size: int = 32,
1460        nms_threshold: float = 0.9,
1461        intersection_over_min: bool = False,
1462        output_mode: str = "instance_segmentation",
1463        mask_threshold: Optional[Union[float, str]] = None,
1464        refine_with_box_prompts: bool = False,
1465        prompt_function: Optional[callable] = None,
1466        optimize_memory: bool = False,
1467    ) -> List[Dict[str, Any]]:
1468        """Generate tiling-based instance segmentation for the currently initialized image.
1469
1470        Args:
1471            min_size: Minimal object size in the segmentation result. By default, set to '25'.
1472            center_distance_threshold: The threshold for the center distance predictions.
1473            boundary_distance_threshold: The threshold for the boundary distance predictions.
1474            multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
1475            batch_size: The batch size for parallelizing the prediction based on prompts.
1476            nms_threshold: The threshold for non-maximum suppression (NMS).
1477            intersection_over_min: Whether to use the minimum area of the two objects or the
1478                intersection over union area (default) in NMS.
1479            output_mode: The form masks are returned in. Possible values are:
1480                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1481                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1482                By default, set to 'instance_segmentation'.
1483            mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.`
1484            refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps
1485                derived from the segmentations after point prompts. Currently not supported for tiled segmentation.
1486            prompt_function: A custom function for deriving prompts from the segmentation decoder predictions.
1487                If given, the default prompt derivation procedure is not used. Must have the following signature:
1488                ```
1489                    def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
1490                ```
1491                where `foreground`, `center_distances`, and `boundary_distances` are the respective
1492                predictions from the segmentation decoder. It must returns a dictionary containing
1493                either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`.
1494            optimize_memory: Whether to optimize the memory consumption by merging the per-slice
1495                segmentation results immediatly with NMS, rather than running a NMS for all results.
1496                This may lead to a slightly different segmentation result and is only compatible with
1497                `refine_with_box_prompts=False` and `output_mode="instance_segmentation"`.
1498
1499        Returns:
1500            The instance segmentation masks.
1501        """
1502        if not self.is_initialized:
1503            raise RuntimeError("TiledAutomaticPromptGenerator has not been initialized. Call initialize first.")
1504        if optimize_memory and (output_mode != "instance_segmentation" or refine_with_box_prompts):
1505            raise ValueError("Invalid settings")
1506        foreground, center_distances, boundary_distances =\
1507            self._foreground, self._center_distances, self._boundary_distances
1508
1509        # 1.) Derive promtps from the decoder predictions.
1510        prompt_function = _derive_point_prompts if prompt_function is None else prompt_function
1511        prompts = prompt_function(
1512            foreground,
1513            center_distances,
1514            boundary_distances,
1515            foreground_threshold=foreground_threshold,
1516            center_distance_threshold=center_distance_threshold,
1517            boundary_distance_threshold=boundary_distance_threshold,
1518        )
1519
1520        # 2.) Apply the predictor to the prompts.
1521        shape = foreground.shape
1522        if prompts is None:  # No prompts were derived, we can't do much further and return empty masks.
1523            return np.zeros(shape, dtype="uint32") if output_mode == "instance_segmentation" else []
1524        else:
1525            if optimize_memory:
1526                prompts.update(dict(
1527                    min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min
1528                ))
1529            predictions = batched_tiled_inference(
1530                self._predictor,
1531                image=None,
1532                batch_size=batch_size,
1533                image_embeddings=self._image_embeddings,
1534                return_instance_segmentation=False,
1535                multimasking=multimasking,
1536                optimize_memory=optimize_memory,
1537                i=getattr(self, "_i", None),
1538                **prompts
1539            )
1540        # Optimize memory directly returns an instance segmentation and does not
1541        # allow for any further refinements.
1542        if optimize_memory:
1543            return predictions
1544
1545        # 3.) Refine the segmentation with box prompts.
1546        if refine_with_box_prompts:
1547            # TODO
1548            raise NotImplementedError
1549
1550        # 4.) Apply non-max suppression to the masks.
1551        segmentation = util.apply_nms(
1552            predictions, shape=shape, min_size=min_size, nms_thresh=nms_threshold,
1553            intersection_over_min=intersection_over_min,
1554        )
1555        if output_mode != "instance_segmentation":
1556            segmentation = self._to_masks(segmentation, output_mode)
1557        return segmentation
1558
1559    # Set state and get state are not implemented yet, as this generator relies on having the image embeddings
1560    # in the state. However, they should not be serialized here and we have to address this a bit differently.
1561    def get_state(self):
1562        """@private
1563        """
1564        raise NotImplementedError
1565
1566    def set_state(self, state):
1567        """@private
1568        """
1569        raise NotImplementedError

Same as AutomaticPromptGenerator but for tiled image embeddings.

def generate( self, min_size: int = 25, center_distance_threshold: float = 0.5, boundary_distance_threshold: float = 0.5, foreground_threshold: float = 0.5, multimasking: bool = False, batch_size: int = 32, nms_threshold: float = 0.9, intersection_over_min: bool = False, output_mode: str = 'instance_segmentation', mask_threshold: Union[float, str, NoneType] = None, refine_with_box_prompts: bool = False, prompt_function: Optional[<built-in function callable>] = None, optimize_memory: bool = False) -> List[Dict[str, Any]]:
1452    def generate(
1453        self,
1454        min_size: int = 25,
1455        center_distance_threshold: float = 0.5,
1456        boundary_distance_threshold: float = 0.5,
1457        foreground_threshold: float = 0.5,
1458        multimasking: bool = False,
1459        batch_size: int = 32,
1460        nms_threshold: float = 0.9,
1461        intersection_over_min: bool = False,
1462        output_mode: str = "instance_segmentation",
1463        mask_threshold: Optional[Union[float, str]] = None,
1464        refine_with_box_prompts: bool = False,
1465        prompt_function: Optional[callable] = None,
1466        optimize_memory: bool = False,
1467    ) -> List[Dict[str, Any]]:
1468        """Generate tiling-based instance segmentation for the currently initialized image.
1469
1470        Args:
1471            min_size: Minimal object size in the segmentation result. By default, set to '25'.
1472            center_distance_threshold: The threshold for the center distance predictions.
1473            boundary_distance_threshold: The threshold for the boundary distance predictions.
1474            multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
1475            batch_size: The batch size for parallelizing the prediction based on prompts.
1476            nms_threshold: The threshold for non-maximum suppression (NMS).
1477            intersection_over_min: Whether to use the minimum area of the two objects or the
1478                intersection over union area (default) in NMS.
1479            output_mode: The form masks are returned in. Possible values are:
1480                - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
1481                - 'instance_segmentation': Return masks merged into an instance segmentation in a single array.
1482                By default, set to 'instance_segmentation'.
1483            mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.`
1484            refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps
1485                derived from the segmentations after point prompts. Currently not supported for tiled segmentation.
1486            prompt_function: A custom function for deriving prompts from the segmentation decoder predictions.
1487                If given, the default prompt derivation procedure is not used. Must have the following signature:
1488                ```
1489                    def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
1490                ```
1491                where `foreground`, `center_distances`, and `boundary_distances` are the respective
1492                predictions from the segmentation decoder. It must returns a dictionary containing
1493                either point, box, or mask prompts in a format compattible with `micro_sam.inference.batched_inference`.
1494            optimize_memory: Whether to optimize the memory consumption by merging the per-slice
1495                segmentation results immediatly with NMS, rather than running a NMS for all results.
1496                This may lead to a slightly different segmentation result and is only compatible with
1497                `refine_with_box_prompts=False` and `output_mode="instance_segmentation"`.
1498
1499        Returns:
1500            The instance segmentation masks.
1501        """
1502        if not self.is_initialized:
1503            raise RuntimeError("TiledAutomaticPromptGenerator has not been initialized. Call initialize first.")
1504        if optimize_memory and (output_mode != "instance_segmentation" or refine_with_box_prompts):
1505            raise ValueError("Invalid settings")
1506        foreground, center_distances, boundary_distances =\
1507            self._foreground, self._center_distances, self._boundary_distances
1508
1509        # 1.) Derive promtps from the decoder predictions.
1510        prompt_function = _derive_point_prompts if prompt_function is None else prompt_function
1511        prompts = prompt_function(
1512            foreground,
1513            center_distances,
1514            boundary_distances,
1515            foreground_threshold=foreground_threshold,
1516            center_distance_threshold=center_distance_threshold,
1517            boundary_distance_threshold=boundary_distance_threshold,
1518        )
1519
1520        # 2.) Apply the predictor to the prompts.
1521        shape = foreground.shape
1522        if prompts is None:  # No prompts were derived, we can't do much further and return empty masks.
1523            return np.zeros(shape, dtype="uint32") if output_mode == "instance_segmentation" else []
1524        else:
1525            if optimize_memory:
1526                prompts.update(dict(
1527                    min_size=min_size, nms_thresh=nms_threshold, intersection_over_min=intersection_over_min
1528                ))
1529            predictions = batched_tiled_inference(
1530                self._predictor,
1531                image=None,
1532                batch_size=batch_size,
1533                image_embeddings=self._image_embeddings,
1534                return_instance_segmentation=False,
1535                multimasking=multimasking,
1536                optimize_memory=optimize_memory,
1537                i=getattr(self, "_i", None),
1538                **prompts
1539            )
1540        # Optimize memory directly returns an instance segmentation and does not
1541        # allow for any further refinements.
1542        if optimize_memory:
1543            return predictions
1544
1545        # 3.) Refine the segmentation with box prompts.
1546        if refine_with_box_prompts:
1547            # TODO
1548            raise NotImplementedError
1549
1550        # 4.) Apply non-max suppression to the masks.
1551        segmentation = util.apply_nms(
1552            predictions, shape=shape, min_size=min_size, nms_thresh=nms_threshold,
1553            intersection_over_min=intersection_over_min,
1554        )
1555        if output_mode != "instance_segmentation":
1556            segmentation = self._to_masks(segmentation, output_mode)
1557        return segmentation

Generate tiling-based instance segmentation for the currently initialized image.

Arguments:
  • min_size: Minimal object size in the segmentation result. By default, set to '25'.
  • center_distance_threshold: The threshold for the center distance predictions.
  • boundary_distance_threshold: The threshold for the boundary distance predictions.
  • multimasking: Whether to use multi-mask prediction for turning the prompts into masks.
  • batch_size: The batch size for parallelizing the prediction based on prompts.
  • nms_threshold: The threshold for non-maximum suppression (NMS).
  • intersection_over_min: Whether to use the minimum area of the two objects or the intersection over union area (default) in NMS.
  • output_mode: The form masks are returned in. Possible values are:
    • 'binary_mask': Return a list of dictionaries with masks encoded as binary masks.
    • 'instance_segmentation': Return masks merged into an instance segmentation in a single array. By default, set to 'instance_segmentation'.
  • mask_threshold: The threshold for turining logits into masks in micro_sam.inference.batched_inference.`
  • refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps derived from the segmentations after point prompts. Currently not supported for tiled segmentation.
  • prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. If given, the default prompt derivation procedure is not used. Must have the following signature:
        def prompt_function(foreground, center_distances, boundary_distances, **kwargs)
    
    where foreground, center_distances, and boundary_distances are the respective predictions from the segmentation decoder. It must returns a dictionary containing either point, box, or mask prompts in a format compattible with micro_sam.inference.batched_inference.
  • optimize_memory: Whether to optimize the memory consumption by merging the per-slice segmentation results immediatly with NMS, rather than running a NMS for all results. This may lead to a slightly different segmentation result and is only compatible with refine_with_box_prompts=False and output_mode="instance_segmentation".
Returns:

The instance segmentation masks.

def get_instance_segmentation_generator( predictor: segment_anything.predictor.SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.modules.module.Module] = None, segmentation_mode: Optional[str] = None, **kwargs) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1572def get_instance_segmentation_generator(
1573    predictor: SamPredictor,
1574    is_tiled: bool,
1575    decoder: Optional[torch.nn.Module] = None,
1576    segmentation_mode: Optional[str] = None,
1577    **kwargs,
1578) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1579    f"""Get the automatic mask generator.
1580
1581    Args:
1582        predictor: The segment anything predictor.
1583        is_tiled: Whether tiled embeddings are used.
1584        decoder: Decoder to predict instacne segmmentation.
1585        segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'.
1586            By default, '{DEFAULT_SEGMENTATION_MODE_WITH_DECODER}' is used if a decoder is passed,
1587            otherwise 'amg' is used.
1588        kwargs: The keyword arguments of the segmentation genetator class.
1589
1590    Returns:
1591        The segmentation generator instance.
1592    """
1593    # Choose the segmentation decoder default depending on whether we have a decoder.
1594    if segmentation_mode is None:
1595        segmentation_mode = "amg" if decoder is None else DEFAULT_SEGMENTATION_MODE_WITH_DECODER
1596
1597    if segmentation_mode.lower() == "amg":
1598        segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator
1599        segmenter = segmenter_class(predictor, **kwargs)
1600    elif segmentation_mode.lower() == "ais":
1601        assert decoder is not None
1602        segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder
1603        segmenter = segmenter_class(predictor, decoder, **kwargs)
1604    elif segmentation_mode.lower() == "apg":
1605        assert decoder is not None
1606        segmenter_class = TiledAutomaticPromptGenerator if is_tiled else AutomaticPromptGenerator
1607        segmenter = segmenter_class(predictor, decoder, **kwargs)
1608    else:
1609        raise ValueError(f"Invalid segmentation_mode: {segmentation_mode}. Choose one of 'amg', 'ais', or 'apg'.")
1610
1611    return segmenter