micro_sam.instance_segmentation

Automated instance segmentation functionality. The classes implemented here extend the automatic instance segmentation from Segment Anything: https://computational-cell-analytics.github.io/micro-sam/micro_sam.html

   1"""Automated instance segmentation functionality.
   2The classes implemented here extend the automatic instance segmentation from Segment Anything:
   3https://computational-cell-analytics.github.io/micro-sam/micro_sam.html
   4"""
   5
   6import os
   7import warnings
   8from abc import ABC
   9from copy import deepcopy
  10from collections import OrderedDict
  11from typing import Any, Dict, List, Optional, Tuple, Union
  12
  13import vigra
  14import numpy as np
  15from skimage.measure import label, regionprops
  16from skimage.segmentation import relabel_sequential
  17
  18import torch
  19from torchvision.ops.boxes import batched_nms, box_area
  20
  21from torch_em.model import UNETR
  22from torch_em.util.segmentation import watershed_from_center_and_boundary_distances
  23
  24import elf.parallel as parallel
  25from elf.parallel.filters import apply_filter
  26
  27from nifty.tools import blocking
  28
  29import segment_anything.utils.amg as amg_utils
  30from segment_anything.predictor import SamPredictor
  31
  32from . import util
  33from ._vendored import batched_mask_to_box, mask_to_rle_pytorch
  34
  35#
  36# Utility Functionality
  37#
  38
  39
  40class _FakeInput:
  41    def __init__(self, shape):
  42        self.shape = shape
  43
  44    def __getitem__(self, index):
  45        block_shape = tuple(ind.stop - ind.start for ind in index)
  46        return np.zeros(block_shape, dtype="float32")
  47
  48
  49def mask_data_to_segmentation(
  50    masks: List[Dict[str, Any]],
  51    with_background: bool,
  52    min_object_size: int = 0,
  53    max_object_size: Optional[int] = None,
  54    label_masks: bool = True,
  55) -> np.ndarray:
  56    """Convert the output of the automatic mask generation to an instance segmentation.
  57
  58    Args:
  59        masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator.
  60            Only supports output_mode=binary_mask.
  61        with_background: Whether the segmentation has background. If yes this function assures that the largest
  62            object in the output will be mapped to zero (the background value).
  63        min_object_size: The minimal size of an object in pixels. By default, set to '0'.
  64        max_object_size: The maximal size of an object in pixels.
  65        label_masks: Whether to apply connected components to the result before removing small objects.
  66            By default, set to 'True'.
  67
  68    Returns:
  69        The instance segmentation.
  70    """
  71
  72    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
  73    # we could also get the shape from the crop box
  74    shape = next(iter(masks))["segmentation"].shape
  75    segmentation = np.zeros(shape, dtype="uint32")
  76
  77    def require_numpy(mask):
  78        return mask.cpu().numpy() if torch.is_tensor(mask) else mask
  79
  80    seg_id = 1
  81    for mask in masks:
  82        if mask["area"] < min_object_size:
  83            continue
  84        if max_object_size is not None and mask["area"] > max_object_size:
  85            continue
  86
  87        this_seg_id = mask.get("seg_id", seg_id)
  88        segmentation[require_numpy(mask["segmentation"])] = this_seg_id
  89        seg_id = this_seg_id + 1
  90
  91    if label_masks:
  92        segmentation = label(segmentation).astype(segmentation.dtype)
  93
  94    seg_ids, sizes = np.unique(segmentation, return_counts=True)
  95
  96    # In some cases objects may be smaller than peviously calculated,
  97    # since they are covered by other objects. We ensure these also get
  98    # filtered out here.
  99    filter_ids = seg_ids[sizes < min_object_size]
 100
 101    # If we run segmentation with background we also map the largest segment
 102    # (the most likely background object) to zero. This is often zero already,
 103    # but it does not hurt to reset that to zero either.
 104    if with_background:
 105        bg_id = seg_ids[np.argmax(sizes)]
 106        filter_ids = np.concatenate([filter_ids, [bg_id]])
 107
 108    segmentation[np.isin(segmentation, filter_ids)] = 0
 109    segmentation = relabel_sequential(segmentation)[0]
 110
 111    return segmentation
 112
 113
 114#
 115# Classes for automatic instance segmentation
 116#
 117
 118
 119class AMGBase(ABC):
 120    """Base class for the automatic mask generators.
 121    """
 122    def __init__(self):
 123        # the state that has to be computed by the 'initialize' method of the child classes
 124        self._is_initialized = False
 125        self._crop_list = None
 126        self._crop_boxes = None
 127        self._original_size = None
 128
 129    @property
 130    def is_initialized(self):
 131        """Whether the mask generator has already been initialized.
 132        """
 133        return self._is_initialized
 134
 135    @property
 136    def crop_list(self):
 137        """The list of mask data after initialization.
 138        """
 139        return self._crop_list
 140
 141    @property
 142    def crop_boxes(self):
 143        """The list of crop boxes.
 144        """
 145        return self._crop_boxes
 146
 147    @property
 148    def original_size(self):
 149        """The original image size.
 150        """
 151        return self._original_size
 152
 153    def _postprocess_batch(
 154        self,
 155        data,
 156        crop_box,
 157        original_size,
 158        pred_iou_thresh,
 159        stability_score_thresh,
 160        box_nms_thresh,
 161    ):
 162        orig_h, orig_w = original_size
 163
 164        # filter by predicted IoU
 165        if pred_iou_thresh > 0.0:
 166            keep_mask = data["iou_preds"] > pred_iou_thresh
 167            data.filter(keep_mask)
 168
 169        # filter by stability score
 170        if stability_score_thresh > 0.0:
 171            keep_mask = data["stability_score"] >= stability_score_thresh
 172            data.filter(keep_mask)
 173
 174        # filter boxes that touch crop boundaries
 175        keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
 176        if not torch.all(keep_mask):
 177            data.filter(keep_mask)
 178
 179        # remove duplicates within this crop.
 180        keep_by_nms = batched_nms(
 181            data["boxes"].float(),
 182            data["iou_preds"],
 183            torch.zeros_like(data["boxes"][:, 0]),  # categories
 184            iou_threshold=box_nms_thresh,
 185        )
 186        data.filter(keep_by_nms)
 187
 188        # return to the original image frame
 189        data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box)
 190        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
 191        # the data from embedding based segmentation doesn't have the points
 192        # so we skip if the corresponding key can't be found
 193        try:
 194            data["points"] = amg_utils.uncrop_points(data["points"], crop_box)
 195        except KeyError:
 196            pass
 197
 198        return data
 199
 200    def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
 201
 202        if len(mask_data["rles"]) == 0:
 203            return mask_data
 204
 205        # filter small disconnected regions and holes
 206        new_masks = []
 207        scores = []
 208        for rle in mask_data["rles"]:
 209            mask = amg_utils.rle_to_mask(rle)
 210
 211            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes")
 212            unchanged = not changed
 213            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands")
 214            unchanged = unchanged and not changed
 215
 216            new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0))
 217            # give score=0 to changed masks and score=1 to unchanged masks
 218            # so NMS will prefer ones that didn't need postprocessing
 219            scores.append(float(unchanged))
 220
 221        # recalculate boxes and remove any new duplicates
 222        masks = torch.cat(new_masks, dim=0)
 223        boxes = batched_mask_to_box(masks.to(torch.bool))  # Casting this to boolean as we work with one-hot labels.
 224        keep_by_nms = batched_nms(
 225            boxes.float(),
 226            torch.as_tensor(scores, dtype=torch.float),
 227            torch.zeros_like(boxes[:, 0]),  # categories
 228            iou_threshold=nms_thresh,
 229        )
 230
 231        # only recalculate RLEs for masks that have changed
 232        for i_mask in keep_by_nms:
 233            if scores[i_mask] == 0.0:
 234                mask_torch = masks[i_mask].unsqueeze(0)
 235                # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
 236                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
 237                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
 238        mask_data.filter(keep_by_nms)
 239
 240        return mask_data
 241
 242    def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode):
 243        # filter small disconnected regions and holes in masks
 244        if min_mask_region_area > 0:
 245            mask_data = self._postprocess_small_regions(
 246                mask_data,
 247                min_mask_region_area,
 248                max(box_nms_thresh, crop_nms_thresh),
 249            )
 250
 251        # encode masks
 252        if output_mode == "coco_rle":
 253            mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]]
 254        elif output_mode == "binary_mask":
 255            mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]]
 256        else:
 257            mask_data["segmentations"] = mask_data["rles"]
 258
 259        # write mask records
 260        curr_anns = []
 261        for idx in range(len(mask_data["segmentations"])):
 262            ann = {
 263                "segmentation": mask_data["segmentations"][idx],
 264                "area": amg_utils.area_from_rle(mask_data["rles"][idx]),
 265                "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
 266                "predicted_iou": mask_data["iou_preds"][idx].item(),
 267                "stability_score": mask_data["stability_score"][idx].item(),
 268                "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
 269            }
 270            # the data from embedding based segmentation doesn't have the points
 271            # so we skip if the corresponding key can't be found
 272            try:
 273                ann["point_coords"] = [mask_data["points"][idx].tolist()]
 274            except KeyError:
 275                pass
 276            curr_anns.append(ann)
 277
 278        return curr_anns
 279
 280    def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
 281        orig_h, orig_w = original_size
 282
 283        # serialize predictions and store in MaskData
 284        data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
 285        if points is not None:
 286            data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
 287
 288        del masks
 289
 290        # calculate the stability scores
 291        data["stability_score"] = amg_utils.calculate_stability_score(
 292            data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
 293        )
 294
 295        # threshold masks and calculate boxes
 296        data["masks"] = data["masks"] > self._predictor.model.mask_threshold
 297        data["masks"] = data["masks"].type(torch.bool)
 298        data["boxes"] = batched_mask_to_box(data["masks"])
 299
 300        # compress to RLE
 301        data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
 302        # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
 303        data["rles"] = mask_to_rle_pytorch(data["masks"])
 304        del data["masks"]
 305
 306        return data
 307
 308    def get_state(self) -> Dict[str, Any]:
 309        """Get the initialized state of the mask generator.
 310
 311        Returns:
 312            State of the mask generator.
 313        """
 314        if not self.is_initialized:
 315            raise RuntimeError("The state has not been computed yet. Call initialize first.")
 316
 317        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
 318
 319    def set_state(self, state: Dict[str, Any]) -> None:
 320        """Set the state of the mask generator.
 321
 322        Args:
 323            state: The state of the mask generator, e.g. from serialized state.
 324        """
 325        self._crop_list = state["crop_list"]
 326        self._crop_boxes = state["crop_boxes"]
 327        self._original_size = state["original_size"]
 328        self._is_initialized = True
 329
 330    def clear_state(self):
 331        """Clear the state of the mask generator.
 332        """
 333        self._crop_list = None
 334        self._crop_boxes = None
 335        self._original_size = None
 336        self._is_initialized = False
 337
 338
 339class AutomaticMaskGenerator(AMGBase):
 340    """Generates an instance segmentation without prompts, using a point grid.
 341
 342    This class implements the same logic as
 343    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
 344    It decouples the computationally expensive steps of generating masks from the cheap post-processing operation
 345    to filter these masks to enable grid search and interactively changing the post-processing.
 346
 347    Use this class as follows:
 348    ```python
 349    amg = AutomaticMaskGenerator(predictor)
 350    amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
 351    masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
 352    ```
 353
 354    Args:
 355        predictor: The segment anything predictor.
 356        points_per_side: The number of points to be sampled along one side of the image.
 357            If None, `point_grids` must provide explicit point sampling. By default, set to '32'.
 358        points_per_batch: The number of points run simultaneously by the model.
 359            Higher numbers may be faster but use more GPU memory.
 360            By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons).
 361        crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
 362            By default, set to '0'.
 363        crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'.
 364        crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
 365            By default, set to '1'.
 366        point_grids: A list over explicit grids of points used for sampling masks.
 367            Normalized to [0, 1] with respect to the image coordinate system.
 368        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
 369            By default, set to '1.0'.
 370    """
 371    def __init__(
 372        self,
 373        predictor: SamPredictor,
 374        points_per_side: Optional[int] = 32,
 375        points_per_batch: Optional[int] = None,
 376        crop_n_layers: int = 0,
 377        crop_overlap_ratio: float = 512 / 1500,
 378        crop_n_points_downscale_factor: int = 1,
 379        point_grids: Optional[List[np.ndarray]] = None,
 380        stability_score_offset: float = 1.0,
 381    ):
 382        super().__init__()
 383
 384        if points_per_side is not None:
 385            self.point_grids = amg_utils.build_all_layer_point_grids(
 386                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
 387            )
 388        elif point_grids is not None:
 389            self.point_grids = point_grids
 390        else:
 391            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
 392
 393        self._predictor = predictor
 394        self._points_per_side = points_per_side
 395
 396        # we set the points per batch to 16 for mps for performance reasons
 397        # and otherwise keep them at the default of 64
 398        if points_per_batch is None:
 399            points_per_batch = 16 if str(predictor.device) == "mps" else 64
 400        self._points_per_batch = points_per_batch
 401
 402        self._crop_n_layers = crop_n_layers
 403        self._crop_overlap_ratio = crop_overlap_ratio
 404        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
 405        self._stability_score_offset = stability_score_offset
 406
 407    def _process_batch(self, points, im_size, crop_box, original_size):
 408        # run model on this batch
 409        transformed_points = self._predictor.transform.apply_coords(points, im_size)
 410        in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
 411        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
 412        masks, iou_preds, _ = self._predictor.predict_torch(
 413            point_coords=in_points[:, None, :],
 414            point_labels=in_labels[:, None],
 415            multimask_output=True,
 416            return_logits=True,
 417        )
 418        data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
 419        del masks
 420        return data
 421
 422    def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None):
 423        # Crop the image and calculate embeddings.
 424        x0, y0, x1, y1 = crop_box
 425        cropped_im = image[y0:y1, x0:x1, :]
 426        cropped_im_size = cropped_im.shape[:2]
 427
 428        if not precomputed_embeddings:
 429            self._predictor.set_image(cropped_im)
 430
 431        # Get the points for this crop.
 432        points_scale = np.array(cropped_im_size)[None, ::-1]
 433        points_for_image = self.point_grids[crop_layer_idx] * points_scale
 434
 435        # Generate masks for this crop in batches.
 436        data = amg_utils.MaskData()
 437        n_batches = len(points_for_image) // self._points_per_batch +\
 438            int(len(points_for_image) % self._points_per_batch != 0)
 439        if pbar_init is not None:
 440            pbar_init(n_batches, "Predict masks for point grid prompts")
 441
 442        for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image):
 443            batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
 444            data.cat(batch_data)
 445            del batch_data
 446            if pbar_update is not None:
 447                pbar_update(1)
 448
 449        if not precomputed_embeddings:
 450            self._predictor.reset_image()
 451
 452        return data
 453
 454    @torch.no_grad()
 455    def initialize(
 456        self,
 457        image: np.ndarray,
 458        image_embeddings: Optional[util.ImageEmbeddings] = None,
 459        i: Optional[int] = None,
 460        verbose: bool = False,
 461        pbar_init: Optional[callable] = None,
 462        pbar_update: Optional[callable] = None,
 463    ) -> None:
 464        """Initialize image embeddings and masks for an image.
 465
 466        Args:
 467            image: The input image, volume or timeseries.
 468            image_embeddings: Optional precomputed image embeddings.
 469                See `util.precompute_image_embeddings` for details.
 470            i: Index for the image data. Required if `image` has three spatial dimensions
 471                or a time dimension and two spatial dimensions.
 472            verbose: Whether to print computation progress. By default, set to 'False'.
 473            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 474                Can be used together with pbar_update to handle napari progress bar in other thread.
 475                To enables using this function within a threadworker.
 476            pbar_update: Callback to update an external progress bar.
 477        """
 478        original_size = image.shape[:2]
 479        self._original_size = original_size
 480
 481        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
 482            original_size, self._crop_n_layers, self._crop_overlap_ratio
 483        )
 484
 485        # We can set fixed image embeddings if we only have a single crop box (the default setting).
 486        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
 487        if len(crop_boxes) == 1:
 488            if image_embeddings is None:
 489                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 490            util.set_precomputed(self._predictor, image_embeddings, i=i)
 491            precomputed_embeddings = True
 492        else:
 493            precomputed_embeddings = False
 494
 495        # we need to cast to the image representation that is compatible with SAM
 496        image = util._to_image(image)
 497
 498        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 499
 500        crop_list = []
 501        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
 502            crop_data = self._process_crop(
 503                image, crop_box, layer_idx,
 504                precomputed_embeddings=precomputed_embeddings,
 505                pbar_init=pbar_init, pbar_update=pbar_update,
 506            )
 507            crop_list.append(crop_data)
 508        pbar_close()
 509
 510        self._is_initialized = True
 511        self._crop_list = crop_list
 512        self._crop_boxes = crop_boxes
 513
 514    @torch.no_grad()
 515    def generate(
 516        self,
 517        pred_iou_thresh: float = 0.88,
 518        stability_score_thresh: float = 0.95,
 519        box_nms_thresh: float = 0.7,
 520        crop_nms_thresh: float = 0.7,
 521        min_mask_region_area: int = 0,
 522        output_mode: str = "binary_mask",
 523    ) -> List[Dict[str, Any]]:
 524        """Generate instance segmentation for the currently initialized image.
 525
 526        Args:
 527            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
 528                By default, set to '0.88'.
 529            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
 530                under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
 531            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
 532                By default, set to '0.7'.
 533            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
 534                By default, set to '0.7'.
 535            min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
 536            output_mode: The form masks are returned in. By default, set to 'binary_mask'.
 537
 538        Returns:
 539            The instance segmentation masks.
 540        """
 541        if not self.is_initialized:
 542            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
 543
 544        data = amg_utils.MaskData()
 545        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
 546            crop_data = self._postprocess_batch(
 547                data=deepcopy(data_),
 548                crop_box=crop_box, original_size=self.original_size,
 549                pred_iou_thresh=pred_iou_thresh,
 550                stability_score_thresh=stability_score_thresh,
 551                box_nms_thresh=box_nms_thresh
 552            )
 553            data.cat(crop_data)
 554
 555        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
 556            # Prefer masks from smaller crops
 557            scores = 1 / box_area(data["crop_boxes"])
 558            scores = scores.to(data["boxes"].device)
 559            keep_by_nms = batched_nms(
 560                data["boxes"].float(),
 561                scores,
 562                torch.zeros_like(data["boxes"][:, 0]),  # categories
 563                iou_threshold=crop_nms_thresh,
 564            )
 565            data.filter(keep_by_nms)
 566
 567        data.to_numpy()
 568        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
 569        return masks
 570
 571
 572# Helper function for tiled embedding computation and checking consistent state.
 573def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo, verbose, batch_size):
 574    if image_embeddings is None:
 575        if tile_shape is None or halo is None:
 576            raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.")
 577        image_embeddings = util.precompute_image_embeddings(
 578            predictor, image, tile_shape=tile_shape, halo=halo, verbose=verbose, batch_size=batch_size
 579        )
 580
 581    # Use tile shape and halo from the precomputed embeddings if not given.
 582    # Otherwise check that they are consistent.
 583    feats = image_embeddings["features"]
 584    tile_shape_, halo_ = tuple(feats.attrs["tile_shape"]), tuple(feats.attrs["halo"])
 585    if tile_shape is None:
 586        tile_shape = tile_shape_
 587    elif tile_shape != tile_shape_:
 588        raise ValueError(
 589            f"Inconsistent tile_shape parameter {tile_shape} with precomputed embeedings: {tile_shape_}."
 590        )
 591    if halo is None:
 592        halo = halo_
 593    elif halo != halo_:
 594        raise ValueError(f"Inconsistent halo parameter {halo} with precomputed embeedings: {halo_}.")
 595
 596    return image_embeddings, tile_shape, halo
 597
 598
 599class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
 600    """Generates an instance segmentation without prompts, using a point grid.
 601
 602    Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
 603
 604    Args:
 605        predictor: The Segment Anything predictor.
 606        points_per_side: The number of points to be sampled along one side of the image.
 607            If None, `point_grids` must provide explicit point sampling. By default, set to '32'.
 608        points_per_batch: The number of points run simultaneously by the model.
 609            Higher numbers may be faster but use more GPU memory. By default, set to '64'.
 610        point_grids: A list over explicit grids of points used for sampling masks.
 611            Normalized to [0, 1] with respect to the image coordinate system.
 612        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
 613            By default, set to '1.0'.
 614    """
 615
 616    # We only expose the arguments that make sense for the tiled mask generator.
 617    # Anything related to crops doesn't make sense, because we re-use that functionality
 618    # for tiling, so these parameters wouldn't have any effect.
 619    def __init__(
 620        self,
 621        predictor: SamPredictor,
 622        points_per_side: Optional[int] = 32,
 623        points_per_batch: int = 64,
 624        point_grids: Optional[List[np.ndarray]] = None,
 625        stability_score_offset: float = 1.0,
 626    ) -> None:
 627        super().__init__(
 628            predictor=predictor,
 629            points_per_side=points_per_side,
 630            points_per_batch=points_per_batch,
 631            point_grids=point_grids,
 632            stability_score_offset=stability_score_offset,
 633        )
 634
 635    @torch.no_grad()
 636    def initialize(
 637        self,
 638        image: np.ndarray,
 639        image_embeddings: Optional[util.ImageEmbeddings] = None,
 640        i: Optional[int] = None,
 641        tile_shape: Optional[Tuple[int, int]] = None,
 642        halo: Optional[Tuple[int, int]] = None,
 643        verbose: bool = False,
 644        pbar_init: Optional[callable] = None,
 645        pbar_update: Optional[callable] = None,
 646        batch_size: int = 1,
 647    ) -> None:
 648        """Initialize image embeddings and masks for an image.
 649
 650        Args:
 651            image: The input image, volume or timeseries.
 652            image_embeddings: Optional precomputed image embeddings.
 653                See `util.precompute_image_embeddings` for details.
 654            i: Index for the image data. Required if `image` has three spatial dimensions
 655                or a time dimension and two spatial dimensions.
 656            tile_shape: The tile shape for embedding prediction.
 657            halo: The overlap of between tiles.
 658            verbose: Whether to print computation progress. By default, set to 'False'.
 659            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 660                Can be used together with pbar_update to handle napari progress bar in other thread.
 661                To enables using this function within a threadworker.
 662            pbar_update: Callback to update an external progress bar.
 663            batch_size: The batch size for image embedding prediction. By default, set to '1'.
 664        """
 665        original_size = image.shape[:2]
 666        self._original_size = original_size
 667
 668        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
 669            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size
 670        )
 671
 672        tiling = blocking([0, 0], original_size, tile_shape)
 673        n_tiles = tiling.numberOfBlocks
 674
 675        # The crop box is always the full local tile.
 676        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
 677        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
 678
 679        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 680        pbar_init(n_tiles, "Compute masks for tile")
 681
 682        # We need to cast to the image representation that is compatible with SAM.
 683        image = util._to_image(image)
 684
 685        mask_data = []
 686        for tile_id in range(n_tiles):
 687            # set the pre-computed embeddings for this tile
 688            features = image_embeddings["features"][str(tile_id)]
 689            tile_embeddings = {
 690                "features": features,
 691                "input_size": features.attrs["input_size"],
 692                "original_size": features.attrs["original_size"],
 693            }
 694            util.set_precomputed(self._predictor, tile_embeddings, i)
 695
 696            # compute the mask data for this tile and append it
 697            this_mask_data = self._process_crop(
 698                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
 699            )
 700            mask_data.append(this_mask_data)
 701            pbar_update(1)
 702        pbar_close()
 703
 704        # set the initialized data
 705        self._is_initialized = True
 706        self._crop_list = mask_data
 707        self._crop_boxes = crop_boxes
 708
 709
 710#
 711# Instance segmentation functionality based on fine-tuned decoder
 712#
 713
 714
 715class DecoderAdapter(torch.nn.Module):
 716    """Adapter to contain the UNETR decoder in a single module.
 717
 718    To apply the decoder on top of pre-computed embeddings for the segmentation functionality.
 719    See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
 720    """
 721    def __init__(self, unetr: torch.nn.Module):
 722        super().__init__()
 723
 724        self.base = unetr.base
 725        self.out_conv = unetr.out_conv
 726        self.deconv_out = unetr.deconv_out
 727        self.decoder_head = unetr.decoder_head
 728        self.final_activation = unetr.final_activation
 729        self.postprocess_masks = unetr.postprocess_masks
 730
 731        self.decoder = unetr.decoder
 732        self.deconv1 = unetr.deconv1
 733        self.deconv2 = unetr.deconv2
 734        self.deconv3 = unetr.deconv3
 735        self.deconv4 = unetr.deconv4
 736
 737    def _forward_impl(self, input_):
 738        z12 = input_
 739
 740        z9 = self.deconv1(z12)
 741        z6 = self.deconv2(z9)
 742        z3 = self.deconv3(z6)
 743        z0 = self.deconv4(z3)
 744
 745        updated_from_encoder = [z9, z6, z3]
 746
 747        x = self.base(z12)
 748        x = self.decoder(x, encoder_inputs=updated_from_encoder)
 749        x = self.deconv_out(x)
 750
 751        x = torch.cat([x, z0], dim=1)
 752        x = self.decoder_head(x)
 753
 754        x = self.out_conv(x)
 755        if self.final_activation is not None:
 756            x = self.final_activation(x)
 757        return x
 758
 759    def forward(self, input_, input_shape, original_shape):
 760        x = self._forward_impl(input_)
 761        x = self.postprocess_masks(x, input_shape, original_shape)
 762        return x
 763
 764
 765def get_unetr(
 766    image_encoder: torch.nn.Module,
 767    decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
 768    device: Optional[Union[str, torch.device]] = None,
 769    out_channels: int = 3,
 770    flexible_load_checkpoint: bool = False,
 771) -> torch.nn.Module:
 772    """Get UNETR model for automatic instance segmentation.
 773
 774    Args:
 775        image_encoder: The image encoder of the SAM model.
 776            This is used as encoder by the UNETR too.
 777        decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
 778        device: The device. By default, automatically chooses the best available device.
 779        out_channels: The number of output channels. By default, set to '3'.
 780        flexible_load_checkpoint: Whether to allow reinitialization of parameters
 781            which could not be found in the provided decoder state. By default, set to 'False'.
 782
 783    Returns:
 784        The UNETR model.
 785    """
 786    device = util.get_device(device)
 787
 788    if decoder_state is None:
 789        use_conv_transpose = False  # By default, we use interpolation for upsampling.
 790    else:
 791        # From the provided pretrained 'decoder_state', we check whether it uses transposed convolutions.
 792        # NOTE: Explanation to the logic below -
 793        # - We do this by looking for parameter names that contain '.block.' within the "decoder.samplers"
 794        #   submodules. This naming convention indicates that transposed convolutions are used,
 795        #   wrapped inside a custom block module.
 796        # - Otherwise '.conv.' appears. It indicates a standard `Conv2d` applied after interpolation for upsampling.
 797        use_conv_transpose = any(".block." in k for k in decoder_state.keys() if k.startswith("decoder.samplers"))
 798
 799    unetr = UNETR(
 800        backbone="sam",
 801        encoder=image_encoder,
 802        out_channels=out_channels,
 803        use_sam_stats=True,
 804        final_activation="Sigmoid",
 805        use_skip_connection=False,
 806        resize_input=True,
 807        use_conv_transpose=use_conv_transpose,
 808
 809    )
 810
 811    if decoder_state is not None:
 812        unetr_state_dict = unetr.state_dict()
 813        for k, v in unetr_state_dict.items():
 814            if not k.startswith("encoder"):
 815                if flexible_load_checkpoint:  # Whether allow reinitalization of params, if not found.
 816                    if k in decoder_state:  # First check whether the key is available in the provided decoder state.
 817                        unetr_state_dict[k] = decoder_state[k]
 818                    else:  # Otherwise, allow it to initialize it.
 819                        warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.")
 820                        unetr_state_dict[k] = v
 821
 822                else:  # Whether be strict on finding the parameter in the decoder state.
 823                    if k not in decoder_state:
 824                        raise RuntimeError(f"The parameters for '{k}' could not be found.")
 825                    unetr_state_dict[k] = decoder_state[k]
 826
 827        unetr.load_state_dict(unetr_state_dict)
 828
 829    unetr.to(device)
 830    return unetr
 831
 832
 833def get_decoder(
 834    image_encoder: torch.nn.Module,
 835    decoder_state: OrderedDict[str, torch.Tensor],
 836    device: Optional[Union[str, torch.device]] = None,
 837) -> DecoderAdapter:
 838    """Get decoder to predict outputs for automatic instance segmentation
 839
 840    Args:
 841        image_encoder: The image encoder of the SAM model.
 842        decoder_state: State to initialize the weights of the UNETR decoder.
 843        device: The device. By default, automatically chooses the best available device.
 844
 845    Returns:
 846        The decoder for instance segmentation.
 847    """
 848    unetr = get_unetr(image_encoder, decoder_state, device)
 849    return DecoderAdapter(unetr)
 850
 851
 852def get_predictor_and_decoder(
 853    model_type: str,
 854    checkpoint_path: Union[str, os.PathLike],
 855    device: Optional[Union[str, torch.device]] = None,
 856    peft_kwargs: Optional[Dict] = None,
 857) -> Tuple[SamPredictor, DecoderAdapter]:
 858    """Load the SAM model (predictor) and instance segmentation decoder.
 859
 860    This requires a checkpoint that contains the state for both predictor
 861    and decoder.
 862
 863    Args:
 864        model_type: The type of the image encoder used in the SAM model.
 865        checkpoint_path: Path to the checkpoint from which to load the data.
 866        device: The device. By default, automatically chooses the best available device.
 867        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 868
 869    Returns:
 870        The SAM predictor.
 871        The decoder for instance segmentation.
 872    """
 873    device = util.get_device(device)
 874    predictor, state = util.get_sam_model(
 875        model_type=model_type,
 876        checkpoint_path=checkpoint_path,
 877        device=device,
 878        return_state=True,
 879        peft_kwargs=peft_kwargs,
 880    )
 881
 882    if "decoder_state" not in state:
 883        raise ValueError(
 884            f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
 885        )
 886
 887    decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
 888    return predictor, decoder
 889
 890
 891def _watershed_from_center_and_boundary_distances_parallel(
 892    center_distances,
 893    boundary_distances,
 894    foreground_map,
 895    center_distance_threshold,
 896    boundary_distance_threshold,
 897    foreground_threshold,
 898    distance_smoothing,
 899    min_size,
 900    tile_shape,
 901    halo,
 902    n_threads,
 903    verbose=False,
 904):
 905    center_distances = apply_filter(
 906        center_distances, "gaussianSmoothing", sigma=distance_smoothing,
 907        block_shape=tile_shape, n_threads=n_threads
 908    )
 909    boundary_distances = apply_filter(
 910        boundary_distances, "gaussianSmoothing", sigma=distance_smoothing,
 911        block_shape=tile_shape, n_threads=n_threads
 912    )
 913
 914    fg_mask = foreground_map > foreground_threshold
 915
 916    marker_map = np.logical_and(
 917        center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold
 918    )
 919    marker_map[~fg_mask] = 0
 920
 921    markers = np.zeros(marker_map.shape, dtype="uint64")
 922    markers = parallel.label(
 923        marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose,
 924    )
 925
 926    seg = np.zeros_like(markers, dtype="uint64")
 927    seg = parallel.seeded_watershed(
 928        boundary_distances, seeds=markers, out=seg, block_shape=tile_shape,
 929        halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask,
 930    )
 931
 932    out = np.zeros_like(seg, dtype="uint64")
 933    out = parallel.size_filter(
 934        seg, out=out, min_size=min_size, block_shape=tile_shape, n_threads=n_threads, verbose=verbose
 935    )
 936
 937    return out
 938
 939
 940class InstanceSegmentationWithDecoder:
 941    """Generates an instance segmentation without prompts, using a decoder.
 942
 943    Implements the same interface as `AutomaticMaskGenerator`.
 944
 945    Use this class as follows:
 946    ```python
 947    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
 948    segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
 949    masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
 950    ```
 951
 952    Args:
 953        predictor: The segment anything predictor.
 954        decoder: The decoder to predict intermediate representations
 955            for instance segmentation.
 956    """
 957    def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None:
 958        self._predictor = predictor
 959        self._decoder = decoder
 960
 961        # The decoder outputs.
 962        self._foreground = None
 963        self._center_distances = None
 964        self._boundary_distances = None
 965
 966        self._is_initialized = False
 967
 968    @property
 969    def is_initialized(self):
 970        """Whether the mask generator has already been initialized.
 971        """
 972        return self._is_initialized
 973
 974    @torch.no_grad()
 975    def initialize(
 976        self,
 977        image: np.ndarray,
 978        image_embeddings: Optional[util.ImageEmbeddings] = None,
 979        i: Optional[int] = None,
 980        verbose: bool = False,
 981        pbar_init: Optional[callable] = None,
 982        pbar_update: Optional[callable] = None,
 983    ) -> None:
 984        """Initialize image embeddings and decoder predictions for an image.
 985
 986        Args:
 987            image: The input image, volume or timeseries.
 988            image_embeddings: Optional precomputed image embeddings.
 989                See `util.precompute_image_embeddings` for details.
 990            i: Index for the image data. Required if `image` has three spatial dimensions
 991                or a time dimension and two spatial dimensions.
 992            verbose: Whether to be verbose. By default, set to 'False'.
 993            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 994                Can be used together with pbar_update to handle napari progress bar in other thread.
 995                To enables using this function within a threadworker.
 996            pbar_update: Callback to update an external progress bar.
 997        """
 998        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 999        pbar_init(1, "Initialize instance segmentation with decoder")
1000
1001        if image_embeddings is None:
1002            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
1003
1004        # Get the image embeddings from the predictor.
1005        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
1006        embeddings = self._predictor.features
1007        input_shape = tuple(self._predictor.input_size)
1008        original_shape = tuple(self._predictor.original_size)
1009
1010        # Run prediction with the UNETR decoder.
1011        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1012        assert output.shape[0] == 3, f"{output.shape}"
1013        pbar_update(1)
1014        pbar_close()
1015
1016        # Set the state.
1017        self._foreground = output[0]
1018        self._center_distances = output[1]
1019        self._boundary_distances = output[2]
1020        self._is_initialized = True
1021
1022    def _to_masks(self, segmentation, output_mode):
1023        if output_mode != "binary_mask":
1024            raise NotImplementedError
1025
1026        props = regionprops(segmentation)
1027        ndim = segmentation.ndim
1028        assert ndim in (2, 3)
1029
1030        shape = segmentation.shape
1031        if ndim == 2:
1032            crop_box = [0, shape[1], 0, shape[0]]
1033        else:
1034            crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]
1035
1036        # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
1037        def to_bbox_2d(bbox):
1038            y0, x0 = bbox[0], bbox[1]
1039            w = bbox[3] - x0
1040            h = bbox[2] - y0
1041            return [x0, w, y0, h]
1042
1043        def to_bbox_3d(bbox):
1044            z0, y0, x0 = bbox[0], bbox[1], bbox[2]
1045            w = bbox[5] - x0
1046            h = bbox[4] - y0
1047            d = bbox[3] - y0
1048            return [x0, w, y0, h, z0, d]
1049
1050        to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
1051        masks = [
1052            {
1053                "segmentation": segmentation == prop.label,
1054                "area": prop.area,
1055                "bbox": to_bbox(prop.bbox),
1056                "crop_box": crop_box,
1057                "seg_id": prop.label,
1058            } for prop in props
1059        ]
1060        return masks
1061
1062    def generate(
1063        self,
1064        center_distance_threshold: float = 0.5,
1065        boundary_distance_threshold: float = 0.5,
1066        foreground_threshold: float = 0.5,
1067        foreground_smoothing: float = 1.0,
1068        distance_smoothing: float = 1.6,
1069        min_size: int = 0,
1070        output_mode: Optional[str] = "binary_mask",
1071        tile_shape: Optional[Tuple[int, int]] = None,
1072        halo: Optional[Tuple[int, int]] = None,
1073        n_threads: Optional[int] = None,
1074    ) -> List[Dict[str, Any]]:
1075        """Generate instance segmentation for the currently initialized image.
1076
1077        Args:
1078            center_distance_threshold: Center distance predictions below this value will be
1079                used to find seeds (intersected with thresholded boundary distance predictions).
1080                By default, set to '0.5'.
1081            boundary_distance_threshold: Boundary distance predictions below this value will be
1082                used to find seeds (intersected with thresholded center distance predictions).
1083                By default, set to '0.5'.
1084            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1085                By default, set to '0.5'.
1086            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1087                checkerboard artifacts in the prediction. By default, set to '1.0'.
1088            distance_smoothing: Sigma value for smoothing the distance predictions.
1089            min_size: Minimal object size in the segmentation result. By default, set to '0'.
1090            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
1091                By default, set to 'binary_mask'.
1092            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1093                This parameter is independent from the tile shape for computing the embeddings.
1094                If not given then post-processing will not be parallelized.
1095            halo: Halo for parallel post-processing. See also `tile_shape`.
1096            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1097
1098        Returns:
1099            The instance segmentation masks.
1100        """
1101        if not self.is_initialized:
1102            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1103
1104        if foreground_smoothing > 0:
1105            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1106        else:
1107            foreground = self._foreground
1108
1109        if tile_shape is None:
1110            segmentation = watershed_from_center_and_boundary_distances(
1111                center_distances=self._center_distances,
1112                boundary_distances=self._boundary_distances,
1113                foreground_map=foreground,
1114                center_distance_threshold=center_distance_threshold,
1115                boundary_distance_threshold=boundary_distance_threshold,
1116                foreground_threshold=foreground_threshold,
1117                distance_smoothing=distance_smoothing,
1118                min_size=min_size,
1119            )
1120        else:
1121            if halo is None:
1122                raise ValueError("You must pass a value for halo if tile_shape is given.")
1123            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1124                center_distances=self._center_distances,
1125                boundary_distances=self._boundary_distances,
1126                foreground_map=foreground,
1127                center_distance_threshold=center_distance_threshold,
1128                boundary_distance_threshold=boundary_distance_threshold,
1129                foreground_threshold=foreground_threshold,
1130                distance_smoothing=distance_smoothing,
1131                min_size=min_size,
1132                tile_shape=tile_shape,
1133                halo=halo,
1134                n_threads=n_threads,
1135                verbose=False,
1136            )
1137
1138        if output_mode is not None:
1139            segmentation = self._to_masks(segmentation, output_mode)
1140        return segmentation
1141
1142    def get_state(self) -> Dict[str, Any]:
1143        """Get the initialized state of the instance segmenter.
1144
1145        Returns:
1146            Instance segmentation state.
1147        """
1148        if not self.is_initialized:
1149            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1150
1151        return {
1152            "foreground": self._foreground,
1153            "center_distances": self._center_distances,
1154            "boundary_distances": self._boundary_distances,
1155        }
1156
1157    def set_state(self, state: Dict[str, Any]) -> None:
1158        """Set the state of the instance segmenter.
1159
1160        Args:
1161            state: The instance segmentation state
1162        """
1163        self._foreground = state["foreground"]
1164        self._center_distances = state["center_distances"]
1165        self._boundary_distances = state["boundary_distances"]
1166        self._is_initialized = True
1167
1168    def clear_state(self):
1169        """Clear the state of the instance segmenter.
1170        """
1171        self._foreground = None
1172        self._center_distances = None
1173        self._boundary_distances = None
1174        self._is_initialized = False
1175
1176
1177class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1178    """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings.
1179    """
1180
1181    # Apply the decoder in a batched fashion, and then perform the resizing independently per output.
1182    # This is necessary, because the individual tiles may have different tile shapes due to border tiles.
1183    def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes):
1184        batched_embeddings = torch.cat(batched_embeddings)
1185        output = self._decoder._forward_impl(batched_embeddings)
1186
1187        batched_output = []
1188        for x, input_shape, original_shape in zip(output, input_shapes, original_shapes):
1189            x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0)
1190            batched_output.append(x.cpu().numpy())
1191        return batched_output
1192
1193    @torch.no_grad()
1194    def initialize(
1195        self,
1196        image: np.ndarray,
1197        image_embeddings: Optional[util.ImageEmbeddings] = None,
1198        i: Optional[int] = None,
1199        tile_shape: Optional[Tuple[int, int]] = None,
1200        halo: Optional[Tuple[int, int]] = None,
1201        verbose: bool = False,
1202        pbar_init: Optional[callable] = None,
1203        pbar_update: Optional[callable] = None,
1204        batch_size: int = 1,
1205    ) -> None:
1206        """Initialize image embeddings and decoder predictions for an image.
1207
1208        Args:
1209            image: The input image, volume or timeseries.
1210            image_embeddings: Optional precomputed image embeddings.
1211                See `util.precompute_image_embeddings` for details.
1212            i: Index for the image data. Required if `image` has three spatial dimensions
1213                or a time dimension and two spatial dimensions.
1214            tile_shape: Shape of the tiles for precomputing image embeddings.
1215            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1216            verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
1217            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1218                Can be used together with pbar_update to handle napari progress bar in other thread.
1219                To enables using this function within a threadworker.
1220            pbar_update: Callback to update an external progress bar.
1221            batch_size: The batch size for image embedding computation and segmentation decoder prediction.
1222        """
1223        original_size = image.shape[:2]
1224        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1225            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size
1226        )
1227        tiling = blocking([0, 0], original_size, tile_shape)
1228
1229        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1230        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1231
1232        foreground = np.zeros(original_size, dtype="float32")
1233        center_distances = np.zeros(original_size, dtype="float32")
1234        boundary_distances = np.zeros(original_size, dtype="float32")
1235
1236        n_tiles = tiling.numberOfBlocks
1237        n_batches = int(np.ceil(n_tiles / batch_size))
1238
1239        for batch_id in range(n_batches):
1240            tile_start = batch_id * batch_size
1241            tile_stop = min(tile_start + batch_size, n_tiles)
1242
1243            batched_embeddings, input_shapes, original_shapes = [], [], []
1244            for tile_id in range(tile_start, tile_stop):
1245                # Get the image embeddings from the predictor for this tile.
1246                self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1247
1248                batched_embeddings.append(self._predictor.features)
1249                input_shapes.append(tuple(self._predictor.input_size))
1250                original_shapes.append(tuple(self._predictor.original_size))
1251
1252            batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes)
1253
1254            for output_id, tile_id in enumerate(range(tile_start, tile_stop)):
1255                output = batched_output[output_id]
1256                assert output.shape[0] == 3
1257
1258                # Set the predictions in the output for this tile.
1259                block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1260                local_bb = tuple(
1261                    slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1262                )
1263                inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1264
1265                foreground[inner_bb] = output[0][local_bb]
1266                center_distances[inner_bb] = output[1][local_bb]
1267                boundary_distances[inner_bb] = output[2][local_bb]
1268                pbar_update(1)
1269
1270        pbar_close()
1271
1272        # Set the state.
1273        self._foreground = foreground
1274        self._center_distances = center_distances
1275        self._boundary_distances = boundary_distances
1276        self._is_initialized = True
1277
1278
1279def get_amg(
1280    predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs,
1281) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1282    """Get the automatic mask generator class.
1283
1284    Args:
1285        predictor: The segment anything predictor.
1286        is_tiled: Whether tiled embeddings are used.
1287        decoder: Decoder to predict instacne segmmentation.
1288        kwargs: The keyword arguments for the amg class.
1289
1290    Returns:
1291        The automatic mask generator.
1292    """
1293    if decoder is None:
1294        segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator
1295        segmenter = segmenter_class(predictor, **kwargs)
1296    else:
1297        segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder
1298        segmenter = segmenter_class(predictor, decoder, **kwargs)
1299
1300    return segmenter
def mask_data_to_segmentation( masks: List[Dict[str, Any]], with_background: bool, min_object_size: int = 0, max_object_size: Optional[int] = None, label_masks: bool = True) -> numpy.ndarray:
 50def mask_data_to_segmentation(
 51    masks: List[Dict[str, Any]],
 52    with_background: bool,
 53    min_object_size: int = 0,
 54    max_object_size: Optional[int] = None,
 55    label_masks: bool = True,
 56) -> np.ndarray:
 57    """Convert the output of the automatic mask generation to an instance segmentation.
 58
 59    Args:
 60        masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator.
 61            Only supports output_mode=binary_mask.
 62        with_background: Whether the segmentation has background. If yes this function assures that the largest
 63            object in the output will be mapped to zero (the background value).
 64        min_object_size: The minimal size of an object in pixels. By default, set to '0'.
 65        max_object_size: The maximal size of an object in pixels.
 66        label_masks: Whether to apply connected components to the result before removing small objects.
 67            By default, set to 'True'.
 68
 69    Returns:
 70        The instance segmentation.
 71    """
 72
 73    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
 74    # we could also get the shape from the crop box
 75    shape = next(iter(masks))["segmentation"].shape
 76    segmentation = np.zeros(shape, dtype="uint32")
 77
 78    def require_numpy(mask):
 79        return mask.cpu().numpy() if torch.is_tensor(mask) else mask
 80
 81    seg_id = 1
 82    for mask in masks:
 83        if mask["area"] < min_object_size:
 84            continue
 85        if max_object_size is not None and mask["area"] > max_object_size:
 86            continue
 87
 88        this_seg_id = mask.get("seg_id", seg_id)
 89        segmentation[require_numpy(mask["segmentation"])] = this_seg_id
 90        seg_id = this_seg_id + 1
 91
 92    if label_masks:
 93        segmentation = label(segmentation).astype(segmentation.dtype)
 94
 95    seg_ids, sizes = np.unique(segmentation, return_counts=True)
 96
 97    # In some cases objects may be smaller than peviously calculated,
 98    # since they are covered by other objects. We ensure these also get
 99    # filtered out here.
100    filter_ids = seg_ids[sizes < min_object_size]
101
102    # If we run segmentation with background we also map the largest segment
103    # (the most likely background object) to zero. This is often zero already,
104    # but it does not hurt to reset that to zero either.
105    if with_background:
106        bg_id = seg_ids[np.argmax(sizes)]
107        filter_ids = np.concatenate([filter_ids, [bg_id]])
108
109    segmentation[np.isin(segmentation, filter_ids)] = 0
110    segmentation = relabel_sequential(segmentation)[0]
111
112    return segmentation

Convert the output of the automatic mask generation to an instance segmentation.

Arguments:
  • masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. Only supports output_mode=binary_mask.
  • with_background: Whether the segmentation has background. If yes this function assures that the largest object in the output will be mapped to zero (the background value).
  • min_object_size: The minimal size of an object in pixels. By default, set to '0'.
  • max_object_size: The maximal size of an object in pixels.
  • label_masks: Whether to apply connected components to the result before removing small objects. By default, set to 'True'.
Returns:

The instance segmentation.

class AMGBase(abc.ABC):
120class AMGBase(ABC):
121    """Base class for the automatic mask generators.
122    """
123    def __init__(self):
124        # the state that has to be computed by the 'initialize' method of the child classes
125        self._is_initialized = False
126        self._crop_list = None
127        self._crop_boxes = None
128        self._original_size = None
129
130    @property
131    def is_initialized(self):
132        """Whether the mask generator has already been initialized.
133        """
134        return self._is_initialized
135
136    @property
137    def crop_list(self):
138        """The list of mask data after initialization.
139        """
140        return self._crop_list
141
142    @property
143    def crop_boxes(self):
144        """The list of crop boxes.
145        """
146        return self._crop_boxes
147
148    @property
149    def original_size(self):
150        """The original image size.
151        """
152        return self._original_size
153
154    def _postprocess_batch(
155        self,
156        data,
157        crop_box,
158        original_size,
159        pred_iou_thresh,
160        stability_score_thresh,
161        box_nms_thresh,
162    ):
163        orig_h, orig_w = original_size
164
165        # filter by predicted IoU
166        if pred_iou_thresh > 0.0:
167            keep_mask = data["iou_preds"] > pred_iou_thresh
168            data.filter(keep_mask)
169
170        # filter by stability score
171        if stability_score_thresh > 0.0:
172            keep_mask = data["stability_score"] >= stability_score_thresh
173            data.filter(keep_mask)
174
175        # filter boxes that touch crop boundaries
176        keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
177        if not torch.all(keep_mask):
178            data.filter(keep_mask)
179
180        # remove duplicates within this crop.
181        keep_by_nms = batched_nms(
182            data["boxes"].float(),
183            data["iou_preds"],
184            torch.zeros_like(data["boxes"][:, 0]),  # categories
185            iou_threshold=box_nms_thresh,
186        )
187        data.filter(keep_by_nms)
188
189        # return to the original image frame
190        data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box)
191        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
192        # the data from embedding based segmentation doesn't have the points
193        # so we skip if the corresponding key can't be found
194        try:
195            data["points"] = amg_utils.uncrop_points(data["points"], crop_box)
196        except KeyError:
197            pass
198
199        return data
200
201    def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
202
203        if len(mask_data["rles"]) == 0:
204            return mask_data
205
206        # filter small disconnected regions and holes
207        new_masks = []
208        scores = []
209        for rle in mask_data["rles"]:
210            mask = amg_utils.rle_to_mask(rle)
211
212            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes")
213            unchanged = not changed
214            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands")
215            unchanged = unchanged and not changed
216
217            new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0))
218            # give score=0 to changed masks and score=1 to unchanged masks
219            # so NMS will prefer ones that didn't need postprocessing
220            scores.append(float(unchanged))
221
222        # recalculate boxes and remove any new duplicates
223        masks = torch.cat(new_masks, dim=0)
224        boxes = batched_mask_to_box(masks.to(torch.bool))  # Casting this to boolean as we work with one-hot labels.
225        keep_by_nms = batched_nms(
226            boxes.float(),
227            torch.as_tensor(scores, dtype=torch.float),
228            torch.zeros_like(boxes[:, 0]),  # categories
229            iou_threshold=nms_thresh,
230        )
231
232        # only recalculate RLEs for masks that have changed
233        for i_mask in keep_by_nms:
234            if scores[i_mask] == 0.0:
235                mask_torch = masks[i_mask].unsqueeze(0)
236                # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
237                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
238                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
239        mask_data.filter(keep_by_nms)
240
241        return mask_data
242
243    def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode):
244        # filter small disconnected regions and holes in masks
245        if min_mask_region_area > 0:
246            mask_data = self._postprocess_small_regions(
247                mask_data,
248                min_mask_region_area,
249                max(box_nms_thresh, crop_nms_thresh),
250            )
251
252        # encode masks
253        if output_mode == "coco_rle":
254            mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]]
255        elif output_mode == "binary_mask":
256            mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]]
257        else:
258            mask_data["segmentations"] = mask_data["rles"]
259
260        # write mask records
261        curr_anns = []
262        for idx in range(len(mask_data["segmentations"])):
263            ann = {
264                "segmentation": mask_data["segmentations"][idx],
265                "area": amg_utils.area_from_rle(mask_data["rles"][idx]),
266                "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
267                "predicted_iou": mask_data["iou_preds"][idx].item(),
268                "stability_score": mask_data["stability_score"][idx].item(),
269                "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
270            }
271            # the data from embedding based segmentation doesn't have the points
272            # so we skip if the corresponding key can't be found
273            try:
274                ann["point_coords"] = [mask_data["points"][idx].tolist()]
275            except KeyError:
276                pass
277            curr_anns.append(ann)
278
279        return curr_anns
280
281    def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
282        orig_h, orig_w = original_size
283
284        # serialize predictions and store in MaskData
285        data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
286        if points is not None:
287            data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
288
289        del masks
290
291        # calculate the stability scores
292        data["stability_score"] = amg_utils.calculate_stability_score(
293            data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
294        )
295
296        # threshold masks and calculate boxes
297        data["masks"] = data["masks"] > self._predictor.model.mask_threshold
298        data["masks"] = data["masks"].type(torch.bool)
299        data["boxes"] = batched_mask_to_box(data["masks"])
300
301        # compress to RLE
302        data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
303        # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
304        data["rles"] = mask_to_rle_pytorch(data["masks"])
305        del data["masks"]
306
307        return data
308
309    def get_state(self) -> Dict[str, Any]:
310        """Get the initialized state of the mask generator.
311
312        Returns:
313            State of the mask generator.
314        """
315        if not self.is_initialized:
316            raise RuntimeError("The state has not been computed yet. Call initialize first.")
317
318        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
319
320    def set_state(self, state: Dict[str, Any]) -> None:
321        """Set the state of the mask generator.
322
323        Args:
324            state: The state of the mask generator, e.g. from serialized state.
325        """
326        self._crop_list = state["crop_list"]
327        self._crop_boxes = state["crop_boxes"]
328        self._original_size = state["original_size"]
329        self._is_initialized = True
330
331    def clear_state(self):
332        """Clear the state of the mask generator.
333        """
334        self._crop_list = None
335        self._crop_boxes = None
336        self._original_size = None
337        self._is_initialized = False

Base class for the automatic mask generators.

is_initialized
130    @property
131    def is_initialized(self):
132        """Whether the mask generator has already been initialized.
133        """
134        return self._is_initialized

Whether the mask generator has already been initialized.

crop_list
136    @property
137    def crop_list(self):
138        """The list of mask data after initialization.
139        """
140        return self._crop_list

The list of mask data after initialization.

crop_boxes
142    @property
143    def crop_boxes(self):
144        """The list of crop boxes.
145        """
146        return self._crop_boxes

The list of crop boxes.

original_size
148    @property
149    def original_size(self):
150        """The original image size.
151        """
152        return self._original_size

The original image size.

def get_state(self) -> Dict[str, Any]:
309    def get_state(self) -> Dict[str, Any]:
310        """Get the initialized state of the mask generator.
311
312        Returns:
313            State of the mask generator.
314        """
315        if not self.is_initialized:
316            raise RuntimeError("The state has not been computed yet. Call initialize first.")
317
318        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}

Get the initialized state of the mask generator.

Returns:

State of the mask generator.

def set_state(self, state: Dict[str, Any]) -> None:
320    def set_state(self, state: Dict[str, Any]) -> None:
321        """Set the state of the mask generator.
322
323        Args:
324            state: The state of the mask generator, e.g. from serialized state.
325        """
326        self._crop_list = state["crop_list"]
327        self._crop_boxes = state["crop_boxes"]
328        self._original_size = state["original_size"]
329        self._is_initialized = True

Set the state of the mask generator.

Arguments:
  • state: The state of the mask generator, e.g. from serialized state.
def clear_state(self):
331    def clear_state(self):
332        """Clear the state of the mask generator.
333        """
334        self._crop_list = None
335        self._crop_boxes = None
336        self._original_size = None
337        self._is_initialized = False

Clear the state of the mask generator.

class AutomaticMaskGenerator(AMGBase):
340class AutomaticMaskGenerator(AMGBase):
341    """Generates an instance segmentation without prompts, using a point grid.
342
343    This class implements the same logic as
344    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
345    It decouples the computationally expensive steps of generating masks from the cheap post-processing operation
346    to filter these masks to enable grid search and interactively changing the post-processing.
347
348    Use this class as follows:
349    ```python
350    amg = AutomaticMaskGenerator(predictor)
351    amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
352    masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
353    ```
354
355    Args:
356        predictor: The segment anything predictor.
357        points_per_side: The number of points to be sampled along one side of the image.
358            If None, `point_grids` must provide explicit point sampling. By default, set to '32'.
359        points_per_batch: The number of points run simultaneously by the model.
360            Higher numbers may be faster but use more GPU memory.
361            By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons).
362        crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
363            By default, set to '0'.
364        crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'.
365        crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
366            By default, set to '1'.
367        point_grids: A list over explicit grids of points used for sampling masks.
368            Normalized to [0, 1] with respect to the image coordinate system.
369        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
370            By default, set to '1.0'.
371    """
372    def __init__(
373        self,
374        predictor: SamPredictor,
375        points_per_side: Optional[int] = 32,
376        points_per_batch: Optional[int] = None,
377        crop_n_layers: int = 0,
378        crop_overlap_ratio: float = 512 / 1500,
379        crop_n_points_downscale_factor: int = 1,
380        point_grids: Optional[List[np.ndarray]] = None,
381        stability_score_offset: float = 1.0,
382    ):
383        super().__init__()
384
385        if points_per_side is not None:
386            self.point_grids = amg_utils.build_all_layer_point_grids(
387                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
388            )
389        elif point_grids is not None:
390            self.point_grids = point_grids
391        else:
392            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
393
394        self._predictor = predictor
395        self._points_per_side = points_per_side
396
397        # we set the points per batch to 16 for mps for performance reasons
398        # and otherwise keep them at the default of 64
399        if points_per_batch is None:
400            points_per_batch = 16 if str(predictor.device) == "mps" else 64
401        self._points_per_batch = points_per_batch
402
403        self._crop_n_layers = crop_n_layers
404        self._crop_overlap_ratio = crop_overlap_ratio
405        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
406        self._stability_score_offset = stability_score_offset
407
408    def _process_batch(self, points, im_size, crop_box, original_size):
409        # run model on this batch
410        transformed_points = self._predictor.transform.apply_coords(points, im_size)
411        in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
412        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
413        masks, iou_preds, _ = self._predictor.predict_torch(
414            point_coords=in_points[:, None, :],
415            point_labels=in_labels[:, None],
416            multimask_output=True,
417            return_logits=True,
418        )
419        data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
420        del masks
421        return data
422
423    def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None):
424        # Crop the image and calculate embeddings.
425        x0, y0, x1, y1 = crop_box
426        cropped_im = image[y0:y1, x0:x1, :]
427        cropped_im_size = cropped_im.shape[:2]
428
429        if not precomputed_embeddings:
430            self._predictor.set_image(cropped_im)
431
432        # Get the points for this crop.
433        points_scale = np.array(cropped_im_size)[None, ::-1]
434        points_for_image = self.point_grids[crop_layer_idx] * points_scale
435
436        # Generate masks for this crop in batches.
437        data = amg_utils.MaskData()
438        n_batches = len(points_for_image) // self._points_per_batch +\
439            int(len(points_for_image) % self._points_per_batch != 0)
440        if pbar_init is not None:
441            pbar_init(n_batches, "Predict masks for point grid prompts")
442
443        for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image):
444            batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
445            data.cat(batch_data)
446            del batch_data
447            if pbar_update is not None:
448                pbar_update(1)
449
450        if not precomputed_embeddings:
451            self._predictor.reset_image()
452
453        return data
454
455    @torch.no_grad()
456    def initialize(
457        self,
458        image: np.ndarray,
459        image_embeddings: Optional[util.ImageEmbeddings] = None,
460        i: Optional[int] = None,
461        verbose: bool = False,
462        pbar_init: Optional[callable] = None,
463        pbar_update: Optional[callable] = None,
464    ) -> None:
465        """Initialize image embeddings and masks for an image.
466
467        Args:
468            image: The input image, volume or timeseries.
469            image_embeddings: Optional precomputed image embeddings.
470                See `util.precompute_image_embeddings` for details.
471            i: Index for the image data. Required if `image` has three spatial dimensions
472                or a time dimension and two spatial dimensions.
473            verbose: Whether to print computation progress. By default, set to 'False'.
474            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
475                Can be used together with pbar_update to handle napari progress bar in other thread.
476                To enables using this function within a threadworker.
477            pbar_update: Callback to update an external progress bar.
478        """
479        original_size = image.shape[:2]
480        self._original_size = original_size
481
482        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
483            original_size, self._crop_n_layers, self._crop_overlap_ratio
484        )
485
486        # We can set fixed image embeddings if we only have a single crop box (the default setting).
487        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
488        if len(crop_boxes) == 1:
489            if image_embeddings is None:
490                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
491            util.set_precomputed(self._predictor, image_embeddings, i=i)
492            precomputed_embeddings = True
493        else:
494            precomputed_embeddings = False
495
496        # we need to cast to the image representation that is compatible with SAM
497        image = util._to_image(image)
498
499        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
500
501        crop_list = []
502        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
503            crop_data = self._process_crop(
504                image, crop_box, layer_idx,
505                precomputed_embeddings=precomputed_embeddings,
506                pbar_init=pbar_init, pbar_update=pbar_update,
507            )
508            crop_list.append(crop_data)
509        pbar_close()
510
511        self._is_initialized = True
512        self._crop_list = crop_list
513        self._crop_boxes = crop_boxes
514
515    @torch.no_grad()
516    def generate(
517        self,
518        pred_iou_thresh: float = 0.88,
519        stability_score_thresh: float = 0.95,
520        box_nms_thresh: float = 0.7,
521        crop_nms_thresh: float = 0.7,
522        min_mask_region_area: int = 0,
523        output_mode: str = "binary_mask",
524    ) -> List[Dict[str, Any]]:
525        """Generate instance segmentation for the currently initialized image.
526
527        Args:
528            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
529                By default, set to '0.88'.
530            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
531                under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
532            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
533                By default, set to '0.7'.
534            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
535                By default, set to '0.7'.
536            min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
537            output_mode: The form masks are returned in. By default, set to 'binary_mask'.
538
539        Returns:
540            The instance segmentation masks.
541        """
542        if not self.is_initialized:
543            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
544
545        data = amg_utils.MaskData()
546        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
547            crop_data = self._postprocess_batch(
548                data=deepcopy(data_),
549                crop_box=crop_box, original_size=self.original_size,
550                pred_iou_thresh=pred_iou_thresh,
551                stability_score_thresh=stability_score_thresh,
552                box_nms_thresh=box_nms_thresh
553            )
554            data.cat(crop_data)
555
556        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
557            # Prefer masks from smaller crops
558            scores = 1 / box_area(data["crop_boxes"])
559            scores = scores.to(data["boxes"].device)
560            keep_by_nms = batched_nms(
561                data["boxes"].float(),
562                scores,
563                torch.zeros_like(data["boxes"][:, 0]),  # categories
564                iou_threshold=crop_nms_thresh,
565            )
566            data.filter(keep_by_nms)
567
568        data.to_numpy()
569        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
570        return masks

Generates an instance segmentation without prompts, using a point grid.

This class implements the same logic as https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py It decouples the computationally expensive steps of generating masks from the cheap post-processing operation to filter these masks to enable grid search and interactively changing the post-processing.

Use this class as follows:

amg = AutomaticMaskGenerator(predictor)
amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
Arguments:
  • predictor: The segment anything predictor.
  • points_per_side: The number of points to be sampled along one side of the image. If None, point_grids must provide explicit point sampling. By default, set to '32'.
  • points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. By default, selects '64' for all devices except 'mps' (selects '16' for performance reasons).
  • crop_n_layers: If >0, the mask prediction will be run again on crops of the image. By default, set to '0'.
  • crop_overlap_ratio: Sets the degree to which crops overlap. By default, set to '512 / 1500'.
  • crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops. By default, set to '1'.
  • point_grids: A list over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
  • stability_score_offset: The amount to shift the cutoff when calculating the stability score. By default, set to '1.0'.
AutomaticMaskGenerator( predictor: segment_anything.predictor.SamPredictor, points_per_side: Optional[int] = 32, points_per_batch: Optional[int] = None, crop_n_layers: int = 0, crop_overlap_ratio: float = 0.3413333333333333, crop_n_points_downscale_factor: int = 1, point_grids: Optional[List[numpy.ndarray]] = None, stability_score_offset: float = 1.0)
372    def __init__(
373        self,
374        predictor: SamPredictor,
375        points_per_side: Optional[int] = 32,
376        points_per_batch: Optional[int] = None,
377        crop_n_layers: int = 0,
378        crop_overlap_ratio: float = 512 / 1500,
379        crop_n_points_downscale_factor: int = 1,
380        point_grids: Optional[List[np.ndarray]] = None,
381        stability_score_offset: float = 1.0,
382    ):
383        super().__init__()
384
385        if points_per_side is not None:
386            self.point_grids = amg_utils.build_all_layer_point_grids(
387                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
388            )
389        elif point_grids is not None:
390            self.point_grids = point_grids
391        else:
392            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
393
394        self._predictor = predictor
395        self._points_per_side = points_per_side
396
397        # we set the points per batch to 16 for mps for performance reasons
398        # and otherwise keep them at the default of 64
399        if points_per_batch is None:
400            points_per_batch = 16 if str(predictor.device) == "mps" else 64
401        self._points_per_batch = points_per_batch
402
403        self._crop_n_layers = crop_n_layers
404        self._crop_overlap_ratio = crop_overlap_ratio
405        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
406        self._stability_score_offset = stability_score_offset
@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
455    @torch.no_grad()
456    def initialize(
457        self,
458        image: np.ndarray,
459        image_embeddings: Optional[util.ImageEmbeddings] = None,
460        i: Optional[int] = None,
461        verbose: bool = False,
462        pbar_init: Optional[callable] = None,
463        pbar_update: Optional[callable] = None,
464    ) -> None:
465        """Initialize image embeddings and masks for an image.
466
467        Args:
468            image: The input image, volume or timeseries.
469            image_embeddings: Optional precomputed image embeddings.
470                See `util.precompute_image_embeddings` for details.
471            i: Index for the image data. Required if `image` has three spatial dimensions
472                or a time dimension and two spatial dimensions.
473            verbose: Whether to print computation progress. By default, set to 'False'.
474            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
475                Can be used together with pbar_update to handle napari progress bar in other thread.
476                To enables using this function within a threadworker.
477            pbar_update: Callback to update an external progress bar.
478        """
479        original_size = image.shape[:2]
480        self._original_size = original_size
481
482        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
483            original_size, self._crop_n_layers, self._crop_overlap_ratio
484        )
485
486        # We can set fixed image embeddings if we only have a single crop box (the default setting).
487        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
488        if len(crop_boxes) == 1:
489            if image_embeddings is None:
490                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
491            util.set_precomputed(self._predictor, image_embeddings, i=i)
492            precomputed_embeddings = True
493        else:
494            precomputed_embeddings = False
495
496        # we need to cast to the image representation that is compatible with SAM
497        image = util._to_image(image)
498
499        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
500
501        crop_list = []
502        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
503            crop_data = self._process_crop(
504                image, crop_box, layer_idx,
505                precomputed_embeddings=precomputed_embeddings,
506                pbar_init=pbar_init, pbar_update=pbar_update,
507            )
508            crop_list.append(crop_data)
509        pbar_close()
510
511        self._is_initialized = True
512        self._crop_list = crop_list
513        self._crop_boxes = crop_boxes

Initialize image embeddings and masks for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to print computation progress. By default, set to 'False'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
@torch.no_grad()
def generate( self, pred_iou_thresh: float = 0.88, stability_score_thresh: float = 0.95, box_nms_thresh: float = 0.7, crop_nms_thresh: float = 0.7, min_mask_region_area: int = 0, output_mode: str = 'binary_mask') -> List[Dict[str, Any]]:
515    @torch.no_grad()
516    def generate(
517        self,
518        pred_iou_thresh: float = 0.88,
519        stability_score_thresh: float = 0.95,
520        box_nms_thresh: float = 0.7,
521        crop_nms_thresh: float = 0.7,
522        min_mask_region_area: int = 0,
523        output_mode: str = "binary_mask",
524    ) -> List[Dict[str, Any]]:
525        """Generate instance segmentation for the currently initialized image.
526
527        Args:
528            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
529                By default, set to '0.88'.
530            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
531                under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
532            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
533                By default, set to '0.7'.
534            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
535                By default, set to '0.7'.
536            min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
537            output_mode: The form masks are returned in. By default, set to 'binary_mask'.
538
539        Returns:
540            The instance segmentation masks.
541        """
542        if not self.is_initialized:
543            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
544
545        data = amg_utils.MaskData()
546        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
547            crop_data = self._postprocess_batch(
548                data=deepcopy(data_),
549                crop_box=crop_box, original_size=self.original_size,
550                pred_iou_thresh=pred_iou_thresh,
551                stability_score_thresh=stability_score_thresh,
552                box_nms_thresh=box_nms_thresh
553            )
554            data.cat(crop_data)
555
556        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
557            # Prefer masks from smaller crops
558            scores = 1 / box_area(data["crop_boxes"])
559            scores = scores.to(data["boxes"].device)
560            keep_by_nms = batched_nms(
561                data["boxes"].float(),
562                scores,
563                torch.zeros_like(data["boxes"][:, 0]),  # categories
564                iou_threshold=crop_nms_thresh,
565            )
566            data.filter(keep_by_nms)
567
568        data.to_numpy()
569        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
570        return masks

Generate instance segmentation for the currently initialized image.

Arguments:
  • pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model. By default, set to '0.88'.
  • stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask under changes to the cutoff used to binarize the model prediction. By default, set to '0.95'.
  • box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks. By default, set to '0.7'.
  • crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops. By default, set to '0.7'.
  • min_mask_region_area: Minimal size for the predicted masks. By default, set to '0'.
  • output_mode: The form masks are returned in. By default, set to 'binary_mask'.
Returns:

The instance segmentation masks.

class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
600class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
601    """Generates an instance segmentation without prompts, using a point grid.
602
603    Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
604
605    Args:
606        predictor: The Segment Anything predictor.
607        points_per_side: The number of points to be sampled along one side of the image.
608            If None, `point_grids` must provide explicit point sampling. By default, set to '32'.
609        points_per_batch: The number of points run simultaneously by the model.
610            Higher numbers may be faster but use more GPU memory. By default, set to '64'.
611        point_grids: A list over explicit grids of points used for sampling masks.
612            Normalized to [0, 1] with respect to the image coordinate system.
613        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
614            By default, set to '1.0'.
615    """
616
617    # We only expose the arguments that make sense for the tiled mask generator.
618    # Anything related to crops doesn't make sense, because we re-use that functionality
619    # for tiling, so these parameters wouldn't have any effect.
620    def __init__(
621        self,
622        predictor: SamPredictor,
623        points_per_side: Optional[int] = 32,
624        points_per_batch: int = 64,
625        point_grids: Optional[List[np.ndarray]] = None,
626        stability_score_offset: float = 1.0,
627    ) -> None:
628        super().__init__(
629            predictor=predictor,
630            points_per_side=points_per_side,
631            points_per_batch=points_per_batch,
632            point_grids=point_grids,
633            stability_score_offset=stability_score_offset,
634        )
635
636    @torch.no_grad()
637    def initialize(
638        self,
639        image: np.ndarray,
640        image_embeddings: Optional[util.ImageEmbeddings] = None,
641        i: Optional[int] = None,
642        tile_shape: Optional[Tuple[int, int]] = None,
643        halo: Optional[Tuple[int, int]] = None,
644        verbose: bool = False,
645        pbar_init: Optional[callable] = None,
646        pbar_update: Optional[callable] = None,
647        batch_size: int = 1,
648    ) -> None:
649        """Initialize image embeddings and masks for an image.
650
651        Args:
652            image: The input image, volume or timeseries.
653            image_embeddings: Optional precomputed image embeddings.
654                See `util.precompute_image_embeddings` for details.
655            i: Index for the image data. Required if `image` has three spatial dimensions
656                or a time dimension and two spatial dimensions.
657            tile_shape: The tile shape for embedding prediction.
658            halo: The overlap of between tiles.
659            verbose: Whether to print computation progress. By default, set to 'False'.
660            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
661                Can be used together with pbar_update to handle napari progress bar in other thread.
662                To enables using this function within a threadworker.
663            pbar_update: Callback to update an external progress bar.
664            batch_size: The batch size for image embedding prediction. By default, set to '1'.
665        """
666        original_size = image.shape[:2]
667        self._original_size = original_size
668
669        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
670            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size
671        )
672
673        tiling = blocking([0, 0], original_size, tile_shape)
674        n_tiles = tiling.numberOfBlocks
675
676        # The crop box is always the full local tile.
677        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
678        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
679
680        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
681        pbar_init(n_tiles, "Compute masks for tile")
682
683        # We need to cast to the image representation that is compatible with SAM.
684        image = util._to_image(image)
685
686        mask_data = []
687        for tile_id in range(n_tiles):
688            # set the pre-computed embeddings for this tile
689            features = image_embeddings["features"][str(tile_id)]
690            tile_embeddings = {
691                "features": features,
692                "input_size": features.attrs["input_size"],
693                "original_size": features.attrs["original_size"],
694            }
695            util.set_precomputed(self._predictor, tile_embeddings, i)
696
697            # compute the mask data for this tile and append it
698            this_mask_data = self._process_crop(
699                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
700            )
701            mask_data.append(this_mask_data)
702            pbar_update(1)
703        pbar_close()
704
705        # set the initialized data
706        self._is_initialized = True
707        self._crop_list = mask_data
708        self._crop_boxes = crop_boxes

Generates an instance segmentation without prompts, using a point grid.

Implements the same functionality as AutomaticMaskGenerator but for tiled embeddings.

Arguments:
  • predictor: The Segment Anything predictor.
  • points_per_side: The number of points to be sampled along one side of the image. If None, point_grids must provide explicit point sampling. By default, set to '32'.
  • points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory. By default, set to '64'.
  • point_grids: A list over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
  • stability_score_offset: The amount to shift the cutoff when calculating the stability score. By default, set to '1.0'.
TiledAutomaticMaskGenerator( predictor: segment_anything.predictor.SamPredictor, points_per_side: Optional[int] = 32, points_per_batch: int = 64, point_grids: Optional[List[numpy.ndarray]] = None, stability_score_offset: float = 1.0)
620    def __init__(
621        self,
622        predictor: SamPredictor,
623        points_per_side: Optional[int] = 32,
624        points_per_batch: int = 64,
625        point_grids: Optional[List[np.ndarray]] = None,
626        stability_score_offset: float = 1.0,
627    ) -> None:
628        super().__init__(
629            predictor=predictor,
630            points_per_side=points_per_side,
631            points_per_batch=points_per_batch,
632            point_grids=point_grids,
633            stability_score_offset=stability_score_offset,
634        )
@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None, batch_size: int = 1) -> None:
636    @torch.no_grad()
637    def initialize(
638        self,
639        image: np.ndarray,
640        image_embeddings: Optional[util.ImageEmbeddings] = None,
641        i: Optional[int] = None,
642        tile_shape: Optional[Tuple[int, int]] = None,
643        halo: Optional[Tuple[int, int]] = None,
644        verbose: bool = False,
645        pbar_init: Optional[callable] = None,
646        pbar_update: Optional[callable] = None,
647        batch_size: int = 1,
648    ) -> None:
649        """Initialize image embeddings and masks for an image.
650
651        Args:
652            image: The input image, volume or timeseries.
653            image_embeddings: Optional precomputed image embeddings.
654                See `util.precompute_image_embeddings` for details.
655            i: Index for the image data. Required if `image` has three spatial dimensions
656                or a time dimension and two spatial dimensions.
657            tile_shape: The tile shape for embedding prediction.
658            halo: The overlap of between tiles.
659            verbose: Whether to print computation progress. By default, set to 'False'.
660            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
661                Can be used together with pbar_update to handle napari progress bar in other thread.
662                To enables using this function within a threadworker.
663            pbar_update: Callback to update an external progress bar.
664            batch_size: The batch size for image embedding prediction. By default, set to '1'.
665        """
666        original_size = image.shape[:2]
667        self._original_size = original_size
668
669        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
670            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size
671        )
672
673        tiling = blocking([0, 0], original_size, tile_shape)
674        n_tiles = tiling.numberOfBlocks
675
676        # The crop box is always the full local tile.
677        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
678        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
679
680        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
681        pbar_init(n_tiles, "Compute masks for tile")
682
683        # We need to cast to the image representation that is compatible with SAM.
684        image = util._to_image(image)
685
686        mask_data = []
687        for tile_id in range(n_tiles):
688            # set the pre-computed embeddings for this tile
689            features = image_embeddings["features"][str(tile_id)]
690            tile_embeddings = {
691                "features": features,
692                "input_size": features.attrs["input_size"],
693                "original_size": features.attrs["original_size"],
694            }
695            util.set_precomputed(self._predictor, tile_embeddings, i)
696
697            # compute the mask data for this tile and append it
698            this_mask_data = self._process_crop(
699                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
700            )
701            mask_data.append(this_mask_data)
702            pbar_update(1)
703        pbar_close()
704
705        # set the initialized data
706        self._is_initialized = True
707        self._crop_list = mask_data
708        self._crop_boxes = crop_boxes

Initialize image embeddings and masks for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_shape: The tile shape for embedding prediction.
  • halo: The overlap of between tiles.
  • verbose: Whether to print computation progress. By default, set to 'False'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
  • batch_size: The batch size for image embedding prediction. By default, set to '1'.
class DecoderAdapter(torch.nn.modules.module.Module):
716class DecoderAdapter(torch.nn.Module):
717    """Adapter to contain the UNETR decoder in a single module.
718
719    To apply the decoder on top of pre-computed embeddings for the segmentation functionality.
720    See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
721    """
722    def __init__(self, unetr: torch.nn.Module):
723        super().__init__()
724
725        self.base = unetr.base
726        self.out_conv = unetr.out_conv
727        self.deconv_out = unetr.deconv_out
728        self.decoder_head = unetr.decoder_head
729        self.final_activation = unetr.final_activation
730        self.postprocess_masks = unetr.postprocess_masks
731
732        self.decoder = unetr.decoder
733        self.deconv1 = unetr.deconv1
734        self.deconv2 = unetr.deconv2
735        self.deconv3 = unetr.deconv3
736        self.deconv4 = unetr.deconv4
737
738    def _forward_impl(self, input_):
739        z12 = input_
740
741        z9 = self.deconv1(z12)
742        z6 = self.deconv2(z9)
743        z3 = self.deconv3(z6)
744        z0 = self.deconv4(z3)
745
746        updated_from_encoder = [z9, z6, z3]
747
748        x = self.base(z12)
749        x = self.decoder(x, encoder_inputs=updated_from_encoder)
750        x = self.deconv_out(x)
751
752        x = torch.cat([x, z0], dim=1)
753        x = self.decoder_head(x)
754
755        x = self.out_conv(x)
756        if self.final_activation is not None:
757            x = self.final_activation(x)
758        return x
759
760    def forward(self, input_, input_shape, original_shape):
761        x = self._forward_impl(input_)
762        x = self.postprocess_masks(x, input_shape, original_shape)
763        return x

Adapter to contain the UNETR decoder in a single module.

To apply the decoder on top of pre-computed embeddings for the segmentation functionality. See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py

DecoderAdapter(unetr: torch.nn.modules.module.Module)
722    def __init__(self, unetr: torch.nn.Module):
723        super().__init__()
724
725        self.base = unetr.base
726        self.out_conv = unetr.out_conv
727        self.deconv_out = unetr.deconv_out
728        self.decoder_head = unetr.decoder_head
729        self.final_activation = unetr.final_activation
730        self.postprocess_masks = unetr.postprocess_masks
731
732        self.decoder = unetr.decoder
733        self.deconv1 = unetr.deconv1
734        self.deconv2 = unetr.deconv2
735        self.deconv3 = unetr.deconv3
736        self.deconv4 = unetr.deconv4

Initialize internal Module state, shared by both nn.Module and ScriptModule.

base
out_conv
deconv_out
decoder_head
final_activation
postprocess_masks
decoder
deconv1
deconv2
deconv3
deconv4
def forward(self, input_, input_shape, original_shape):
760    def forward(self, input_, input_shape, original_shape):
761        x = self._forward_impl(input_)
762        x = self.postprocess_masks(x, input_shape, original_shape)
763        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
set_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
mtia
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_post_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_pre_hook
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def get_unetr( image_encoder: torch.nn.modules.module.Module, decoder_state: Optional[collections.OrderedDict[str, torch.Tensor]] = None, device: Union[str, torch.device, NoneType] = None, out_channels: int = 3, flexible_load_checkpoint: bool = False) -> torch.nn.modules.module.Module:
766def get_unetr(
767    image_encoder: torch.nn.Module,
768    decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
769    device: Optional[Union[str, torch.device]] = None,
770    out_channels: int = 3,
771    flexible_load_checkpoint: bool = False,
772) -> torch.nn.Module:
773    """Get UNETR model for automatic instance segmentation.
774
775    Args:
776        image_encoder: The image encoder of the SAM model.
777            This is used as encoder by the UNETR too.
778        decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
779        device: The device. By default, automatically chooses the best available device.
780        out_channels: The number of output channels. By default, set to '3'.
781        flexible_load_checkpoint: Whether to allow reinitialization of parameters
782            which could not be found in the provided decoder state. By default, set to 'False'.
783
784    Returns:
785        The UNETR model.
786    """
787    device = util.get_device(device)
788
789    if decoder_state is None:
790        use_conv_transpose = False  # By default, we use interpolation for upsampling.
791    else:
792        # From the provided pretrained 'decoder_state', we check whether it uses transposed convolutions.
793        # NOTE: Explanation to the logic below -
794        # - We do this by looking for parameter names that contain '.block.' within the "decoder.samplers"
795        #   submodules. This naming convention indicates that transposed convolutions are used,
796        #   wrapped inside a custom block module.
797        # - Otherwise '.conv.' appears. It indicates a standard `Conv2d` applied after interpolation for upsampling.
798        use_conv_transpose = any(".block." in k for k in decoder_state.keys() if k.startswith("decoder.samplers"))
799
800    unetr = UNETR(
801        backbone="sam",
802        encoder=image_encoder,
803        out_channels=out_channels,
804        use_sam_stats=True,
805        final_activation="Sigmoid",
806        use_skip_connection=False,
807        resize_input=True,
808        use_conv_transpose=use_conv_transpose,
809
810    )
811
812    if decoder_state is not None:
813        unetr_state_dict = unetr.state_dict()
814        for k, v in unetr_state_dict.items():
815            if not k.startswith("encoder"):
816                if flexible_load_checkpoint:  # Whether allow reinitalization of params, if not found.
817                    if k in decoder_state:  # First check whether the key is available in the provided decoder state.
818                        unetr_state_dict[k] = decoder_state[k]
819                    else:  # Otherwise, allow it to initialize it.
820                        warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.")
821                        unetr_state_dict[k] = v
822
823                else:  # Whether be strict on finding the parameter in the decoder state.
824                    if k not in decoder_state:
825                        raise RuntimeError(f"The parameters for '{k}' could not be found.")
826                    unetr_state_dict[k] = decoder_state[k]
827
828        unetr.load_state_dict(unetr_state_dict)
829
830    unetr.to(device)
831    return unetr

Get UNETR model for automatic instance segmentation.

Arguments:
  • image_encoder: The image encoder of the SAM model. This is used as encoder by the UNETR too.
  • decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
  • device: The device. By default, automatically chooses the best available device.
  • out_channels: The number of output channels. By default, set to '3'.
  • flexible_load_checkpoint: Whether to allow reinitialization of parameters which could not be found in the provided decoder state. By default, set to 'False'.
Returns:

The UNETR model.

def get_decoder( image_encoder: torch.nn.modules.module.Module, decoder_state: collections.OrderedDict[str, torch.Tensor], device: Union[str, torch.device, NoneType] = None) -> DecoderAdapter:
834def get_decoder(
835    image_encoder: torch.nn.Module,
836    decoder_state: OrderedDict[str, torch.Tensor],
837    device: Optional[Union[str, torch.device]] = None,
838) -> DecoderAdapter:
839    """Get decoder to predict outputs for automatic instance segmentation
840
841    Args:
842        image_encoder: The image encoder of the SAM model.
843        decoder_state: State to initialize the weights of the UNETR decoder.
844        device: The device. By default, automatically chooses the best available device.
845
846    Returns:
847        The decoder for instance segmentation.
848    """
849    unetr = get_unetr(image_encoder, decoder_state, device)
850    return DecoderAdapter(unetr)

Get decoder to predict outputs for automatic instance segmentation

Arguments:
  • image_encoder: The image encoder of the SAM model.
  • decoder_state: State to initialize the weights of the UNETR decoder.
  • device: The device. By default, automatically chooses the best available device.
Returns:

The decoder for instance segmentation.

def get_predictor_and_decoder( model_type: str, checkpoint_path: Union[str, os.PathLike], device: Union[str, torch.device, NoneType] = None, peft_kwargs: Optional[Dict] = None) -> Tuple[segment_anything.predictor.SamPredictor, DecoderAdapter]:
853def get_predictor_and_decoder(
854    model_type: str,
855    checkpoint_path: Union[str, os.PathLike],
856    device: Optional[Union[str, torch.device]] = None,
857    peft_kwargs: Optional[Dict] = None,
858) -> Tuple[SamPredictor, DecoderAdapter]:
859    """Load the SAM model (predictor) and instance segmentation decoder.
860
861    This requires a checkpoint that contains the state for both predictor
862    and decoder.
863
864    Args:
865        model_type: The type of the image encoder used in the SAM model.
866        checkpoint_path: Path to the checkpoint from which to load the data.
867        device: The device. By default, automatically chooses the best available device.
868        peft_kwargs: Keyword arguments for the PEFT wrapper class.
869
870    Returns:
871        The SAM predictor.
872        The decoder for instance segmentation.
873    """
874    device = util.get_device(device)
875    predictor, state = util.get_sam_model(
876        model_type=model_type,
877        checkpoint_path=checkpoint_path,
878        device=device,
879        return_state=True,
880        peft_kwargs=peft_kwargs,
881    )
882
883    if "decoder_state" not in state:
884        raise ValueError(
885            f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
886        )
887
888    decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
889    return predictor, decoder

Load the SAM model (predictor) and instance segmentation decoder.

This requires a checkpoint that contains the state for both predictor and decoder.

Arguments:
  • model_type: The type of the image encoder used in the SAM model.
  • checkpoint_path: Path to the checkpoint from which to load the data.
  • device: The device. By default, automatically chooses the best available device.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
Returns:

The SAM predictor. The decoder for instance segmentation.

class InstanceSegmentationWithDecoder:
 941class InstanceSegmentationWithDecoder:
 942    """Generates an instance segmentation without prompts, using a decoder.
 943
 944    Implements the same interface as `AutomaticMaskGenerator`.
 945
 946    Use this class as follows:
 947    ```python
 948    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
 949    segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
 950    masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
 951    ```
 952
 953    Args:
 954        predictor: The segment anything predictor.
 955        decoder: The decoder to predict intermediate representations
 956            for instance segmentation.
 957    """
 958    def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None:
 959        self._predictor = predictor
 960        self._decoder = decoder
 961
 962        # The decoder outputs.
 963        self._foreground = None
 964        self._center_distances = None
 965        self._boundary_distances = None
 966
 967        self._is_initialized = False
 968
 969    @property
 970    def is_initialized(self):
 971        """Whether the mask generator has already been initialized.
 972        """
 973        return self._is_initialized
 974
 975    @torch.no_grad()
 976    def initialize(
 977        self,
 978        image: np.ndarray,
 979        image_embeddings: Optional[util.ImageEmbeddings] = None,
 980        i: Optional[int] = None,
 981        verbose: bool = False,
 982        pbar_init: Optional[callable] = None,
 983        pbar_update: Optional[callable] = None,
 984    ) -> None:
 985        """Initialize image embeddings and decoder predictions for an image.
 986
 987        Args:
 988            image: The input image, volume or timeseries.
 989            image_embeddings: Optional precomputed image embeddings.
 990                See `util.precompute_image_embeddings` for details.
 991            i: Index for the image data. Required if `image` has three spatial dimensions
 992                or a time dimension and two spatial dimensions.
 993            verbose: Whether to be verbose. By default, set to 'False'.
 994            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 995                Can be used together with pbar_update to handle napari progress bar in other thread.
 996                To enables using this function within a threadworker.
 997            pbar_update: Callback to update an external progress bar.
 998        """
 999        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1000        pbar_init(1, "Initialize instance segmentation with decoder")
1001
1002        if image_embeddings is None:
1003            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
1004
1005        # Get the image embeddings from the predictor.
1006        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
1007        embeddings = self._predictor.features
1008        input_shape = tuple(self._predictor.input_size)
1009        original_shape = tuple(self._predictor.original_size)
1010
1011        # Run prediction with the UNETR decoder.
1012        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1013        assert output.shape[0] == 3, f"{output.shape}"
1014        pbar_update(1)
1015        pbar_close()
1016
1017        # Set the state.
1018        self._foreground = output[0]
1019        self._center_distances = output[1]
1020        self._boundary_distances = output[2]
1021        self._is_initialized = True
1022
1023    def _to_masks(self, segmentation, output_mode):
1024        if output_mode != "binary_mask":
1025            raise NotImplementedError
1026
1027        props = regionprops(segmentation)
1028        ndim = segmentation.ndim
1029        assert ndim in (2, 3)
1030
1031        shape = segmentation.shape
1032        if ndim == 2:
1033            crop_box = [0, shape[1], 0, shape[0]]
1034        else:
1035            crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]
1036
1037        # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
1038        def to_bbox_2d(bbox):
1039            y0, x0 = bbox[0], bbox[1]
1040            w = bbox[3] - x0
1041            h = bbox[2] - y0
1042            return [x0, w, y0, h]
1043
1044        def to_bbox_3d(bbox):
1045            z0, y0, x0 = bbox[0], bbox[1], bbox[2]
1046            w = bbox[5] - x0
1047            h = bbox[4] - y0
1048            d = bbox[3] - y0
1049            return [x0, w, y0, h, z0, d]
1050
1051        to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
1052        masks = [
1053            {
1054                "segmentation": segmentation == prop.label,
1055                "area": prop.area,
1056                "bbox": to_bbox(prop.bbox),
1057                "crop_box": crop_box,
1058                "seg_id": prop.label,
1059            } for prop in props
1060        ]
1061        return masks
1062
1063    def generate(
1064        self,
1065        center_distance_threshold: float = 0.5,
1066        boundary_distance_threshold: float = 0.5,
1067        foreground_threshold: float = 0.5,
1068        foreground_smoothing: float = 1.0,
1069        distance_smoothing: float = 1.6,
1070        min_size: int = 0,
1071        output_mode: Optional[str] = "binary_mask",
1072        tile_shape: Optional[Tuple[int, int]] = None,
1073        halo: Optional[Tuple[int, int]] = None,
1074        n_threads: Optional[int] = None,
1075    ) -> List[Dict[str, Any]]:
1076        """Generate instance segmentation for the currently initialized image.
1077
1078        Args:
1079            center_distance_threshold: Center distance predictions below this value will be
1080                used to find seeds (intersected with thresholded boundary distance predictions).
1081                By default, set to '0.5'.
1082            boundary_distance_threshold: Boundary distance predictions below this value will be
1083                used to find seeds (intersected with thresholded center distance predictions).
1084                By default, set to '0.5'.
1085            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1086                By default, set to '0.5'.
1087            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1088                checkerboard artifacts in the prediction. By default, set to '1.0'.
1089            distance_smoothing: Sigma value for smoothing the distance predictions.
1090            min_size: Minimal object size in the segmentation result. By default, set to '0'.
1091            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
1092                By default, set to 'binary_mask'.
1093            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1094                This parameter is independent from the tile shape for computing the embeddings.
1095                If not given then post-processing will not be parallelized.
1096            halo: Halo for parallel post-processing. See also `tile_shape`.
1097            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1098
1099        Returns:
1100            The instance segmentation masks.
1101        """
1102        if not self.is_initialized:
1103            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1104
1105        if foreground_smoothing > 0:
1106            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1107        else:
1108            foreground = self._foreground
1109
1110        if tile_shape is None:
1111            segmentation = watershed_from_center_and_boundary_distances(
1112                center_distances=self._center_distances,
1113                boundary_distances=self._boundary_distances,
1114                foreground_map=foreground,
1115                center_distance_threshold=center_distance_threshold,
1116                boundary_distance_threshold=boundary_distance_threshold,
1117                foreground_threshold=foreground_threshold,
1118                distance_smoothing=distance_smoothing,
1119                min_size=min_size,
1120            )
1121        else:
1122            if halo is None:
1123                raise ValueError("You must pass a value for halo if tile_shape is given.")
1124            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1125                center_distances=self._center_distances,
1126                boundary_distances=self._boundary_distances,
1127                foreground_map=foreground,
1128                center_distance_threshold=center_distance_threshold,
1129                boundary_distance_threshold=boundary_distance_threshold,
1130                foreground_threshold=foreground_threshold,
1131                distance_smoothing=distance_smoothing,
1132                min_size=min_size,
1133                tile_shape=tile_shape,
1134                halo=halo,
1135                n_threads=n_threads,
1136                verbose=False,
1137            )
1138
1139        if output_mode is not None:
1140            segmentation = self._to_masks(segmentation, output_mode)
1141        return segmentation
1142
1143    def get_state(self) -> Dict[str, Any]:
1144        """Get the initialized state of the instance segmenter.
1145
1146        Returns:
1147            Instance segmentation state.
1148        """
1149        if not self.is_initialized:
1150            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1151
1152        return {
1153            "foreground": self._foreground,
1154            "center_distances": self._center_distances,
1155            "boundary_distances": self._boundary_distances,
1156        }
1157
1158    def set_state(self, state: Dict[str, Any]) -> None:
1159        """Set the state of the instance segmenter.
1160
1161        Args:
1162            state: The instance segmentation state
1163        """
1164        self._foreground = state["foreground"]
1165        self._center_distances = state["center_distances"]
1166        self._boundary_distances = state["boundary_distances"]
1167        self._is_initialized = True
1168
1169    def clear_state(self):
1170        """Clear the state of the instance segmenter.
1171        """
1172        self._foreground = None
1173        self._center_distances = None
1174        self._boundary_distances = None
1175        self._is_initialized = False

Generates an instance segmentation without prompts, using a decoder.

Implements the same interface as AutomaticMaskGenerator.

Use this class as follows:

segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
Arguments:
  • predictor: The segment anything predictor.
  • decoder: The decoder to predict intermediate representations for instance segmentation.
InstanceSegmentationWithDecoder( predictor: segment_anything.predictor.SamPredictor, decoder: torch.nn.modules.module.Module)
958    def __init__(self, predictor: SamPredictor, decoder: torch.nn.Module) -> None:
959        self._predictor = predictor
960        self._decoder = decoder
961
962        # The decoder outputs.
963        self._foreground = None
964        self._center_distances = None
965        self._boundary_distances = None
966
967        self._is_initialized = False
is_initialized
969    @property
970    def is_initialized(self):
971        """Whether the mask generator has already been initialized.
972        """
973        return self._is_initialized

Whether the mask generator has already been initialized.

@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
 975    @torch.no_grad()
 976    def initialize(
 977        self,
 978        image: np.ndarray,
 979        image_embeddings: Optional[util.ImageEmbeddings] = None,
 980        i: Optional[int] = None,
 981        verbose: bool = False,
 982        pbar_init: Optional[callable] = None,
 983        pbar_update: Optional[callable] = None,
 984    ) -> None:
 985        """Initialize image embeddings and decoder predictions for an image.
 986
 987        Args:
 988            image: The input image, volume or timeseries.
 989            image_embeddings: Optional precomputed image embeddings.
 990                See `util.precompute_image_embeddings` for details.
 991            i: Index for the image data. Required if `image` has three spatial dimensions
 992                or a time dimension and two spatial dimensions.
 993            verbose: Whether to be verbose. By default, set to 'False'.
 994            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 995                Can be used together with pbar_update to handle napari progress bar in other thread.
 996                To enables using this function within a threadworker.
 997            pbar_update: Callback to update an external progress bar.
 998        """
 999        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1000        pbar_init(1, "Initialize instance segmentation with decoder")
1001
1002        if image_embeddings is None:
1003            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
1004
1005        # Get the image embeddings from the predictor.
1006        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
1007        embeddings = self._predictor.features
1008        input_shape = tuple(self._predictor.input_size)
1009        original_shape = tuple(self._predictor.original_size)
1010
1011        # Run prediction with the UNETR decoder.
1012        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1013        assert output.shape[0] == 3, f"{output.shape}"
1014        pbar_update(1)
1015        pbar_close()
1016
1017        # Set the state.
1018        self._foreground = output[0]
1019        self._center_distances = output[1]
1020        self._boundary_distances = output[2]
1021        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to be verbose. By default, set to 'False'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
def generate( self, center_distance_threshold: float = 0.5, boundary_distance_threshold: float = 0.5, foreground_threshold: float = 0.5, foreground_smoothing: float = 1.0, distance_smoothing: float = 1.6, min_size: int = 0, output_mode: Optional[str] = 'binary_mask', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, n_threads: Optional[int] = None) -> List[Dict[str, Any]]:
1063    def generate(
1064        self,
1065        center_distance_threshold: float = 0.5,
1066        boundary_distance_threshold: float = 0.5,
1067        foreground_threshold: float = 0.5,
1068        foreground_smoothing: float = 1.0,
1069        distance_smoothing: float = 1.6,
1070        min_size: int = 0,
1071        output_mode: Optional[str] = "binary_mask",
1072        tile_shape: Optional[Tuple[int, int]] = None,
1073        halo: Optional[Tuple[int, int]] = None,
1074        n_threads: Optional[int] = None,
1075    ) -> List[Dict[str, Any]]:
1076        """Generate instance segmentation for the currently initialized image.
1077
1078        Args:
1079            center_distance_threshold: Center distance predictions below this value will be
1080                used to find seeds (intersected with thresholded boundary distance predictions).
1081                By default, set to '0.5'.
1082            boundary_distance_threshold: Boundary distance predictions below this value will be
1083                used to find seeds (intersected with thresholded center distance predictions).
1084                By default, set to '0.5'.
1085            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1086                By default, set to '0.5'.
1087            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1088                checkerboard artifacts in the prediction. By default, set to '1.0'.
1089            distance_smoothing: Sigma value for smoothing the distance predictions.
1090            min_size: Minimal object size in the segmentation result. By default, set to '0'.
1091            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
1092                By default, set to 'binary_mask'.
1093            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1094                This parameter is independent from the tile shape for computing the embeddings.
1095                If not given then post-processing will not be parallelized.
1096            halo: Halo for parallel post-processing. See also `tile_shape`.
1097            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1098
1099        Returns:
1100            The instance segmentation masks.
1101        """
1102        if not self.is_initialized:
1103            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1104
1105        if foreground_smoothing > 0:
1106            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1107        else:
1108            foreground = self._foreground
1109
1110        if tile_shape is None:
1111            segmentation = watershed_from_center_and_boundary_distances(
1112                center_distances=self._center_distances,
1113                boundary_distances=self._boundary_distances,
1114                foreground_map=foreground,
1115                center_distance_threshold=center_distance_threshold,
1116                boundary_distance_threshold=boundary_distance_threshold,
1117                foreground_threshold=foreground_threshold,
1118                distance_smoothing=distance_smoothing,
1119                min_size=min_size,
1120            )
1121        else:
1122            if halo is None:
1123                raise ValueError("You must pass a value for halo if tile_shape is given.")
1124            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1125                center_distances=self._center_distances,
1126                boundary_distances=self._boundary_distances,
1127                foreground_map=foreground,
1128                center_distance_threshold=center_distance_threshold,
1129                boundary_distance_threshold=boundary_distance_threshold,
1130                foreground_threshold=foreground_threshold,
1131                distance_smoothing=distance_smoothing,
1132                min_size=min_size,
1133                tile_shape=tile_shape,
1134                halo=halo,
1135                n_threads=n_threads,
1136                verbose=False,
1137            )
1138
1139        if output_mode is not None:
1140            segmentation = self._to_masks(segmentation, output_mode)
1141        return segmentation

Generate instance segmentation for the currently initialized image.

Arguments:
  • center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions). By default, set to '0.5'.
  • boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions). By default, set to '0.5'.
  • foreground_threshold: Foreground predictions above this value will be used as foreground mask. By default, set to '0.5'.
  • foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction. By default, set to '1.0'.
  • distance_smoothing: Sigma value for smoothing the distance predictions.
  • min_size: Minimal object size in the segmentation result. By default, set to '0'.
  • output_mode: The form masks are returned in. Pass None to directly return the instance segmentation. By default, set to 'binary_mask'.
  • tile_shape: Tile shape for parallelizing the instance segmentation post-processing. This parameter is independent from the tile shape for computing the embeddings. If not given then post-processing will not be parallelized.
  • halo: Halo for parallel post-processing. See also tile_shape.
  • n_threads: Number of threads for parallel post-processing. See also tile_shape.
Returns:

The instance segmentation masks.

def get_state(self) -> Dict[str, Any]:
1143    def get_state(self) -> Dict[str, Any]:
1144        """Get the initialized state of the instance segmenter.
1145
1146        Returns:
1147            Instance segmentation state.
1148        """
1149        if not self.is_initialized:
1150            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1151
1152        return {
1153            "foreground": self._foreground,
1154            "center_distances": self._center_distances,
1155            "boundary_distances": self._boundary_distances,
1156        }

Get the initialized state of the instance segmenter.

Returns:

Instance segmentation state.

def set_state(self, state: Dict[str, Any]) -> None:
1158    def set_state(self, state: Dict[str, Any]) -> None:
1159        """Set the state of the instance segmenter.
1160
1161        Args:
1162            state: The instance segmentation state
1163        """
1164        self._foreground = state["foreground"]
1165        self._center_distances = state["center_distances"]
1166        self._boundary_distances = state["boundary_distances"]
1167        self._is_initialized = True

Set the state of the instance segmenter.

Arguments:
  • state: The instance segmentation state
def clear_state(self):
1169    def clear_state(self):
1170        """Clear the state of the instance segmenter.
1171        """
1172        self._foreground = None
1173        self._center_distances = None
1174        self._boundary_distances = None
1175        self._is_initialized = False

Clear the state of the instance segmenter.

class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1178class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1179    """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings.
1180    """
1181
1182    # Apply the decoder in a batched fashion, and then perform the resizing independently per output.
1183    # This is necessary, because the individual tiles may have different tile shapes due to border tiles.
1184    def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes):
1185        batched_embeddings = torch.cat(batched_embeddings)
1186        output = self._decoder._forward_impl(batched_embeddings)
1187
1188        batched_output = []
1189        for x, input_shape, original_shape in zip(output, input_shapes, original_shapes):
1190            x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0)
1191            batched_output.append(x.cpu().numpy())
1192        return batched_output
1193
1194    @torch.no_grad()
1195    def initialize(
1196        self,
1197        image: np.ndarray,
1198        image_embeddings: Optional[util.ImageEmbeddings] = None,
1199        i: Optional[int] = None,
1200        tile_shape: Optional[Tuple[int, int]] = None,
1201        halo: Optional[Tuple[int, int]] = None,
1202        verbose: bool = False,
1203        pbar_init: Optional[callable] = None,
1204        pbar_update: Optional[callable] = None,
1205        batch_size: int = 1,
1206    ) -> None:
1207        """Initialize image embeddings and decoder predictions for an image.
1208
1209        Args:
1210            image: The input image, volume or timeseries.
1211            image_embeddings: Optional precomputed image embeddings.
1212                See `util.precompute_image_embeddings` for details.
1213            i: Index for the image data. Required if `image` has three spatial dimensions
1214                or a time dimension and two spatial dimensions.
1215            tile_shape: Shape of the tiles for precomputing image embeddings.
1216            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1217            verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
1218            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1219                Can be used together with pbar_update to handle napari progress bar in other thread.
1220                To enables using this function within a threadworker.
1221            pbar_update: Callback to update an external progress bar.
1222            batch_size: The batch size for image embedding computation and segmentation decoder prediction.
1223        """
1224        original_size = image.shape[:2]
1225        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1226            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size
1227        )
1228        tiling = blocking([0, 0], original_size, tile_shape)
1229
1230        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1231        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1232
1233        foreground = np.zeros(original_size, dtype="float32")
1234        center_distances = np.zeros(original_size, dtype="float32")
1235        boundary_distances = np.zeros(original_size, dtype="float32")
1236
1237        n_tiles = tiling.numberOfBlocks
1238        n_batches = int(np.ceil(n_tiles / batch_size))
1239
1240        for batch_id in range(n_batches):
1241            tile_start = batch_id * batch_size
1242            tile_stop = min(tile_start + batch_size, n_tiles)
1243
1244            batched_embeddings, input_shapes, original_shapes = [], [], []
1245            for tile_id in range(tile_start, tile_stop):
1246                # Get the image embeddings from the predictor for this tile.
1247                self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1248
1249                batched_embeddings.append(self._predictor.features)
1250                input_shapes.append(tuple(self._predictor.input_size))
1251                original_shapes.append(tuple(self._predictor.original_size))
1252
1253            batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes)
1254
1255            for output_id, tile_id in enumerate(range(tile_start, tile_stop)):
1256                output = batched_output[output_id]
1257                assert output.shape[0] == 3
1258
1259                # Set the predictions in the output for this tile.
1260                block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1261                local_bb = tuple(
1262                    slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1263                )
1264                inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1265
1266                foreground[inner_bb] = output[0][local_bb]
1267                center_distances[inner_bb] = output[1][local_bb]
1268                boundary_distances[inner_bb] = output[2][local_bb]
1269                pbar_update(1)
1270
1271        pbar_close()
1272
1273        # Set the state.
1274        self._foreground = foreground
1275        self._center_distances = center_distances
1276        self._boundary_distances = boundary_distances
1277        self._is_initialized = True

Same as InstanceSegmentationWithDecoder but for tiled image embeddings.

@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None, batch_size: int = 1) -> None:
1194    @torch.no_grad()
1195    def initialize(
1196        self,
1197        image: np.ndarray,
1198        image_embeddings: Optional[util.ImageEmbeddings] = None,
1199        i: Optional[int] = None,
1200        tile_shape: Optional[Tuple[int, int]] = None,
1201        halo: Optional[Tuple[int, int]] = None,
1202        verbose: bool = False,
1203        pbar_init: Optional[callable] = None,
1204        pbar_update: Optional[callable] = None,
1205        batch_size: int = 1,
1206    ) -> None:
1207        """Initialize image embeddings and decoder predictions for an image.
1208
1209        Args:
1210            image: The input image, volume or timeseries.
1211            image_embeddings: Optional precomputed image embeddings.
1212                See `util.precompute_image_embeddings` for details.
1213            i: Index for the image data. Required if `image` has three spatial dimensions
1214                or a time dimension and two spatial dimensions.
1215            tile_shape: Shape of the tiles for precomputing image embeddings.
1216            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1217            verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
1218            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1219                Can be used together with pbar_update to handle napari progress bar in other thread.
1220                To enables using this function within a threadworker.
1221            pbar_update: Callback to update an external progress bar.
1222            batch_size: The batch size for image embedding computation and segmentation decoder prediction.
1223        """
1224        original_size = image.shape[:2]
1225        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1226            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size
1227        )
1228        tiling = blocking([0, 0], original_size, tile_shape)
1229
1230        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1231        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1232
1233        foreground = np.zeros(original_size, dtype="float32")
1234        center_distances = np.zeros(original_size, dtype="float32")
1235        boundary_distances = np.zeros(original_size, dtype="float32")
1236
1237        n_tiles = tiling.numberOfBlocks
1238        n_batches = int(np.ceil(n_tiles / batch_size))
1239
1240        for batch_id in range(n_batches):
1241            tile_start = batch_id * batch_size
1242            tile_stop = min(tile_start + batch_size, n_tiles)
1243
1244            batched_embeddings, input_shapes, original_shapes = [], [], []
1245            for tile_id in range(tile_start, tile_stop):
1246                # Get the image embeddings from the predictor for this tile.
1247                self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1248
1249                batched_embeddings.append(self._predictor.features)
1250                input_shapes.append(tuple(self._predictor.input_size))
1251                original_shapes.append(tuple(self._predictor.original_size))
1252
1253            batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes)
1254
1255            for output_id, tile_id in enumerate(range(tile_start, tile_stop)):
1256                output = batched_output[output_id]
1257                assert output.shape[0] == 3
1258
1259                # Set the predictions in the output for this tile.
1260                block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1261                local_bb = tuple(
1262                    slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1263                )
1264                inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1265
1266                foreground[inner_bb] = output[0][local_bb]
1267                center_distances[inner_bb] = output[1][local_bb]
1268                boundary_distances[inner_bb] = output[2][local_bb]
1269                pbar_update(1)
1270
1271        pbar_close()
1272
1273        # Set the state.
1274        self._foreground = foreground
1275        self._center_distances = center_distances
1276        self._boundary_distances = boundary_distances
1277        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_shape: Shape of the tiles for precomputing image embeddings.
  • halo: Overlap of the tiles for tiled precomputation of image embeddings.
  • verbose: Dummy input to be compatible with other function signatures. By default, set to 'False'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
  • batch_size: The batch size for image embedding computation and segmentation decoder prediction.
def get_amg( predictor: segment_anything.predictor.SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.modules.module.Module] = None, **kwargs) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1280def get_amg(
1281    predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs,
1282) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1283    """Get the automatic mask generator class.
1284
1285    Args:
1286        predictor: The segment anything predictor.
1287        is_tiled: Whether tiled embeddings are used.
1288        decoder: Decoder to predict instacne segmmentation.
1289        kwargs: The keyword arguments for the amg class.
1290
1291    Returns:
1292        The automatic mask generator.
1293    """
1294    if decoder is None:
1295        segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator
1296        segmenter = segmenter_class(predictor, **kwargs)
1297    else:
1298        segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder
1299        segmenter = segmenter_class(predictor, decoder, **kwargs)
1300
1301    return segmenter

Get the automatic mask generator class.

Arguments:
  • predictor: The segment anything predictor.
  • is_tiled: Whether tiled embeddings are used.
  • decoder: Decoder to predict instacne segmmentation.
  • kwargs: The keyword arguments for the amg class.
Returns:

The automatic mask generator.