micro_sam.instance_segmentation

Automated instance segmentation functionality. The classes implemented here extend the automatic instance segmentation from Segment Anything: https://computational-cell-analytics.github.io/micro-sam/micro_sam.html

   1"""
   2Automated instance segmentation functionality.
   3The classes implemented here extend the automatic instance segmentation from Segment Anything:
   4https://computational-cell-analytics.github.io/micro-sam/micro_sam.html
   5"""
   6
   7import os
   8from abc import ABC
   9from copy import deepcopy
  10from collections import OrderedDict
  11from typing import Any, Dict, List, Optional, Tuple, Union
  12
  13import vigra
  14import numpy as np
  15import torch
  16
  17from skimage.measure import label, regionprops
  18from skimage.segmentation import relabel_sequential
  19from torchvision.ops.boxes import batched_nms, box_area
  20
  21from torch_em.model import UNETR
  22from torch_em.util.segmentation import watershed_from_center_and_boundary_distances
  23
  24from nifty.tools import blocking
  25
  26import segment_anything.utils.amg as amg_utils
  27from segment_anything.predictor import SamPredictor
  28
  29from . import util
  30from ._vendored import batched_mask_to_box, mask_to_rle_pytorch
  31
  32#
  33# Utility Functionality
  34#
  35
  36
  37class _FakeInput:
  38    def __init__(self, shape):
  39        self.shape = shape
  40
  41    def __getitem__(self, index):
  42        block_shape = tuple(ind.stop - ind.start for ind in index)
  43        return np.zeros(block_shape, dtype="float32")
  44
  45
  46def mask_data_to_segmentation(
  47    masks: List[Dict[str, Any]],
  48    with_background: bool,
  49    min_object_size: int = 0,
  50    max_object_size: Optional[int] = None,
  51    label_masks: bool = True,
  52) -> np.ndarray:
  53    """Convert the output of the automatic mask generation to an instance segmentation.
  54
  55    Args:
  56        masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator.
  57            Only supports output_mode=binary_mask.
  58        with_background: Whether the segmentation has background. If yes this function assures that the largest
  59            object in the output will be mapped to zero (the background value).
  60        min_object_size: The minimal size of an object in pixels.
  61        max_object_size: The maximal size of an object in pixels.
  62        label_masks: Whether to apply connected components to the result before removing small objects.
  63
  64    Returns:
  65        The instance segmentation.
  66    """
  67
  68    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
  69    # we could also get the shape from the crop box
  70    shape = next(iter(masks))["segmentation"].shape
  71    segmentation = np.zeros(shape, dtype="uint32")
  72
  73    def require_numpy(mask):
  74        return mask.cpu().numpy() if torch.is_tensor(mask) else mask
  75
  76    seg_id = 1
  77    for mask in masks:
  78        if mask["area"] < min_object_size:
  79            continue
  80        if max_object_size is not None and mask["area"] > max_object_size:
  81            continue
  82
  83        this_seg_id = mask.get("seg_id", seg_id)
  84        segmentation[require_numpy(mask["segmentation"])] = this_seg_id
  85        seg_id = this_seg_id + 1
  86
  87    if label_masks:
  88        segmentation = label(segmentation).astype(segmentation.dtype)
  89
  90    seg_ids, sizes = np.unique(segmentation, return_counts=True)
  91
  92    # In some cases objects may be smaller than peviously calculated,
  93    # since they are covered by other objects. We ensure these also get
  94    # filtered out here.
  95    filter_ids = seg_ids[sizes < min_object_size]
  96
  97    # If we run segmentation with background we also map the largest segment
  98    # (the most likely background object) to zero. This is often zero already,
  99    # but it does not hurt to reset that to zero either.
 100    if with_background:
 101        bg_id = seg_ids[np.argmax(sizes)]
 102        filter_ids = np.concatenate([filter_ids, [bg_id]])
 103
 104    segmentation[np.isin(segmentation, filter_ids)] = 0
 105    segmentation = relabel_sequential(segmentation)[0]
 106
 107    return segmentation
 108
 109
 110#
 111# Classes for automatic instance segmentation
 112#
 113
 114
 115class AMGBase(ABC):
 116    """Base class for the automatic mask generators.
 117    """
 118    def __init__(self):
 119        # the state that has to be computed by the 'initialize' method of the child classes
 120        self._is_initialized = False
 121        self._crop_list = None
 122        self._crop_boxes = None
 123        self._original_size = None
 124
 125    @property
 126    def is_initialized(self):
 127        """Whether the mask generator has already been initialized.
 128        """
 129        return self._is_initialized
 130
 131    @property
 132    def crop_list(self):
 133        """The list of mask data after initialization.
 134        """
 135        return self._crop_list
 136
 137    @property
 138    def crop_boxes(self):
 139        """The list of crop boxes.
 140        """
 141        return self._crop_boxes
 142
 143    @property
 144    def original_size(self):
 145        """The original image size.
 146        """
 147        return self._original_size
 148
 149    def _postprocess_batch(
 150        self,
 151        data,
 152        crop_box,
 153        original_size,
 154        pred_iou_thresh,
 155        stability_score_thresh,
 156        box_nms_thresh,
 157    ):
 158        orig_h, orig_w = original_size
 159
 160        # filter by predicted IoU
 161        if pred_iou_thresh > 0.0:
 162            keep_mask = data["iou_preds"] > pred_iou_thresh
 163            data.filter(keep_mask)
 164
 165        # filter by stability score
 166        if stability_score_thresh > 0.0:
 167            keep_mask = data["stability_score"] >= stability_score_thresh
 168            data.filter(keep_mask)
 169
 170        # filter boxes that touch crop boundaries
 171        keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
 172        if not torch.all(keep_mask):
 173            data.filter(keep_mask)
 174
 175        # remove duplicates within this crop.
 176        keep_by_nms = batched_nms(
 177            data["boxes"].float(),
 178            data["iou_preds"],
 179            torch.zeros_like(data["boxes"][:, 0]),  # categories
 180            iou_threshold=box_nms_thresh,
 181        )
 182        data.filter(keep_by_nms)
 183
 184        # return to the original image frame
 185        data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box)
 186        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
 187        # the data from embedding based segmentation doesn't have the points
 188        # so we skip if the corresponding key can't be found
 189        try:
 190            data["points"] = amg_utils.uncrop_points(data["points"], crop_box)
 191        except KeyError:
 192            pass
 193
 194        return data
 195
 196    def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
 197
 198        if len(mask_data["rles"]) == 0:
 199            return mask_data
 200
 201        # filter small disconnected regions and holes
 202        new_masks = []
 203        scores = []
 204        for rle in mask_data["rles"]:
 205            mask = amg_utils.rle_to_mask(rle)
 206
 207            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes")
 208            unchanged = not changed
 209            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands")
 210            unchanged = unchanged and not changed
 211
 212            new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0))
 213            # give score=0 to changed masks and score=1 to unchanged masks
 214            # so NMS will prefer ones that didn't need postprocessing
 215            scores.append(float(unchanged))
 216
 217        # recalculate boxes and remove any new duplicates
 218        masks = torch.cat(new_masks, dim=0)
 219        boxes = batched_mask_to_box(masks.to(torch.bool))  # Casting this to boolean as we work with one-hot labels.
 220        keep_by_nms = batched_nms(
 221            boxes.float(),
 222            torch.as_tensor(scores, dtype=torch.float),
 223            torch.zeros_like(boxes[:, 0]),  # categories
 224            iou_threshold=nms_thresh,
 225        )
 226
 227        # only recalculate RLEs for masks that have changed
 228        for i_mask in keep_by_nms:
 229            if scores[i_mask] == 0.0:
 230                mask_torch = masks[i_mask].unsqueeze(0)
 231                # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
 232                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
 233                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
 234        mask_data.filter(keep_by_nms)
 235
 236        return mask_data
 237
 238    def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode):
 239        # filter small disconnected regions and holes in masks
 240        if min_mask_region_area > 0:
 241            mask_data = self._postprocess_small_regions(
 242                mask_data,
 243                min_mask_region_area,
 244                max(box_nms_thresh, crop_nms_thresh),
 245            )
 246
 247        # encode masks
 248        if output_mode == "coco_rle":
 249            mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]]
 250        elif output_mode == "binary_mask":
 251            mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]]
 252        else:
 253            mask_data["segmentations"] = mask_data["rles"]
 254
 255        # write mask records
 256        curr_anns = []
 257        for idx in range(len(mask_data["segmentations"])):
 258            ann = {
 259                "segmentation": mask_data["segmentations"][idx],
 260                "area": amg_utils.area_from_rle(mask_data["rles"][idx]),
 261                "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
 262                "predicted_iou": mask_data["iou_preds"][idx].item(),
 263                "stability_score": mask_data["stability_score"][idx].item(),
 264                "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
 265            }
 266            # the data from embedding based segmentation doesn't have the points
 267            # so we skip if the corresponding key can't be found
 268            try:
 269                ann["point_coords"] = [mask_data["points"][idx].tolist()]
 270            except KeyError:
 271                pass
 272            curr_anns.append(ann)
 273
 274        return curr_anns
 275
 276    def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
 277        orig_h, orig_w = original_size
 278
 279        # serialize predictions and store in MaskData
 280        data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
 281        if points is not None:
 282            data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
 283
 284        del masks
 285
 286        # calculate the stability scores
 287        data["stability_score"] = amg_utils.calculate_stability_score(
 288            data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
 289        )
 290
 291        # threshold masks and calculate boxes
 292        data["masks"] = data["masks"] > self._predictor.model.mask_threshold
 293        data["masks"] = data["masks"].type(torch.bool)
 294        data["boxes"] = batched_mask_to_box(data["masks"])
 295
 296        # compress to RLE
 297        data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
 298        # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
 299        data["rles"] = mask_to_rle_pytorch(data["masks"])
 300        del data["masks"]
 301
 302        return data
 303
 304    def get_state(self) -> Dict[str, Any]:
 305        """Get the initialized state of the mask generator.
 306
 307        Returns:
 308            State of the mask generator.
 309        """
 310        if not self.is_initialized:
 311            raise RuntimeError("The state has not been computed yet. Call initialize first.")
 312
 313        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
 314
 315    def set_state(self, state: Dict[str, Any]) -> None:
 316        """Set the state of the mask generator.
 317
 318        Args:
 319            state: The state of the mask generator, e.g. from serialized state.
 320        """
 321        self._crop_list = state["crop_list"]
 322        self._crop_boxes = state["crop_boxes"]
 323        self._original_size = state["original_size"]
 324        self._is_initialized = True
 325
 326    def clear_state(self):
 327        """Clear the state of the mask generator.
 328        """
 329        self._crop_list = None
 330        self._crop_boxes = None
 331        self._original_size = None
 332        self._is_initialized = False
 333
 334
 335class AutomaticMaskGenerator(AMGBase):
 336    """Generates an instance segmentation without prompts, using a point grid.
 337
 338    This class implements the same logic as
 339    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
 340    It decouples the computationally expensive steps of generating masks from the cheap post-processing operation
 341    to filter these masks to enable grid search and interactively changing the post-processing.
 342
 343    Use this class as follows:
 344    ```python
 345    amg = AutomaticMaskGenerator(predictor)
 346    amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
 347    masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
 348    ```
 349
 350    Args:
 351        predictor: The segment anything predictor.
 352        points_per_side: The number of points to be sampled along one side of the image.
 353            If None, `point_grids` must provide explicit point sampling.
 354        points_per_batch: The number of points run simultaneously by the model.
 355            Higher numbers may be faster but use more GPU memory.
 356        crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
 357        crop_overlap_ratio: Sets the degree to which crops overlap.
 358        crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
 359        point_grids: A lisst over explicit grids of points used for sampling masks.
 360            Normalized to [0, 1] with respect to the image coordinate system.
 361        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
 362    """
 363    def __init__(
 364        self,
 365        predictor: SamPredictor,
 366        points_per_side: Optional[int] = 32,
 367        points_per_batch: Optional[int] = None,
 368        crop_n_layers: int = 0,
 369        crop_overlap_ratio: float = 512 / 1500,
 370        crop_n_points_downscale_factor: int = 1,
 371        point_grids: Optional[List[np.ndarray]] = None,
 372        stability_score_offset: float = 1.0,
 373    ):
 374        super().__init__()
 375
 376        if points_per_side is not None:
 377            self.point_grids = amg_utils.build_all_layer_point_grids(
 378                points_per_side,
 379                crop_n_layers,
 380                crop_n_points_downscale_factor,
 381            )
 382        elif point_grids is not None:
 383            self.point_grids = point_grids
 384        else:
 385            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
 386
 387        self._predictor = predictor
 388        self._points_per_side = points_per_side
 389
 390        # we set the points per batch to 16 for mps for performance reasons
 391        # and otherwise keep them at the default of 64
 392        if points_per_batch is None:
 393            points_per_batch = 16 if str(predictor.device) == "mps" else 64
 394        self._points_per_batch = points_per_batch
 395
 396        self._crop_n_layers = crop_n_layers
 397        self._crop_overlap_ratio = crop_overlap_ratio
 398        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
 399        self._stability_score_offset = stability_score_offset
 400
 401    def _process_batch(self, points, im_size, crop_box, original_size):
 402        # run model on this batch
 403        transformed_points = self._predictor.transform.apply_coords(points, im_size)
 404        in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
 405        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
 406        masks, iou_preds, _ = self._predictor.predict_torch(
 407            in_points[:, None, :],
 408            in_labels[:, None],
 409            multimask_output=True,
 410            return_logits=True,
 411        )
 412        data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
 413        del masks
 414        return data
 415
 416    def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None):
 417        # Crop the image and calculate embeddings.
 418        x0, y0, x1, y1 = crop_box
 419        cropped_im = image[y0:y1, x0:x1, :]
 420        cropped_im_size = cropped_im.shape[:2]
 421
 422        if not precomputed_embeddings:
 423            self._predictor.set_image(cropped_im)
 424
 425        # Get the points for this crop.
 426        points_scale = np.array(cropped_im_size)[None, ::-1]
 427        points_for_image = self.point_grids[crop_layer_idx] * points_scale
 428
 429        # Generate masks for this crop in batches.
 430        data = amg_utils.MaskData()
 431        n_batches = len(points_for_image) // self._points_per_batch +\
 432            int(len(points_for_image) % self._points_per_batch != 0)
 433        if pbar_init is not None:
 434            pbar_init(n_batches, "Predict masks for point grid prompts")
 435
 436        for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image):
 437            batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
 438            data.cat(batch_data)
 439            del batch_data
 440            if pbar_update is not None:
 441                pbar_update(1)
 442
 443        if not precomputed_embeddings:
 444            self._predictor.reset_image()
 445
 446        return data
 447
 448    @torch.no_grad()
 449    def initialize(
 450        self,
 451        image: np.ndarray,
 452        image_embeddings: Optional[util.ImageEmbeddings] = None,
 453        i: Optional[int] = None,
 454        verbose: bool = False,
 455        pbar_init: Optional[callable] = None,
 456        pbar_update: Optional[callable] = None,
 457    ) -> None:
 458        """Initialize image embeddings and masks for an image.
 459
 460        Args:
 461            image: The input image, volume or timeseries.
 462            image_embeddings: Optional precomputed image embeddings.
 463                See `util.precompute_image_embeddings` for details.
 464            i: Index for the image data. Required if `image` has three spatial dimensions
 465                or a time dimension and two spatial dimensions.
 466            verbose: Whether to print computation progress.
 467            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 468                Can be used together with pbar_update to handle napari progress bar in other thread.
 469                To enables using this function within a threadworker.
 470            pbar_update: Callback to update an external progress bar.
 471        """
 472        original_size = image.shape[:2]
 473        self._original_size = original_size
 474
 475        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
 476            original_size, self._crop_n_layers, self._crop_overlap_ratio
 477        )
 478
 479        # We can set fixed image embeddings if we only have a single crop box (the default setting).
 480        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
 481        if len(crop_boxes) == 1:
 482            if image_embeddings is None:
 483                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 484            util.set_precomputed(self._predictor, image_embeddings, i=i)
 485            precomputed_embeddings = True
 486        else:
 487            precomputed_embeddings = False
 488
 489        # we need to cast to the image representation that is compatible with SAM
 490        image = util._to_image(image)
 491
 492        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 493
 494        crop_list = []
 495        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
 496            crop_data = self._process_crop(
 497                image, crop_box, layer_idx,
 498                precomputed_embeddings=precomputed_embeddings,
 499                pbar_init=pbar_init, pbar_update=pbar_update,
 500            )
 501            crop_list.append(crop_data)
 502        pbar_close()
 503
 504        self._is_initialized = True
 505        self._crop_list = crop_list
 506        self._crop_boxes = crop_boxes
 507
 508    @torch.no_grad()
 509    def generate(
 510        self,
 511        pred_iou_thresh: float = 0.88,
 512        stability_score_thresh: float = 0.95,
 513        box_nms_thresh: float = 0.7,
 514        crop_nms_thresh: float = 0.7,
 515        min_mask_region_area: int = 0,
 516        output_mode: str = "binary_mask",
 517    ) -> List[Dict[str, Any]]:
 518        """Generate instance segmentation for the currently initialized image.
 519
 520        Args:
 521            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
 522            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
 523                under changes to the cutoff used to binarize the model prediction.
 524            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
 525            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
 526            min_mask_region_area: Minimal size for the predicted masks.
 527            output_mode: The form masks are returned in.
 528
 529        Returns:
 530            The instance segmentation masks.
 531        """
 532        if not self.is_initialized:
 533            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
 534
 535        data = amg_utils.MaskData()
 536        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
 537            crop_data = self._postprocess_batch(
 538                data=deepcopy(data_),
 539                crop_box=crop_box, original_size=self.original_size,
 540                pred_iou_thresh=pred_iou_thresh,
 541                stability_score_thresh=stability_score_thresh,
 542                box_nms_thresh=box_nms_thresh
 543            )
 544            data.cat(crop_data)
 545
 546        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
 547            # Prefer masks from smaller crops
 548            scores = 1 / box_area(data["crop_boxes"])
 549            scores = scores.to(data["boxes"].device)
 550            keep_by_nms = batched_nms(
 551                data["boxes"].float(),
 552                scores,
 553                torch.zeros_like(data["boxes"][:, 0]),  # categories
 554                iou_threshold=crop_nms_thresh,
 555            )
 556            data.filter(keep_by_nms)
 557
 558        data.to_numpy()
 559        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
 560        return masks
 561
 562
 563# Helper function for tiled embedding computation and checking consistent state.
 564def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo):
 565    if image_embeddings is None:
 566        if tile_shape is None or halo is None:
 567            raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.")
 568        image_embeddings = util.precompute_image_embeddings(predictor, image, tile_shape=tile_shape, halo=halo)
 569
 570    # Use tile shape and halo from the precomputed embeddings if not given.
 571    # Otherwise check that they are consistent.
 572    feats = image_embeddings["features"]
 573    tile_shape_, halo_ = feats.attrs["tile_shape"], feats.attrs["halo"]
 574    if tile_shape is None:
 575        tile_shape = tile_shape_
 576    elif tile_shape != tile_shape_:
 577        raise ValueError(
 578            f"Inconsistent tile_shape parameter {tile_shape} with precomputed embeedings: {tile_shape_}."
 579        )
 580    if halo is None:
 581        halo = halo_
 582    elif halo != halo_:
 583        raise ValueError(f"Inconsistent halo parameter {halo} with precomputed embeedings: {halo_}.")
 584
 585    return image_embeddings, tile_shape, halo
 586
 587
 588class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
 589    """Generates an instance segmentation without prompts, using a point grid.
 590
 591    Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
 592
 593    Args:
 594        predictor: The segment anything predictor.
 595        points_per_side: The number of points to be sampled along one side of the image.
 596            If None, `point_grids` must provide explicit point sampling.
 597        points_per_batch: The number of points run simultaneously by the model.
 598            Higher numbers may be faster but use more GPU memory.
 599        point_grids: A lisst over explicit grids of points used for sampling masks.
 600            Normalized to [0, 1] with respect to the image coordinate system.
 601        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
 602    """
 603
 604    # We only expose the arguments that make sense for the tiled mask generator.
 605    # Anything related to crops doesn't make sense, because we re-use that functionality
 606    # for tiling, so these parameters wouldn't have any effect.
 607    def __init__(
 608        self,
 609        predictor: SamPredictor,
 610        points_per_side: Optional[int] = 32,
 611        points_per_batch: int = 64,
 612        point_grids: Optional[List[np.ndarray]] = None,
 613        stability_score_offset: float = 1.0,
 614    ) -> None:
 615        super().__init__(
 616            predictor=predictor,
 617            points_per_side=points_per_side,
 618            points_per_batch=points_per_batch,
 619            point_grids=point_grids,
 620            stability_score_offset=stability_score_offset,
 621        )
 622
 623    @torch.no_grad()
 624    def initialize(
 625        self,
 626        image: np.ndarray,
 627        image_embeddings: Optional[util.ImageEmbeddings] = None,
 628        i: Optional[int] = None,
 629        tile_shape: Optional[Tuple[int, int]] = None,
 630        halo: Optional[Tuple[int, int]] = None,
 631        verbose: bool = False,
 632        pbar_init: Optional[callable] = None,
 633        pbar_update: Optional[callable] = None,
 634    ) -> None:
 635        """Initialize image embeddings and masks for an image.
 636
 637        Args:
 638            image: The input image, volume or timeseries.
 639            image_embeddings: Optional precomputed image embeddings.
 640                See `util.precompute_image_embeddings` for details.
 641            i: Index for the image data. Required if `image` has three spatial dimensions
 642                or a time dimension and two spatial dimensions.
 643            tile_shape: The tile shape for embedding prediction.
 644            halo: The overlap of between tiles.
 645            verbose: Whether to print computation progress.
 646            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 647                Can be used together with pbar_update to handle napari progress bar in other thread.
 648                To enables using this function within a threadworker.
 649            pbar_update: Callback to update an external progress bar.
 650        """
 651        original_size = image.shape[:2]
 652        self._original_size = original_size
 653
 654        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
 655            self._predictor, image, image_embeddings, tile_shape, halo
 656        )
 657
 658        tiling = blocking([0, 0], original_size, tile_shape)
 659        n_tiles = tiling.numberOfBlocks
 660
 661        # The crop box is always the full local tile.
 662        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
 663        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
 664
 665        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 666        pbar_init(n_tiles, "Compute masks for tile")
 667
 668        # We need to cast to the image representation that is compatible with SAM.
 669        image = util._to_image(image)
 670
 671        mask_data = []
 672        for tile_id in range(n_tiles):
 673            # set the pre-computed embeddings for this tile
 674            features = image_embeddings["features"][tile_id]
 675            tile_embeddings = {
 676                "features": features,
 677                "input_size": features.attrs["input_size"],
 678                "original_size": features.attrs["original_size"],
 679            }
 680            util.set_precomputed(self._predictor, tile_embeddings, i)
 681
 682            # compute the mask data for this tile and append it
 683            this_mask_data = self._process_crop(
 684                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
 685            )
 686            mask_data.append(this_mask_data)
 687            pbar_update(1)
 688        pbar_close()
 689
 690        # set the initialized data
 691        self._is_initialized = True
 692        self._crop_list = mask_data
 693        self._crop_boxes = crop_boxes
 694
 695
 696#
 697# Instance segmentation functionality based on fine-tuned decoder
 698#
 699
 700
 701class DecoderAdapter(torch.nn.Module):
 702    """Adapter to contain the UNETR decoder in a single module.
 703
 704    To apply the decoder on top of pre-computed embeddings for
 705    the segmentation functionality.
 706    See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
 707    """
 708    def __init__(self, unetr):
 709        super().__init__()
 710
 711        self.base = unetr.base
 712        self.out_conv = unetr.out_conv
 713        self.deconv_out = unetr.deconv_out
 714        self.decoder_head = unetr.decoder_head
 715        self.final_activation = unetr.final_activation
 716        self.postprocess_masks = unetr.postprocess_masks
 717
 718        self.decoder = unetr.decoder
 719        self.deconv1 = unetr.deconv1
 720        self.deconv2 = unetr.deconv2
 721        self.deconv3 = unetr.deconv3
 722        self.deconv4 = unetr.deconv4
 723
 724    def forward(self, input_, input_shape, original_shape):
 725        z12 = input_
 726
 727        z9 = self.deconv1(z12)
 728        z6 = self.deconv2(z9)
 729        z3 = self.deconv3(z6)
 730        z0 = self.deconv4(z3)
 731
 732        updated_from_encoder = [z9, z6, z3]
 733
 734        x = self.base(z12)
 735        x = self.decoder(x, encoder_inputs=updated_from_encoder)
 736        x = self.deconv_out(x)
 737
 738        x = torch.cat([x, z0], dim=1)
 739        x = self.decoder_head(x)
 740
 741        x = self.out_conv(x)
 742        if self.final_activation is not None:
 743            x = self.final_activation(x)
 744
 745        x = self.postprocess_masks(x, input_shape, original_shape)
 746        return x
 747
 748
 749def get_unetr(
 750    image_encoder: torch.nn.Module,
 751    decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
 752    device: Optional[Union[str, torch.device]] = None,
 753) -> torch.nn.Module:
 754    """Get UNETR model for automatic instance segmentation.
 755
 756    Args:
 757        image_encoder: The image encoder of the SAM model.
 758            This is used as encoder by the UNETR too.
 759        decoder_state: Optional decoder state to initialize the weights
 760            of the UNETR decoder.
 761        device: The device.
 762    Returns:
 763        The UNETR model.
 764    """
 765    device = util.get_device(device)
 766
 767    unetr = UNETR(
 768        backbone="sam",
 769        encoder=image_encoder,
 770        out_channels=3,
 771        use_sam_stats=True,
 772        final_activation="Sigmoid",
 773        use_skip_connection=False,
 774        resize_input=True,
 775    )
 776    if decoder_state is not None:
 777        unetr_state_dict = unetr.state_dict()
 778        for k, v in unetr_state_dict.items():
 779            if not k.startswith("encoder"):
 780                unetr_state_dict[k] = decoder_state[k]
 781        unetr.load_state_dict(unetr_state_dict)
 782
 783    unetr.to(device)
 784    return unetr
 785
 786
 787def get_decoder(
 788    image_encoder: torch.nn.Module,
 789    decoder_state: OrderedDict[str, torch.Tensor],
 790    device: Optional[Union[str, torch.device]] = None,
 791) -> DecoderAdapter:
 792    """Get decoder to predict outputs for automatic instance segmentation
 793
 794    Args:
 795        image_encoder: The image encoder of the SAM model.
 796        decoder_state: State to initialize the weights of the UNETR decoder.
 797        device: The device.
 798    Returns:
 799        The decoder for instance segmentation.
 800    """
 801    unetr = get_unetr(image_encoder, decoder_state, device)
 802    return DecoderAdapter(unetr)
 803
 804
 805def get_predictor_and_decoder(
 806    model_type: str,
 807    checkpoint_path: Union[str, os.PathLike],
 808    device: Optional[Union[str, torch.device]] = None,
 809    peft_kwargs: Optional[Dict] = None,
 810) -> Tuple[SamPredictor, DecoderAdapter]:
 811    """Load the SAM model (predictor) and instance segmentation decoder.
 812
 813    This requires a checkpoint that contains the state for both predictor
 814    and decoder.
 815
 816    Args:
 817        model_type: The type of the image encoder used in the SAM model.
 818        checkpoint_path: Path to the checkpoint from which to load the data.
 819        device: The device.
 820        lora_rank: The rank for low rank adaptation of the attention layers.
 821
 822    Returns:
 823        The SAM predictor.
 824        The decoder for instance segmentation.
 825    """
 826    device = util.get_device(device)
 827    predictor, state = util.get_sam_model(
 828        model_type=model_type,
 829        checkpoint_path=checkpoint_path,
 830        device=device,
 831        return_state=True,
 832        peft_kwargs=peft_kwargs,
 833    )
 834    if "decoder_state" not in state:
 835        raise ValueError(
 836            f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
 837        )
 838    decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
 839    return predictor, decoder
 840
 841
 842class InstanceSegmentationWithDecoder:
 843    """Generates an instance segmentation without prompts, using a decoder.
 844
 845    Implements the same interface as `AutomaticMaskGenerator`.
 846
 847    Use this class as follows:
 848    ```python
 849    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
 850    segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
 851    masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
 852    ```
 853
 854    Args:
 855        predictor: The segment anything predictor.
 856        decoder: The decoder to predict intermediate representations
 857            for instance segmentation.
 858    """
 859    def __init__(
 860        self,
 861        predictor: SamPredictor,
 862        decoder: torch.nn.Module,
 863    ) -> None:
 864        self._predictor = predictor
 865        self._decoder = decoder
 866
 867        # The decoder outputs.
 868        self._foreground = None
 869        self._center_distances = None
 870        self._boundary_distances = None
 871
 872        self._is_initialized = False
 873
 874    @property
 875    def is_initialized(self):
 876        """Whether the mask generator has already been initialized.
 877        """
 878        return self._is_initialized
 879
 880    @torch.no_grad()
 881    def initialize(
 882        self,
 883        image: np.ndarray,
 884        image_embeddings: Optional[util.ImageEmbeddings] = None,
 885        i: Optional[int] = None,
 886        verbose: bool = False,
 887        pbar_init: Optional[callable] = None,
 888        pbar_update: Optional[callable] = None,
 889    ) -> None:
 890        """Initialize image embeddings and decoder predictions for an image.
 891
 892        Args:
 893            image: The input image, volume or timeseries.
 894            image_embeddings: Optional precomputed image embeddings.
 895                See `util.precompute_image_embeddings` for details.
 896            i: Index for the image data. Required if `image` has three spatial dimensions
 897                or a time dimension and two spatial dimensions.
 898            verbose: Whether to be verbose.
 899            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 900                Can be used together with pbar_update to handle napari progress bar in other thread.
 901                To enables using this function within a threadworker.
 902            pbar_update: Callback to update an external progress bar.
 903        """
 904        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 905        pbar_init(1, "Initialize instance segmentation with decoder")
 906
 907        if image_embeddings is None:
 908            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 909
 910        # Get the image embeddings from the predictor.
 911        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
 912        embeddings = self._predictor.features
 913        input_shape = tuple(self._predictor.input_size)
 914        original_shape = tuple(self._predictor.original_size)
 915
 916        # Run prediction with the UNETR decoder.
 917        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
 918        assert output.shape[0] == 3, f"{output.shape}"
 919        pbar_update(1)
 920        pbar_close()
 921
 922        # Set the state.
 923        self._foreground = output[0]
 924        self._center_distances = output[1]
 925        self._boundary_distances = output[2]
 926        self._is_initialized = True
 927
 928    def _to_masks(self, segmentation, output_mode):
 929        if output_mode != "binary_mask":
 930            raise NotImplementedError
 931
 932        props = regionprops(segmentation)
 933        ndim = segmentation.ndim
 934        assert ndim in (2, 3)
 935
 936        shape = segmentation.shape
 937        if ndim == 2:
 938            crop_box = [0, shape[1], 0, shape[0]]
 939        else:
 940            crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]
 941
 942        # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
 943        def to_bbox_2d(bbox):
 944            y0, x0 = bbox[0], bbox[1]
 945            w = bbox[3] - x0
 946            h = bbox[2] - y0
 947            return [x0, w, y0, h]
 948
 949        def to_bbox_3d(bbox):
 950            z0, y0, x0 = bbox[0], bbox[1], bbox[2]
 951            w = bbox[5] - x0
 952            h = bbox[4] - y0
 953            d = bbox[3] - y0
 954            return [x0, w, y0, h, z0, d]
 955
 956        to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
 957        masks = [
 958            {
 959                "segmentation": segmentation == prop.label,
 960                "area": prop.area,
 961                "bbox": to_bbox(prop.bbox),
 962                "crop_box": crop_box,
 963                "seg_id": prop.label,
 964            } for prop in props
 965        ]
 966        return masks
 967
 968    def generate(
 969        self,
 970        center_distance_threshold: float = 0.5,
 971        boundary_distance_threshold: float = 0.5,
 972        foreground_threshold: float = 0.5,
 973        foreground_smoothing: float = 1.0,
 974        distance_smoothing: float = 1.6,
 975        min_size: int = 0,
 976        output_mode: Optional[str] = "binary_mask",
 977    ) -> List[Dict[str, Any]]:
 978        """Generate instance segmentation for the currently initialized image.
 979
 980        Args:
 981            center_distance_threshold: Center distance predictions below this value will be
 982                used to find seeds (intersected with thresholded boundary distance predictions).
 983            boundary_distance_threshold: Boundary distance predictions below this value will be
 984                used to find seeds (intersected with thresholded center distance predictions).
 985            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
 986                checkerboard artifacts in the prediction.
 987            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
 988            distance_smoothing: Sigma value for smoothing the distance predictions.
 989            min_size: Minimal object size in the segmentation result.
 990            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
 991
 992        Returns:
 993            The instance segmentation masks.
 994        """
 995        if not self.is_initialized:
 996            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
 997
 998        if foreground_smoothing > 0:
 999            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1000        else:
1001            foreground = self._foreground
1002        # Further optimization: parallel implementation using elf.parallel functionality.
1003        # (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization)
1004        segmentation = watershed_from_center_and_boundary_distances(
1005            self._center_distances, self._boundary_distances, foreground,
1006            center_distance_threshold=center_distance_threshold,
1007            boundary_distance_threshold=boundary_distance_threshold,
1008            foreground_threshold=foreground_threshold,
1009            distance_smoothing=distance_smoothing,
1010            min_size=min_size,
1011        )
1012        if output_mode is not None:
1013            segmentation = self._to_masks(segmentation, output_mode)
1014        return segmentation
1015
1016    def get_state(self) -> Dict[str, Any]:
1017        """Get the initialized state of the instance segmenter.
1018
1019        Returns:
1020            Instance segmentation state.
1021        """
1022        if not self.is_initialized:
1023            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1024
1025        return {
1026            "foreground": self._foreground,
1027            "center_distances": self._center_distances,
1028            "boundary_distances": self._boundary_distances,
1029        }
1030
1031    def set_state(self, state: Dict[str, Any]) -> None:
1032        """Set the state of the instance segmenter.
1033
1034        Args:
1035            state: The instance segmentation state
1036        """
1037        self._foreground = state["foreground"]
1038        self._center_distances = state["center_distances"]
1039        self._boundary_distances = state["boundary_distances"]
1040        self._is_initialized = True
1041
1042    def clear_state(self):
1043        """Clear the state of the instance segmenter.
1044        """
1045        self._foreground = None
1046        self._center_distances = None
1047        self._boundary_distances = None
1048        self._is_initialized = False
1049
1050
1051class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1052    """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings.
1053    """
1054
1055    @torch.no_grad()
1056    def initialize(
1057        self,
1058        image: np.ndarray,
1059        image_embeddings: Optional[util.ImageEmbeddings] = None,
1060        i: Optional[int] = None,
1061        tile_shape: Optional[Tuple[int, int]] = None,
1062        halo: Optional[Tuple[int, int]] = None,
1063        verbose: bool = False,
1064        pbar_init: Optional[callable] = None,
1065        pbar_update: Optional[callable] = None,
1066    ) -> None:
1067        """Initialize image embeddings and decoder predictions for an image.
1068
1069        Args:
1070            image: The input image, volume or timeseries.
1071            image_embeddings: Optional precomputed image embeddings.
1072                See `util.precompute_image_embeddings` for details.
1073            i: Index for the image data. Required if `image` has three spatial dimensions
1074                or a time dimension and two spatial dimensions.
1075            verbose: Dummy input to be compatible with other function signatures.
1076            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1077                Can be used together with pbar_update to handle napari progress bar in other thread.
1078                To enables using this function within a threadworker.
1079            pbar_update: Callback to update an external progress bar.
1080        """
1081        original_size = image.shape[:2]
1082        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1083            self._predictor, image, image_embeddings, tile_shape, halo
1084        )
1085        tiling = blocking([0, 0], original_size, tile_shape)
1086
1087        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1088        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1089
1090        foreground = np.zeros(original_size, dtype="float32")
1091        center_distances = np.zeros(original_size, dtype="float32")
1092        boundary_distances = np.zeros(original_size, dtype="float32")
1093
1094        for tile_id in range(tiling.numberOfBlocks):
1095
1096            # Get the image embeddings from the predictor for this tile.
1097            self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1098            embeddings = self._predictor.features
1099            input_shape = tuple(self._predictor.input_size)
1100            original_shape = tuple(self._predictor.original_size)
1101
1102            # Predict with the UNETR decoder for this tile.
1103            output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1104            assert output.shape[0] == 3, f"{output.shape}"
1105
1106            # Set the predictions in the output for this tile.
1107            block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1108            local_bb = tuple(
1109                slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1110            )
1111            inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1112
1113            foreground[inner_bb] = output[0][local_bb]
1114            center_distances[inner_bb] = output[1][local_bb]
1115            boundary_distances[inner_bb] = output[2][local_bb]
1116        pbar_close()
1117
1118        # Set the state.
1119        self._foreground = foreground
1120        self._center_distances = center_distances
1121        self._boundary_distances = boundary_distances
1122        self._is_initialized = True
1123
1124
1125def get_amg(
1126    predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs,
1127) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1128    """Get the automatic mask generator class.
1129
1130    Args:
1131        predictor: The segment anything predictor.
1132        is_tiled: Whether tiled embeddings are used.
1133        decoder: Decoder to predict instacne segmmentation.
1134        kwargs: The keyword arguments for the amg class.
1135
1136    Returns:
1137        The automatic mask generator.
1138    """
1139    if decoder is None:
1140        segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator
1141        segmenter = segmenter_class(predictor, **kwargs)
1142    else:
1143        segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder
1144        segmenter = segmenter_class(predictor, decoder, **kwargs)
1145
1146    return segmenter
def mask_data_to_segmentation( masks: List[Dict[str, Any]], with_background: bool, min_object_size: int = 0, max_object_size: Optional[int] = None, label_masks: bool = True) -> numpy.ndarray:
 47def mask_data_to_segmentation(
 48    masks: List[Dict[str, Any]],
 49    with_background: bool,
 50    min_object_size: int = 0,
 51    max_object_size: Optional[int] = None,
 52    label_masks: bool = True,
 53) -> np.ndarray:
 54    """Convert the output of the automatic mask generation to an instance segmentation.
 55
 56    Args:
 57        masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator.
 58            Only supports output_mode=binary_mask.
 59        with_background: Whether the segmentation has background. If yes this function assures that the largest
 60            object in the output will be mapped to zero (the background value).
 61        min_object_size: The minimal size of an object in pixels.
 62        max_object_size: The maximal size of an object in pixels.
 63        label_masks: Whether to apply connected components to the result before removing small objects.
 64
 65    Returns:
 66        The instance segmentation.
 67    """
 68
 69    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
 70    # we could also get the shape from the crop box
 71    shape = next(iter(masks))["segmentation"].shape
 72    segmentation = np.zeros(shape, dtype="uint32")
 73
 74    def require_numpy(mask):
 75        return mask.cpu().numpy() if torch.is_tensor(mask) else mask
 76
 77    seg_id = 1
 78    for mask in masks:
 79        if mask["area"] < min_object_size:
 80            continue
 81        if max_object_size is not None and mask["area"] > max_object_size:
 82            continue
 83
 84        this_seg_id = mask.get("seg_id", seg_id)
 85        segmentation[require_numpy(mask["segmentation"])] = this_seg_id
 86        seg_id = this_seg_id + 1
 87
 88    if label_masks:
 89        segmentation = label(segmentation).astype(segmentation.dtype)
 90
 91    seg_ids, sizes = np.unique(segmentation, return_counts=True)
 92
 93    # In some cases objects may be smaller than peviously calculated,
 94    # since they are covered by other objects. We ensure these also get
 95    # filtered out here.
 96    filter_ids = seg_ids[sizes < min_object_size]
 97
 98    # If we run segmentation with background we also map the largest segment
 99    # (the most likely background object) to zero. This is often zero already,
100    # but it does not hurt to reset that to zero either.
101    if with_background:
102        bg_id = seg_ids[np.argmax(sizes)]
103        filter_ids = np.concatenate([filter_ids, [bg_id]])
104
105    segmentation[np.isin(segmentation, filter_ids)] = 0
106    segmentation = relabel_sequential(segmentation)[0]
107
108    return segmentation

Convert the output of the automatic mask generation to an instance segmentation.

Arguments:
  • masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. Only supports output_mode=binary_mask.
  • with_background: Whether the segmentation has background. If yes this function assures that the largest object in the output will be mapped to zero (the background value).
  • min_object_size: The minimal size of an object in pixels.
  • max_object_size: The maximal size of an object in pixels.
  • label_masks: Whether to apply connected components to the result before removing small objects.
Returns:

The instance segmentation.

class AMGBase(abc.ABC):
116class AMGBase(ABC):
117    """Base class for the automatic mask generators.
118    """
119    def __init__(self):
120        # the state that has to be computed by the 'initialize' method of the child classes
121        self._is_initialized = False
122        self._crop_list = None
123        self._crop_boxes = None
124        self._original_size = None
125
126    @property
127    def is_initialized(self):
128        """Whether the mask generator has already been initialized.
129        """
130        return self._is_initialized
131
132    @property
133    def crop_list(self):
134        """The list of mask data after initialization.
135        """
136        return self._crop_list
137
138    @property
139    def crop_boxes(self):
140        """The list of crop boxes.
141        """
142        return self._crop_boxes
143
144    @property
145    def original_size(self):
146        """The original image size.
147        """
148        return self._original_size
149
150    def _postprocess_batch(
151        self,
152        data,
153        crop_box,
154        original_size,
155        pred_iou_thresh,
156        stability_score_thresh,
157        box_nms_thresh,
158    ):
159        orig_h, orig_w = original_size
160
161        # filter by predicted IoU
162        if pred_iou_thresh > 0.0:
163            keep_mask = data["iou_preds"] > pred_iou_thresh
164            data.filter(keep_mask)
165
166        # filter by stability score
167        if stability_score_thresh > 0.0:
168            keep_mask = data["stability_score"] >= stability_score_thresh
169            data.filter(keep_mask)
170
171        # filter boxes that touch crop boundaries
172        keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
173        if not torch.all(keep_mask):
174            data.filter(keep_mask)
175
176        # remove duplicates within this crop.
177        keep_by_nms = batched_nms(
178            data["boxes"].float(),
179            data["iou_preds"],
180            torch.zeros_like(data["boxes"][:, 0]),  # categories
181            iou_threshold=box_nms_thresh,
182        )
183        data.filter(keep_by_nms)
184
185        # return to the original image frame
186        data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box)
187        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
188        # the data from embedding based segmentation doesn't have the points
189        # so we skip if the corresponding key can't be found
190        try:
191            data["points"] = amg_utils.uncrop_points(data["points"], crop_box)
192        except KeyError:
193            pass
194
195        return data
196
197    def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
198
199        if len(mask_data["rles"]) == 0:
200            return mask_data
201
202        # filter small disconnected regions and holes
203        new_masks = []
204        scores = []
205        for rle in mask_data["rles"]:
206            mask = amg_utils.rle_to_mask(rle)
207
208            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes")
209            unchanged = not changed
210            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands")
211            unchanged = unchanged and not changed
212
213            new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0))
214            # give score=0 to changed masks and score=1 to unchanged masks
215            # so NMS will prefer ones that didn't need postprocessing
216            scores.append(float(unchanged))
217
218        # recalculate boxes and remove any new duplicates
219        masks = torch.cat(new_masks, dim=0)
220        boxes = batched_mask_to_box(masks.to(torch.bool))  # Casting this to boolean as we work with one-hot labels.
221        keep_by_nms = batched_nms(
222            boxes.float(),
223            torch.as_tensor(scores, dtype=torch.float),
224            torch.zeros_like(boxes[:, 0]),  # categories
225            iou_threshold=nms_thresh,
226        )
227
228        # only recalculate RLEs for masks that have changed
229        for i_mask in keep_by_nms:
230            if scores[i_mask] == 0.0:
231                mask_torch = masks[i_mask].unsqueeze(0)
232                # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
233                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
234                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
235        mask_data.filter(keep_by_nms)
236
237        return mask_data
238
239    def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode):
240        # filter small disconnected regions and holes in masks
241        if min_mask_region_area > 0:
242            mask_data = self._postprocess_small_regions(
243                mask_data,
244                min_mask_region_area,
245                max(box_nms_thresh, crop_nms_thresh),
246            )
247
248        # encode masks
249        if output_mode == "coco_rle":
250            mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]]
251        elif output_mode == "binary_mask":
252            mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]]
253        else:
254            mask_data["segmentations"] = mask_data["rles"]
255
256        # write mask records
257        curr_anns = []
258        for idx in range(len(mask_data["segmentations"])):
259            ann = {
260                "segmentation": mask_data["segmentations"][idx],
261                "area": amg_utils.area_from_rle(mask_data["rles"][idx]),
262                "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
263                "predicted_iou": mask_data["iou_preds"][idx].item(),
264                "stability_score": mask_data["stability_score"][idx].item(),
265                "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
266            }
267            # the data from embedding based segmentation doesn't have the points
268            # so we skip if the corresponding key can't be found
269            try:
270                ann["point_coords"] = [mask_data["points"][idx].tolist()]
271            except KeyError:
272                pass
273            curr_anns.append(ann)
274
275        return curr_anns
276
277    def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
278        orig_h, orig_w = original_size
279
280        # serialize predictions and store in MaskData
281        data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
282        if points is not None:
283            data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
284
285        del masks
286
287        # calculate the stability scores
288        data["stability_score"] = amg_utils.calculate_stability_score(
289            data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
290        )
291
292        # threshold masks and calculate boxes
293        data["masks"] = data["masks"] > self._predictor.model.mask_threshold
294        data["masks"] = data["masks"].type(torch.bool)
295        data["boxes"] = batched_mask_to_box(data["masks"])
296
297        # compress to RLE
298        data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
299        # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
300        data["rles"] = mask_to_rle_pytorch(data["masks"])
301        del data["masks"]
302
303        return data
304
305    def get_state(self) -> Dict[str, Any]:
306        """Get the initialized state of the mask generator.
307
308        Returns:
309            State of the mask generator.
310        """
311        if not self.is_initialized:
312            raise RuntimeError("The state has not been computed yet. Call initialize first.")
313
314        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
315
316    def set_state(self, state: Dict[str, Any]) -> None:
317        """Set the state of the mask generator.
318
319        Args:
320            state: The state of the mask generator, e.g. from serialized state.
321        """
322        self._crop_list = state["crop_list"]
323        self._crop_boxes = state["crop_boxes"]
324        self._original_size = state["original_size"]
325        self._is_initialized = True
326
327    def clear_state(self):
328        """Clear the state of the mask generator.
329        """
330        self._crop_list = None
331        self._crop_boxes = None
332        self._original_size = None
333        self._is_initialized = False

Base class for the automatic mask generators.

is_initialized
126    @property
127    def is_initialized(self):
128        """Whether the mask generator has already been initialized.
129        """
130        return self._is_initialized

Whether the mask generator has already been initialized.

crop_list
132    @property
133    def crop_list(self):
134        """The list of mask data after initialization.
135        """
136        return self._crop_list

The list of mask data after initialization.

crop_boxes
138    @property
139    def crop_boxes(self):
140        """The list of crop boxes.
141        """
142        return self._crop_boxes

The list of crop boxes.

original_size
144    @property
145    def original_size(self):
146        """The original image size.
147        """
148        return self._original_size

The original image size.

def get_state(self) -> Dict[str, Any]:
305    def get_state(self) -> Dict[str, Any]:
306        """Get the initialized state of the mask generator.
307
308        Returns:
309            State of the mask generator.
310        """
311        if not self.is_initialized:
312            raise RuntimeError("The state has not been computed yet. Call initialize first.")
313
314        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}

Get the initialized state of the mask generator.

Returns:

State of the mask generator.

def set_state(self, state: Dict[str, Any]) -> None:
316    def set_state(self, state: Dict[str, Any]) -> None:
317        """Set the state of the mask generator.
318
319        Args:
320            state: The state of the mask generator, e.g. from serialized state.
321        """
322        self._crop_list = state["crop_list"]
323        self._crop_boxes = state["crop_boxes"]
324        self._original_size = state["original_size"]
325        self._is_initialized = True

Set the state of the mask generator.

Arguments:
  • state: The state of the mask generator, e.g. from serialized state.
def clear_state(self):
327    def clear_state(self):
328        """Clear the state of the mask generator.
329        """
330        self._crop_list = None
331        self._crop_boxes = None
332        self._original_size = None
333        self._is_initialized = False

Clear the state of the mask generator.

class AutomaticMaskGenerator(AMGBase):
336class AutomaticMaskGenerator(AMGBase):
337    """Generates an instance segmentation without prompts, using a point grid.
338
339    This class implements the same logic as
340    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
341    It decouples the computationally expensive steps of generating masks from the cheap post-processing operation
342    to filter these masks to enable grid search and interactively changing the post-processing.
343
344    Use this class as follows:
345    ```python
346    amg = AutomaticMaskGenerator(predictor)
347    amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
348    masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
349    ```
350
351    Args:
352        predictor: The segment anything predictor.
353        points_per_side: The number of points to be sampled along one side of the image.
354            If None, `point_grids` must provide explicit point sampling.
355        points_per_batch: The number of points run simultaneously by the model.
356            Higher numbers may be faster but use more GPU memory.
357        crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
358        crop_overlap_ratio: Sets the degree to which crops overlap.
359        crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
360        point_grids: A lisst over explicit grids of points used for sampling masks.
361            Normalized to [0, 1] with respect to the image coordinate system.
362        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
363    """
364    def __init__(
365        self,
366        predictor: SamPredictor,
367        points_per_side: Optional[int] = 32,
368        points_per_batch: Optional[int] = None,
369        crop_n_layers: int = 0,
370        crop_overlap_ratio: float = 512 / 1500,
371        crop_n_points_downscale_factor: int = 1,
372        point_grids: Optional[List[np.ndarray]] = None,
373        stability_score_offset: float = 1.0,
374    ):
375        super().__init__()
376
377        if points_per_side is not None:
378            self.point_grids = amg_utils.build_all_layer_point_grids(
379                points_per_side,
380                crop_n_layers,
381                crop_n_points_downscale_factor,
382            )
383        elif point_grids is not None:
384            self.point_grids = point_grids
385        else:
386            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
387
388        self._predictor = predictor
389        self._points_per_side = points_per_side
390
391        # we set the points per batch to 16 for mps for performance reasons
392        # and otherwise keep them at the default of 64
393        if points_per_batch is None:
394            points_per_batch = 16 if str(predictor.device) == "mps" else 64
395        self._points_per_batch = points_per_batch
396
397        self._crop_n_layers = crop_n_layers
398        self._crop_overlap_ratio = crop_overlap_ratio
399        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
400        self._stability_score_offset = stability_score_offset
401
402    def _process_batch(self, points, im_size, crop_box, original_size):
403        # run model on this batch
404        transformed_points = self._predictor.transform.apply_coords(points, im_size)
405        in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
406        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
407        masks, iou_preds, _ = self._predictor.predict_torch(
408            in_points[:, None, :],
409            in_labels[:, None],
410            multimask_output=True,
411            return_logits=True,
412        )
413        data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
414        del masks
415        return data
416
417    def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None):
418        # Crop the image and calculate embeddings.
419        x0, y0, x1, y1 = crop_box
420        cropped_im = image[y0:y1, x0:x1, :]
421        cropped_im_size = cropped_im.shape[:2]
422
423        if not precomputed_embeddings:
424            self._predictor.set_image(cropped_im)
425
426        # Get the points for this crop.
427        points_scale = np.array(cropped_im_size)[None, ::-1]
428        points_for_image = self.point_grids[crop_layer_idx] * points_scale
429
430        # Generate masks for this crop in batches.
431        data = amg_utils.MaskData()
432        n_batches = len(points_for_image) // self._points_per_batch +\
433            int(len(points_for_image) % self._points_per_batch != 0)
434        if pbar_init is not None:
435            pbar_init(n_batches, "Predict masks for point grid prompts")
436
437        for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image):
438            batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
439            data.cat(batch_data)
440            del batch_data
441            if pbar_update is not None:
442                pbar_update(1)
443
444        if not precomputed_embeddings:
445            self._predictor.reset_image()
446
447        return data
448
449    @torch.no_grad()
450    def initialize(
451        self,
452        image: np.ndarray,
453        image_embeddings: Optional[util.ImageEmbeddings] = None,
454        i: Optional[int] = None,
455        verbose: bool = False,
456        pbar_init: Optional[callable] = None,
457        pbar_update: Optional[callable] = None,
458    ) -> None:
459        """Initialize image embeddings and masks for an image.
460
461        Args:
462            image: The input image, volume or timeseries.
463            image_embeddings: Optional precomputed image embeddings.
464                See `util.precompute_image_embeddings` for details.
465            i: Index for the image data. Required if `image` has three spatial dimensions
466                or a time dimension and two spatial dimensions.
467            verbose: Whether to print computation progress.
468            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
469                Can be used together with pbar_update to handle napari progress bar in other thread.
470                To enables using this function within a threadworker.
471            pbar_update: Callback to update an external progress bar.
472        """
473        original_size = image.shape[:2]
474        self._original_size = original_size
475
476        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
477            original_size, self._crop_n_layers, self._crop_overlap_ratio
478        )
479
480        # We can set fixed image embeddings if we only have a single crop box (the default setting).
481        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
482        if len(crop_boxes) == 1:
483            if image_embeddings is None:
484                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
485            util.set_precomputed(self._predictor, image_embeddings, i=i)
486            precomputed_embeddings = True
487        else:
488            precomputed_embeddings = False
489
490        # we need to cast to the image representation that is compatible with SAM
491        image = util._to_image(image)
492
493        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
494
495        crop_list = []
496        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
497            crop_data = self._process_crop(
498                image, crop_box, layer_idx,
499                precomputed_embeddings=precomputed_embeddings,
500                pbar_init=pbar_init, pbar_update=pbar_update,
501            )
502            crop_list.append(crop_data)
503        pbar_close()
504
505        self._is_initialized = True
506        self._crop_list = crop_list
507        self._crop_boxes = crop_boxes
508
509    @torch.no_grad()
510    def generate(
511        self,
512        pred_iou_thresh: float = 0.88,
513        stability_score_thresh: float = 0.95,
514        box_nms_thresh: float = 0.7,
515        crop_nms_thresh: float = 0.7,
516        min_mask_region_area: int = 0,
517        output_mode: str = "binary_mask",
518    ) -> List[Dict[str, Any]]:
519        """Generate instance segmentation for the currently initialized image.
520
521        Args:
522            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
523            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
524                under changes to the cutoff used to binarize the model prediction.
525            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
526            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
527            min_mask_region_area: Minimal size for the predicted masks.
528            output_mode: The form masks are returned in.
529
530        Returns:
531            The instance segmentation masks.
532        """
533        if not self.is_initialized:
534            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
535
536        data = amg_utils.MaskData()
537        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
538            crop_data = self._postprocess_batch(
539                data=deepcopy(data_),
540                crop_box=crop_box, original_size=self.original_size,
541                pred_iou_thresh=pred_iou_thresh,
542                stability_score_thresh=stability_score_thresh,
543                box_nms_thresh=box_nms_thresh
544            )
545            data.cat(crop_data)
546
547        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
548            # Prefer masks from smaller crops
549            scores = 1 / box_area(data["crop_boxes"])
550            scores = scores.to(data["boxes"].device)
551            keep_by_nms = batched_nms(
552                data["boxes"].float(),
553                scores,
554                torch.zeros_like(data["boxes"][:, 0]),  # categories
555                iou_threshold=crop_nms_thresh,
556            )
557            data.filter(keep_by_nms)
558
559        data.to_numpy()
560        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
561        return masks

Generates an instance segmentation without prompts, using a point grid.

This class implements the same logic as https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py It decouples the computationally expensive steps of generating masks from the cheap post-processing operation to filter these masks to enable grid search and interactively changing the post-processing.

Use this class as follows:

amg = AutomaticMaskGenerator(predictor)
amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
Arguments:
  • predictor: The segment anything predictor.
  • points_per_side: The number of points to be sampled along one side of the image. If None, point_grids must provide explicit point sampling.
  • points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
  • crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
  • crop_overlap_ratio: Sets the degree to which crops overlap.
  • crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
  • point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
  • stability_score_offset: The amount to shift the cutoff when calculating the stability score.
AutomaticMaskGenerator( predictor: segment_anything.predictor.SamPredictor, points_per_side: Optional[int] = 32, points_per_batch: Optional[int] = None, crop_n_layers: int = 0, crop_overlap_ratio: float = 0.3413333333333333, crop_n_points_downscale_factor: int = 1, point_grids: Optional[List[numpy.ndarray]] = None, stability_score_offset: float = 1.0)
364    def __init__(
365        self,
366        predictor: SamPredictor,
367        points_per_side: Optional[int] = 32,
368        points_per_batch: Optional[int] = None,
369        crop_n_layers: int = 0,
370        crop_overlap_ratio: float = 512 / 1500,
371        crop_n_points_downscale_factor: int = 1,
372        point_grids: Optional[List[np.ndarray]] = None,
373        stability_score_offset: float = 1.0,
374    ):
375        super().__init__()
376
377        if points_per_side is not None:
378            self.point_grids = amg_utils.build_all_layer_point_grids(
379                points_per_side,
380                crop_n_layers,
381                crop_n_points_downscale_factor,
382            )
383        elif point_grids is not None:
384            self.point_grids = point_grids
385        else:
386            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
387
388        self._predictor = predictor
389        self._points_per_side = points_per_side
390
391        # we set the points per batch to 16 for mps for performance reasons
392        # and otherwise keep them at the default of 64
393        if points_per_batch is None:
394            points_per_batch = 16 if str(predictor.device) == "mps" else 64
395        self._points_per_batch = points_per_batch
396
397        self._crop_n_layers = crop_n_layers
398        self._crop_overlap_ratio = crop_overlap_ratio
399        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
400        self._stability_score_offset = stability_score_offset
@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
449    @torch.no_grad()
450    def initialize(
451        self,
452        image: np.ndarray,
453        image_embeddings: Optional[util.ImageEmbeddings] = None,
454        i: Optional[int] = None,
455        verbose: bool = False,
456        pbar_init: Optional[callable] = None,
457        pbar_update: Optional[callable] = None,
458    ) -> None:
459        """Initialize image embeddings and masks for an image.
460
461        Args:
462            image: The input image, volume or timeseries.
463            image_embeddings: Optional precomputed image embeddings.
464                See `util.precompute_image_embeddings` for details.
465            i: Index for the image data. Required if `image` has three spatial dimensions
466                or a time dimension and two spatial dimensions.
467            verbose: Whether to print computation progress.
468            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
469                Can be used together with pbar_update to handle napari progress bar in other thread.
470                To enables using this function within a threadworker.
471            pbar_update: Callback to update an external progress bar.
472        """
473        original_size = image.shape[:2]
474        self._original_size = original_size
475
476        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
477            original_size, self._crop_n_layers, self._crop_overlap_ratio
478        )
479
480        # We can set fixed image embeddings if we only have a single crop box (the default setting).
481        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
482        if len(crop_boxes) == 1:
483            if image_embeddings is None:
484                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
485            util.set_precomputed(self._predictor, image_embeddings, i=i)
486            precomputed_embeddings = True
487        else:
488            precomputed_embeddings = False
489
490        # we need to cast to the image representation that is compatible with SAM
491        image = util._to_image(image)
492
493        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
494
495        crop_list = []
496        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
497            crop_data = self._process_crop(
498                image, crop_box, layer_idx,
499                precomputed_embeddings=precomputed_embeddings,
500                pbar_init=pbar_init, pbar_update=pbar_update,
501            )
502            crop_list.append(crop_data)
503        pbar_close()
504
505        self._is_initialized = True
506        self._crop_list = crop_list
507        self._crop_boxes = crop_boxes

Initialize image embeddings and masks for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to print computation progress.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
@torch.no_grad()
def generate( self, pred_iou_thresh: float = 0.88, stability_score_thresh: float = 0.95, box_nms_thresh: float = 0.7, crop_nms_thresh: float = 0.7, min_mask_region_area: int = 0, output_mode: str = 'binary_mask') -> List[Dict[str, Any]]:
509    @torch.no_grad()
510    def generate(
511        self,
512        pred_iou_thresh: float = 0.88,
513        stability_score_thresh: float = 0.95,
514        box_nms_thresh: float = 0.7,
515        crop_nms_thresh: float = 0.7,
516        min_mask_region_area: int = 0,
517        output_mode: str = "binary_mask",
518    ) -> List[Dict[str, Any]]:
519        """Generate instance segmentation for the currently initialized image.
520
521        Args:
522            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
523            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
524                under changes to the cutoff used to binarize the model prediction.
525            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
526            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
527            min_mask_region_area: Minimal size for the predicted masks.
528            output_mode: The form masks are returned in.
529
530        Returns:
531            The instance segmentation masks.
532        """
533        if not self.is_initialized:
534            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
535
536        data = amg_utils.MaskData()
537        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
538            crop_data = self._postprocess_batch(
539                data=deepcopy(data_),
540                crop_box=crop_box, original_size=self.original_size,
541                pred_iou_thresh=pred_iou_thresh,
542                stability_score_thresh=stability_score_thresh,
543                box_nms_thresh=box_nms_thresh
544            )
545            data.cat(crop_data)
546
547        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
548            # Prefer masks from smaller crops
549            scores = 1 / box_area(data["crop_boxes"])
550            scores = scores.to(data["boxes"].device)
551            keep_by_nms = batched_nms(
552                data["boxes"].float(),
553                scores,
554                torch.zeros_like(data["boxes"][:, 0]),  # categories
555                iou_threshold=crop_nms_thresh,
556            )
557            data.filter(keep_by_nms)
558
559        data.to_numpy()
560        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
561        return masks

Generate instance segmentation for the currently initialized image.

Arguments:
  • pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
  • stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask under changes to the cutoff used to binarize the model prediction.
  • box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
  • crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
  • min_mask_region_area: Minimal size for the predicted masks.
  • output_mode: The form masks are returned in.
Returns:

The instance segmentation masks.

class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
589class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
590    """Generates an instance segmentation without prompts, using a point grid.
591
592    Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
593
594    Args:
595        predictor: The segment anything predictor.
596        points_per_side: The number of points to be sampled along one side of the image.
597            If None, `point_grids` must provide explicit point sampling.
598        points_per_batch: The number of points run simultaneously by the model.
599            Higher numbers may be faster but use more GPU memory.
600        point_grids: A lisst over explicit grids of points used for sampling masks.
601            Normalized to [0, 1] with respect to the image coordinate system.
602        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
603    """
604
605    # We only expose the arguments that make sense for the tiled mask generator.
606    # Anything related to crops doesn't make sense, because we re-use that functionality
607    # for tiling, so these parameters wouldn't have any effect.
608    def __init__(
609        self,
610        predictor: SamPredictor,
611        points_per_side: Optional[int] = 32,
612        points_per_batch: int = 64,
613        point_grids: Optional[List[np.ndarray]] = None,
614        stability_score_offset: float = 1.0,
615    ) -> None:
616        super().__init__(
617            predictor=predictor,
618            points_per_side=points_per_side,
619            points_per_batch=points_per_batch,
620            point_grids=point_grids,
621            stability_score_offset=stability_score_offset,
622        )
623
624    @torch.no_grad()
625    def initialize(
626        self,
627        image: np.ndarray,
628        image_embeddings: Optional[util.ImageEmbeddings] = None,
629        i: Optional[int] = None,
630        tile_shape: Optional[Tuple[int, int]] = None,
631        halo: Optional[Tuple[int, int]] = None,
632        verbose: bool = False,
633        pbar_init: Optional[callable] = None,
634        pbar_update: Optional[callable] = None,
635    ) -> None:
636        """Initialize image embeddings and masks for an image.
637
638        Args:
639            image: The input image, volume or timeseries.
640            image_embeddings: Optional precomputed image embeddings.
641                See `util.precompute_image_embeddings` for details.
642            i: Index for the image data. Required if `image` has three spatial dimensions
643                or a time dimension and two spatial dimensions.
644            tile_shape: The tile shape for embedding prediction.
645            halo: The overlap of between tiles.
646            verbose: Whether to print computation progress.
647            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
648                Can be used together with pbar_update to handle napari progress bar in other thread.
649                To enables using this function within a threadworker.
650            pbar_update: Callback to update an external progress bar.
651        """
652        original_size = image.shape[:2]
653        self._original_size = original_size
654
655        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
656            self._predictor, image, image_embeddings, tile_shape, halo
657        )
658
659        tiling = blocking([0, 0], original_size, tile_shape)
660        n_tiles = tiling.numberOfBlocks
661
662        # The crop box is always the full local tile.
663        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
664        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
665
666        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
667        pbar_init(n_tiles, "Compute masks for tile")
668
669        # We need to cast to the image representation that is compatible with SAM.
670        image = util._to_image(image)
671
672        mask_data = []
673        for tile_id in range(n_tiles):
674            # set the pre-computed embeddings for this tile
675            features = image_embeddings["features"][tile_id]
676            tile_embeddings = {
677                "features": features,
678                "input_size": features.attrs["input_size"],
679                "original_size": features.attrs["original_size"],
680            }
681            util.set_precomputed(self._predictor, tile_embeddings, i)
682
683            # compute the mask data for this tile and append it
684            this_mask_data = self._process_crop(
685                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
686            )
687            mask_data.append(this_mask_data)
688            pbar_update(1)
689        pbar_close()
690
691        # set the initialized data
692        self._is_initialized = True
693        self._crop_list = mask_data
694        self._crop_boxes = crop_boxes

Generates an instance segmentation without prompts, using a point grid.

Implements the same functionality as AutomaticMaskGenerator but for tiled embeddings.

Arguments:
  • predictor: The segment anything predictor.
  • points_per_side: The number of points to be sampled along one side of the image. If None, point_grids must provide explicit point sampling.
  • points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
  • point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
  • stability_score_offset: The amount to shift the cutoff when calculating the stability score.
TiledAutomaticMaskGenerator( predictor: segment_anything.predictor.SamPredictor, points_per_side: Optional[int] = 32, points_per_batch: int = 64, point_grids: Optional[List[numpy.ndarray]] = None, stability_score_offset: float = 1.0)
608    def __init__(
609        self,
610        predictor: SamPredictor,
611        points_per_side: Optional[int] = 32,
612        points_per_batch: int = 64,
613        point_grids: Optional[List[np.ndarray]] = None,
614        stability_score_offset: float = 1.0,
615    ) -> None:
616        super().__init__(
617            predictor=predictor,
618            points_per_side=points_per_side,
619            points_per_batch=points_per_batch,
620            point_grids=point_grids,
621            stability_score_offset=stability_score_offset,
622        )
@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
624    @torch.no_grad()
625    def initialize(
626        self,
627        image: np.ndarray,
628        image_embeddings: Optional[util.ImageEmbeddings] = None,
629        i: Optional[int] = None,
630        tile_shape: Optional[Tuple[int, int]] = None,
631        halo: Optional[Tuple[int, int]] = None,
632        verbose: bool = False,
633        pbar_init: Optional[callable] = None,
634        pbar_update: Optional[callable] = None,
635    ) -> None:
636        """Initialize image embeddings and masks for an image.
637
638        Args:
639            image: The input image, volume or timeseries.
640            image_embeddings: Optional precomputed image embeddings.
641                See `util.precompute_image_embeddings` for details.
642            i: Index for the image data. Required if `image` has three spatial dimensions
643                or a time dimension and two spatial dimensions.
644            tile_shape: The tile shape for embedding prediction.
645            halo: The overlap of between tiles.
646            verbose: Whether to print computation progress.
647            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
648                Can be used together with pbar_update to handle napari progress bar in other thread.
649                To enables using this function within a threadworker.
650            pbar_update: Callback to update an external progress bar.
651        """
652        original_size = image.shape[:2]
653        self._original_size = original_size
654
655        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
656            self._predictor, image, image_embeddings, tile_shape, halo
657        )
658
659        tiling = blocking([0, 0], original_size, tile_shape)
660        n_tiles = tiling.numberOfBlocks
661
662        # The crop box is always the full local tile.
663        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
664        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
665
666        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
667        pbar_init(n_tiles, "Compute masks for tile")
668
669        # We need to cast to the image representation that is compatible with SAM.
670        image = util._to_image(image)
671
672        mask_data = []
673        for tile_id in range(n_tiles):
674            # set the pre-computed embeddings for this tile
675            features = image_embeddings["features"][tile_id]
676            tile_embeddings = {
677                "features": features,
678                "input_size": features.attrs["input_size"],
679                "original_size": features.attrs["original_size"],
680            }
681            util.set_precomputed(self._predictor, tile_embeddings, i)
682
683            # compute the mask data for this tile and append it
684            this_mask_data = self._process_crop(
685                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
686            )
687            mask_data.append(this_mask_data)
688            pbar_update(1)
689        pbar_close()
690
691        # set the initialized data
692        self._is_initialized = True
693        self._crop_list = mask_data
694        self._crop_boxes = crop_boxes

Initialize image embeddings and masks for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_shape: The tile shape for embedding prediction.
  • halo: The overlap of between tiles.
  • verbose: Whether to print computation progress.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
class DecoderAdapter(torch.nn.modules.module.Module):
702class DecoderAdapter(torch.nn.Module):
703    """Adapter to contain the UNETR decoder in a single module.
704
705    To apply the decoder on top of pre-computed embeddings for
706    the segmentation functionality.
707    See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
708    """
709    def __init__(self, unetr):
710        super().__init__()
711
712        self.base = unetr.base
713        self.out_conv = unetr.out_conv
714        self.deconv_out = unetr.deconv_out
715        self.decoder_head = unetr.decoder_head
716        self.final_activation = unetr.final_activation
717        self.postprocess_masks = unetr.postprocess_masks
718
719        self.decoder = unetr.decoder
720        self.deconv1 = unetr.deconv1
721        self.deconv2 = unetr.deconv2
722        self.deconv3 = unetr.deconv3
723        self.deconv4 = unetr.deconv4
724
725    def forward(self, input_, input_shape, original_shape):
726        z12 = input_
727
728        z9 = self.deconv1(z12)
729        z6 = self.deconv2(z9)
730        z3 = self.deconv3(z6)
731        z0 = self.deconv4(z3)
732
733        updated_from_encoder = [z9, z6, z3]
734
735        x = self.base(z12)
736        x = self.decoder(x, encoder_inputs=updated_from_encoder)
737        x = self.deconv_out(x)
738
739        x = torch.cat([x, z0], dim=1)
740        x = self.decoder_head(x)
741
742        x = self.out_conv(x)
743        if self.final_activation is not None:
744            x = self.final_activation(x)
745
746        x = self.postprocess_masks(x, input_shape, original_shape)
747        return x

Adapter to contain the UNETR decoder in a single module.

To apply the decoder on top of pre-computed embeddings for the segmentation functionality. See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py

DecoderAdapter(unetr)
709    def __init__(self, unetr):
710        super().__init__()
711
712        self.base = unetr.base
713        self.out_conv = unetr.out_conv
714        self.deconv_out = unetr.deconv_out
715        self.decoder_head = unetr.decoder_head
716        self.final_activation = unetr.final_activation
717        self.postprocess_masks = unetr.postprocess_masks
718
719        self.decoder = unetr.decoder
720        self.deconv1 = unetr.deconv1
721        self.deconv2 = unetr.deconv2
722        self.deconv3 = unetr.deconv3
723        self.deconv4 = unetr.deconv4

Initialize internal Module state, shared by both nn.Module and ScriptModule.

base
out_conv
deconv_out
decoder_head
final_activation
postprocess_masks
decoder
deconv1
deconv2
deconv3
deconv4
def forward(self, input_, input_shape, original_shape):
725    def forward(self, input_, input_shape, original_shape):
726        z12 = input_
727
728        z9 = self.deconv1(z12)
729        z6 = self.deconv2(z9)
730        z3 = self.deconv3(z6)
731        z0 = self.deconv4(z3)
732
733        updated_from_encoder = [z9, z6, z3]
734
735        x = self.base(z12)
736        x = self.decoder(x, encoder_inputs=updated_from_encoder)
737        x = self.deconv_out(x)
738
739        x = torch.cat([x, z0], dim=1)
740        x = self.decoder_head(x)
741
742        x = self.out_conv(x)
743        if self.final_activation is not None:
744            x = self.final_activation(x)
745
746        x = self.postprocess_masks(x, input_shape, original_shape)
747        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def get_unetr( image_encoder: torch.nn.modules.module.Module, decoder_state: Optional[collections.OrderedDict[str, torch.Tensor]] = None, device: Union[str, torch.device, NoneType] = None) -> torch.nn.modules.module.Module:
750def get_unetr(
751    image_encoder: torch.nn.Module,
752    decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
753    device: Optional[Union[str, torch.device]] = None,
754) -> torch.nn.Module:
755    """Get UNETR model for automatic instance segmentation.
756
757    Args:
758        image_encoder: The image encoder of the SAM model.
759            This is used as encoder by the UNETR too.
760        decoder_state: Optional decoder state to initialize the weights
761            of the UNETR decoder.
762        device: The device.
763    Returns:
764        The UNETR model.
765    """
766    device = util.get_device(device)
767
768    unetr = UNETR(
769        backbone="sam",
770        encoder=image_encoder,
771        out_channels=3,
772        use_sam_stats=True,
773        final_activation="Sigmoid",
774        use_skip_connection=False,
775        resize_input=True,
776    )
777    if decoder_state is not None:
778        unetr_state_dict = unetr.state_dict()
779        for k, v in unetr_state_dict.items():
780            if not k.startswith("encoder"):
781                unetr_state_dict[k] = decoder_state[k]
782        unetr.load_state_dict(unetr_state_dict)
783
784    unetr.to(device)
785    return unetr

Get UNETR model for automatic instance segmentation.

Arguments:
  • image_encoder: The image encoder of the SAM model. This is used as encoder by the UNETR too.
  • decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
  • device: The device.
Returns:

The UNETR model.

def get_decoder( image_encoder: torch.nn.modules.module.Module, decoder_state: collections.OrderedDict[str, torch.Tensor], device: Union[str, torch.device, NoneType] = None) -> DecoderAdapter:
788def get_decoder(
789    image_encoder: torch.nn.Module,
790    decoder_state: OrderedDict[str, torch.Tensor],
791    device: Optional[Union[str, torch.device]] = None,
792) -> DecoderAdapter:
793    """Get decoder to predict outputs for automatic instance segmentation
794
795    Args:
796        image_encoder: The image encoder of the SAM model.
797        decoder_state: State to initialize the weights of the UNETR decoder.
798        device: The device.
799    Returns:
800        The decoder for instance segmentation.
801    """
802    unetr = get_unetr(image_encoder, decoder_state, device)
803    return DecoderAdapter(unetr)

Get decoder to predict outputs for automatic instance segmentation

Arguments:
  • image_encoder: The image encoder of the SAM model.
  • decoder_state: State to initialize the weights of the UNETR decoder.
  • device: The device.
Returns:

The decoder for instance segmentation.

def get_predictor_and_decoder( model_type: str, checkpoint_path: Union[str, os.PathLike], device: Union[str, torch.device, NoneType] = None, peft_kwargs: Optional[Dict] = None) -> Tuple[segment_anything.predictor.SamPredictor, DecoderAdapter]:
806def get_predictor_and_decoder(
807    model_type: str,
808    checkpoint_path: Union[str, os.PathLike],
809    device: Optional[Union[str, torch.device]] = None,
810    peft_kwargs: Optional[Dict] = None,
811) -> Tuple[SamPredictor, DecoderAdapter]:
812    """Load the SAM model (predictor) and instance segmentation decoder.
813
814    This requires a checkpoint that contains the state for both predictor
815    and decoder.
816
817    Args:
818        model_type: The type of the image encoder used in the SAM model.
819        checkpoint_path: Path to the checkpoint from which to load the data.
820        device: The device.
821        lora_rank: The rank for low rank adaptation of the attention layers.
822
823    Returns:
824        The SAM predictor.
825        The decoder for instance segmentation.
826    """
827    device = util.get_device(device)
828    predictor, state = util.get_sam_model(
829        model_type=model_type,
830        checkpoint_path=checkpoint_path,
831        device=device,
832        return_state=True,
833        peft_kwargs=peft_kwargs,
834    )
835    if "decoder_state" not in state:
836        raise ValueError(
837            f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
838        )
839    decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
840    return predictor, decoder

Load the SAM model (predictor) and instance segmentation decoder.

This requires a checkpoint that contains the state for both predictor and decoder.

Arguments:
  • model_type: The type of the image encoder used in the SAM model.
  • checkpoint_path: Path to the checkpoint from which to load the data.
  • device: The device.
  • lora_rank: The rank for low rank adaptation of the attention layers.
Returns:

The SAM predictor. The decoder for instance segmentation.

class InstanceSegmentationWithDecoder:
 843class InstanceSegmentationWithDecoder:
 844    """Generates an instance segmentation without prompts, using a decoder.
 845
 846    Implements the same interface as `AutomaticMaskGenerator`.
 847
 848    Use this class as follows:
 849    ```python
 850    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
 851    segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
 852    masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
 853    ```
 854
 855    Args:
 856        predictor: The segment anything predictor.
 857        decoder: The decoder to predict intermediate representations
 858            for instance segmentation.
 859    """
 860    def __init__(
 861        self,
 862        predictor: SamPredictor,
 863        decoder: torch.nn.Module,
 864    ) -> None:
 865        self._predictor = predictor
 866        self._decoder = decoder
 867
 868        # The decoder outputs.
 869        self._foreground = None
 870        self._center_distances = None
 871        self._boundary_distances = None
 872
 873        self._is_initialized = False
 874
 875    @property
 876    def is_initialized(self):
 877        """Whether the mask generator has already been initialized.
 878        """
 879        return self._is_initialized
 880
 881    @torch.no_grad()
 882    def initialize(
 883        self,
 884        image: np.ndarray,
 885        image_embeddings: Optional[util.ImageEmbeddings] = None,
 886        i: Optional[int] = None,
 887        verbose: bool = False,
 888        pbar_init: Optional[callable] = None,
 889        pbar_update: Optional[callable] = None,
 890    ) -> None:
 891        """Initialize image embeddings and decoder predictions for an image.
 892
 893        Args:
 894            image: The input image, volume or timeseries.
 895            image_embeddings: Optional precomputed image embeddings.
 896                See `util.precompute_image_embeddings` for details.
 897            i: Index for the image data. Required if `image` has three spatial dimensions
 898                or a time dimension and two spatial dimensions.
 899            verbose: Whether to be verbose.
 900            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 901                Can be used together with pbar_update to handle napari progress bar in other thread.
 902                To enables using this function within a threadworker.
 903            pbar_update: Callback to update an external progress bar.
 904        """
 905        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 906        pbar_init(1, "Initialize instance segmentation with decoder")
 907
 908        if image_embeddings is None:
 909            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 910
 911        # Get the image embeddings from the predictor.
 912        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
 913        embeddings = self._predictor.features
 914        input_shape = tuple(self._predictor.input_size)
 915        original_shape = tuple(self._predictor.original_size)
 916
 917        # Run prediction with the UNETR decoder.
 918        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
 919        assert output.shape[0] == 3, f"{output.shape}"
 920        pbar_update(1)
 921        pbar_close()
 922
 923        # Set the state.
 924        self._foreground = output[0]
 925        self._center_distances = output[1]
 926        self._boundary_distances = output[2]
 927        self._is_initialized = True
 928
 929    def _to_masks(self, segmentation, output_mode):
 930        if output_mode != "binary_mask":
 931            raise NotImplementedError
 932
 933        props = regionprops(segmentation)
 934        ndim = segmentation.ndim
 935        assert ndim in (2, 3)
 936
 937        shape = segmentation.shape
 938        if ndim == 2:
 939            crop_box = [0, shape[1], 0, shape[0]]
 940        else:
 941            crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]
 942
 943        # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
 944        def to_bbox_2d(bbox):
 945            y0, x0 = bbox[0], bbox[1]
 946            w = bbox[3] - x0
 947            h = bbox[2] - y0
 948            return [x0, w, y0, h]
 949
 950        def to_bbox_3d(bbox):
 951            z0, y0, x0 = bbox[0], bbox[1], bbox[2]
 952            w = bbox[5] - x0
 953            h = bbox[4] - y0
 954            d = bbox[3] - y0
 955            return [x0, w, y0, h, z0, d]
 956
 957        to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
 958        masks = [
 959            {
 960                "segmentation": segmentation == prop.label,
 961                "area": prop.area,
 962                "bbox": to_bbox(prop.bbox),
 963                "crop_box": crop_box,
 964                "seg_id": prop.label,
 965            } for prop in props
 966        ]
 967        return masks
 968
 969    def generate(
 970        self,
 971        center_distance_threshold: float = 0.5,
 972        boundary_distance_threshold: float = 0.5,
 973        foreground_threshold: float = 0.5,
 974        foreground_smoothing: float = 1.0,
 975        distance_smoothing: float = 1.6,
 976        min_size: int = 0,
 977        output_mode: Optional[str] = "binary_mask",
 978    ) -> List[Dict[str, Any]]:
 979        """Generate instance segmentation for the currently initialized image.
 980
 981        Args:
 982            center_distance_threshold: Center distance predictions below this value will be
 983                used to find seeds (intersected with thresholded boundary distance predictions).
 984            boundary_distance_threshold: Boundary distance predictions below this value will be
 985                used to find seeds (intersected with thresholded center distance predictions).
 986            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
 987                checkerboard artifacts in the prediction.
 988            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
 989            distance_smoothing: Sigma value for smoothing the distance predictions.
 990            min_size: Minimal object size in the segmentation result.
 991            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
 992
 993        Returns:
 994            The instance segmentation masks.
 995        """
 996        if not self.is_initialized:
 997            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
 998
 999        if foreground_smoothing > 0:
1000            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1001        else:
1002            foreground = self._foreground
1003        # Further optimization: parallel implementation using elf.parallel functionality.
1004        # (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization)
1005        segmentation = watershed_from_center_and_boundary_distances(
1006            self._center_distances, self._boundary_distances, foreground,
1007            center_distance_threshold=center_distance_threshold,
1008            boundary_distance_threshold=boundary_distance_threshold,
1009            foreground_threshold=foreground_threshold,
1010            distance_smoothing=distance_smoothing,
1011            min_size=min_size,
1012        )
1013        if output_mode is not None:
1014            segmentation = self._to_masks(segmentation, output_mode)
1015        return segmentation
1016
1017    def get_state(self) -> Dict[str, Any]:
1018        """Get the initialized state of the instance segmenter.
1019
1020        Returns:
1021            Instance segmentation state.
1022        """
1023        if not self.is_initialized:
1024            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1025
1026        return {
1027            "foreground": self._foreground,
1028            "center_distances": self._center_distances,
1029            "boundary_distances": self._boundary_distances,
1030        }
1031
1032    def set_state(self, state: Dict[str, Any]) -> None:
1033        """Set the state of the instance segmenter.
1034
1035        Args:
1036            state: The instance segmentation state
1037        """
1038        self._foreground = state["foreground"]
1039        self._center_distances = state["center_distances"]
1040        self._boundary_distances = state["boundary_distances"]
1041        self._is_initialized = True
1042
1043    def clear_state(self):
1044        """Clear the state of the instance segmenter.
1045        """
1046        self._foreground = None
1047        self._center_distances = None
1048        self._boundary_distances = None
1049        self._is_initialized = False

Generates an instance segmentation without prompts, using a decoder.

Implements the same interface as AutomaticMaskGenerator.

Use this class as follows:

segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
Arguments:
  • predictor: The segment anything predictor.
  • decoder: The decoder to predict intermediate representations for instance segmentation.
InstanceSegmentationWithDecoder( predictor: segment_anything.predictor.SamPredictor, decoder: torch.nn.modules.module.Module)
860    def __init__(
861        self,
862        predictor: SamPredictor,
863        decoder: torch.nn.Module,
864    ) -> None:
865        self._predictor = predictor
866        self._decoder = decoder
867
868        # The decoder outputs.
869        self._foreground = None
870        self._center_distances = None
871        self._boundary_distances = None
872
873        self._is_initialized = False
is_initialized
875    @property
876    def is_initialized(self):
877        """Whether the mask generator has already been initialized.
878        """
879        return self._is_initialized

Whether the mask generator has already been initialized.

@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
881    @torch.no_grad()
882    def initialize(
883        self,
884        image: np.ndarray,
885        image_embeddings: Optional[util.ImageEmbeddings] = None,
886        i: Optional[int] = None,
887        verbose: bool = False,
888        pbar_init: Optional[callable] = None,
889        pbar_update: Optional[callable] = None,
890    ) -> None:
891        """Initialize image embeddings and decoder predictions for an image.
892
893        Args:
894            image: The input image, volume or timeseries.
895            image_embeddings: Optional precomputed image embeddings.
896                See `util.precompute_image_embeddings` for details.
897            i: Index for the image data. Required if `image` has three spatial dimensions
898                or a time dimension and two spatial dimensions.
899            verbose: Whether to be verbose.
900            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
901                Can be used together with pbar_update to handle napari progress bar in other thread.
902                To enables using this function within a threadworker.
903            pbar_update: Callback to update an external progress bar.
904        """
905        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
906        pbar_init(1, "Initialize instance segmentation with decoder")
907
908        if image_embeddings is None:
909            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
910
911        # Get the image embeddings from the predictor.
912        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
913        embeddings = self._predictor.features
914        input_shape = tuple(self._predictor.input_size)
915        original_shape = tuple(self._predictor.original_size)
916
917        # Run prediction with the UNETR decoder.
918        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
919        assert output.shape[0] == 3, f"{output.shape}"
920        pbar_update(1)
921        pbar_close()
922
923        # Set the state.
924        self._foreground = output[0]
925        self._center_distances = output[1]
926        self._boundary_distances = output[2]
927        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to be verbose.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
def generate( self, center_distance_threshold: float = 0.5, boundary_distance_threshold: float = 0.5, foreground_threshold: float = 0.5, foreground_smoothing: float = 1.0, distance_smoothing: float = 1.6, min_size: int = 0, output_mode: Optional[str] = 'binary_mask') -> List[Dict[str, Any]]:
 969    def generate(
 970        self,
 971        center_distance_threshold: float = 0.5,
 972        boundary_distance_threshold: float = 0.5,
 973        foreground_threshold: float = 0.5,
 974        foreground_smoothing: float = 1.0,
 975        distance_smoothing: float = 1.6,
 976        min_size: int = 0,
 977        output_mode: Optional[str] = "binary_mask",
 978    ) -> List[Dict[str, Any]]:
 979        """Generate instance segmentation for the currently initialized image.
 980
 981        Args:
 982            center_distance_threshold: Center distance predictions below this value will be
 983                used to find seeds (intersected with thresholded boundary distance predictions).
 984            boundary_distance_threshold: Boundary distance predictions below this value will be
 985                used to find seeds (intersected with thresholded center distance predictions).
 986            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
 987                checkerboard artifacts in the prediction.
 988            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
 989            distance_smoothing: Sigma value for smoothing the distance predictions.
 990            min_size: Minimal object size in the segmentation result.
 991            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
 992
 993        Returns:
 994            The instance segmentation masks.
 995        """
 996        if not self.is_initialized:
 997            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
 998
 999        if foreground_smoothing > 0:
1000            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1001        else:
1002            foreground = self._foreground
1003        # Further optimization: parallel implementation using elf.parallel functionality.
1004        # (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization)
1005        segmentation = watershed_from_center_and_boundary_distances(
1006            self._center_distances, self._boundary_distances, foreground,
1007            center_distance_threshold=center_distance_threshold,
1008            boundary_distance_threshold=boundary_distance_threshold,
1009            foreground_threshold=foreground_threshold,
1010            distance_smoothing=distance_smoothing,
1011            min_size=min_size,
1012        )
1013        if output_mode is not None:
1014            segmentation = self._to_masks(segmentation, output_mode)
1015        return segmentation

Generate instance segmentation for the currently initialized image.

Arguments:
  • center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions).
  • boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions).
  • foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction.
  • foreground_threshold: Foreground predictions above this value will be used as foreground mask.
  • distance_smoothing: Sigma value for smoothing the distance predictions.
  • min_size: Minimal object size in the segmentation result.
  • output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
Returns:

The instance segmentation masks.

def get_state(self) -> Dict[str, Any]:
1017    def get_state(self) -> Dict[str, Any]:
1018        """Get the initialized state of the instance segmenter.
1019
1020        Returns:
1021            Instance segmentation state.
1022        """
1023        if not self.is_initialized:
1024            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1025
1026        return {
1027            "foreground": self._foreground,
1028            "center_distances": self._center_distances,
1029            "boundary_distances": self._boundary_distances,
1030        }

Get the initialized state of the instance segmenter.

Returns:

Instance segmentation state.

def set_state(self, state: Dict[str, Any]) -> None:
1032    def set_state(self, state: Dict[str, Any]) -> None:
1033        """Set the state of the instance segmenter.
1034
1035        Args:
1036            state: The instance segmentation state
1037        """
1038        self._foreground = state["foreground"]
1039        self._center_distances = state["center_distances"]
1040        self._boundary_distances = state["boundary_distances"]
1041        self._is_initialized = True

Set the state of the instance segmenter.

Arguments:
  • state: The instance segmentation state
def clear_state(self):
1043    def clear_state(self):
1044        """Clear the state of the instance segmenter.
1045        """
1046        self._foreground = None
1047        self._center_distances = None
1048        self._boundary_distances = None
1049        self._is_initialized = False

Clear the state of the instance segmenter.

class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1052class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1053    """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings.
1054    """
1055
1056    @torch.no_grad()
1057    def initialize(
1058        self,
1059        image: np.ndarray,
1060        image_embeddings: Optional[util.ImageEmbeddings] = None,
1061        i: Optional[int] = None,
1062        tile_shape: Optional[Tuple[int, int]] = None,
1063        halo: Optional[Tuple[int, int]] = None,
1064        verbose: bool = False,
1065        pbar_init: Optional[callable] = None,
1066        pbar_update: Optional[callable] = None,
1067    ) -> None:
1068        """Initialize image embeddings and decoder predictions for an image.
1069
1070        Args:
1071            image: The input image, volume or timeseries.
1072            image_embeddings: Optional precomputed image embeddings.
1073                See `util.precompute_image_embeddings` for details.
1074            i: Index for the image data. Required if `image` has three spatial dimensions
1075                or a time dimension and two spatial dimensions.
1076            verbose: Dummy input to be compatible with other function signatures.
1077            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1078                Can be used together with pbar_update to handle napari progress bar in other thread.
1079                To enables using this function within a threadworker.
1080            pbar_update: Callback to update an external progress bar.
1081        """
1082        original_size = image.shape[:2]
1083        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1084            self._predictor, image, image_embeddings, tile_shape, halo
1085        )
1086        tiling = blocking([0, 0], original_size, tile_shape)
1087
1088        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1089        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1090
1091        foreground = np.zeros(original_size, dtype="float32")
1092        center_distances = np.zeros(original_size, dtype="float32")
1093        boundary_distances = np.zeros(original_size, dtype="float32")
1094
1095        for tile_id in range(tiling.numberOfBlocks):
1096
1097            # Get the image embeddings from the predictor for this tile.
1098            self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1099            embeddings = self._predictor.features
1100            input_shape = tuple(self._predictor.input_size)
1101            original_shape = tuple(self._predictor.original_size)
1102
1103            # Predict with the UNETR decoder for this tile.
1104            output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1105            assert output.shape[0] == 3, f"{output.shape}"
1106
1107            # Set the predictions in the output for this tile.
1108            block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1109            local_bb = tuple(
1110                slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1111            )
1112            inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1113
1114            foreground[inner_bb] = output[0][local_bb]
1115            center_distances[inner_bb] = output[1][local_bb]
1116            boundary_distances[inner_bb] = output[2][local_bb]
1117        pbar_close()
1118
1119        # Set the state.
1120        self._foreground = foreground
1121        self._center_distances = center_distances
1122        self._boundary_distances = boundary_distances
1123        self._is_initialized = True

Same as InstanceSegmentationWithDecoder but for tiled image embeddings.

@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
1056    @torch.no_grad()
1057    def initialize(
1058        self,
1059        image: np.ndarray,
1060        image_embeddings: Optional[util.ImageEmbeddings] = None,
1061        i: Optional[int] = None,
1062        tile_shape: Optional[Tuple[int, int]] = None,
1063        halo: Optional[Tuple[int, int]] = None,
1064        verbose: bool = False,
1065        pbar_init: Optional[callable] = None,
1066        pbar_update: Optional[callable] = None,
1067    ) -> None:
1068        """Initialize image embeddings and decoder predictions for an image.
1069
1070        Args:
1071            image: The input image, volume or timeseries.
1072            image_embeddings: Optional precomputed image embeddings.
1073                See `util.precompute_image_embeddings` for details.
1074            i: Index for the image data. Required if `image` has three spatial dimensions
1075                or a time dimension and two spatial dimensions.
1076            verbose: Dummy input to be compatible with other function signatures.
1077            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1078                Can be used together with pbar_update to handle napari progress bar in other thread.
1079                To enables using this function within a threadworker.
1080            pbar_update: Callback to update an external progress bar.
1081        """
1082        original_size = image.shape[:2]
1083        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1084            self._predictor, image, image_embeddings, tile_shape, halo
1085        )
1086        tiling = blocking([0, 0], original_size, tile_shape)
1087
1088        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1089        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1090
1091        foreground = np.zeros(original_size, dtype="float32")
1092        center_distances = np.zeros(original_size, dtype="float32")
1093        boundary_distances = np.zeros(original_size, dtype="float32")
1094
1095        for tile_id in range(tiling.numberOfBlocks):
1096
1097            # Get the image embeddings from the predictor for this tile.
1098            self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1099            embeddings = self._predictor.features
1100            input_shape = tuple(self._predictor.input_size)
1101            original_shape = tuple(self._predictor.original_size)
1102
1103            # Predict with the UNETR decoder for this tile.
1104            output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1105            assert output.shape[0] == 3, f"{output.shape}"
1106
1107            # Set the predictions in the output for this tile.
1108            block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1109            local_bb = tuple(
1110                slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1111            )
1112            inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1113
1114            foreground[inner_bb] = output[0][local_bb]
1115            center_distances[inner_bb] = output[1][local_bb]
1116            boundary_distances[inner_bb] = output[2][local_bb]
1117        pbar_close()
1118
1119        # Set the state.
1120        self._foreground = foreground
1121        self._center_distances = center_distances
1122        self._boundary_distances = boundary_distances
1123        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Dummy input to be compatible with other function signatures.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
def get_amg( predictor: segment_anything.predictor.SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.modules.module.Module] = None, **kwargs) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1126def get_amg(
1127    predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs,
1128) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1129    """Get the automatic mask generator class.
1130
1131    Args:
1132        predictor: The segment anything predictor.
1133        is_tiled: Whether tiled embeddings are used.
1134        decoder: Decoder to predict instacne segmmentation.
1135        kwargs: The keyword arguments for the amg class.
1136
1137    Returns:
1138        The automatic mask generator.
1139    """
1140    if decoder is None:
1141        segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator
1142        segmenter = segmenter_class(predictor, **kwargs)
1143    else:
1144        segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder
1145        segmenter = segmenter_class(predictor, decoder, **kwargs)
1146
1147    return segmenter

Get the automatic mask generator class.

Arguments:
  • predictor: The segment anything predictor.
  • is_tiled: Whether tiled embeddings are used.
  • decoder: Decoder to predict instacne segmmentation.
  • kwargs: The keyword arguments for the amg class.
Returns:

The automatic mask generator.