micro_sam.instance_segmentation

Automated instance segmentation functionality. The classes implemented here extend the automatic instance segmentation from Segment Anything: https://computational-cell-analytics.github.io/micro-sam/micro_sam.html

   1"""Automated instance segmentation functionality.
   2The classes implemented here extend the automatic instance segmentation from Segment Anything:
   3https://computational-cell-analytics.github.io/micro-sam/micro_sam.html
   4"""
   5
   6import os
   7import warnings
   8from abc import ABC
   9from copy import deepcopy
  10from collections import OrderedDict
  11from typing import Any, Dict, List, Optional, Tuple, Union
  12
  13import vigra
  14import numpy as np
  15from skimage.measure import label, regionprops
  16from skimage.segmentation import relabel_sequential
  17
  18import torch
  19from torchvision.ops.boxes import batched_nms, box_area
  20
  21from torch_em.model import UNETR
  22from torch_em.util.segmentation import watershed_from_center_and_boundary_distances
  23
  24import elf.parallel as parallel
  25from elf.parallel.filters import apply_filter
  26
  27from nifty.tools import blocking
  28
  29import segment_anything.utils.amg as amg_utils
  30from segment_anything.predictor import SamPredictor
  31
  32from . import util
  33from ._vendored import batched_mask_to_box, mask_to_rle_pytorch
  34
  35#
  36# Utility Functionality
  37#
  38
  39
  40class _FakeInput:
  41    def __init__(self, shape):
  42        self.shape = shape
  43
  44    def __getitem__(self, index):
  45        block_shape = tuple(ind.stop - ind.start for ind in index)
  46        return np.zeros(block_shape, dtype="float32")
  47
  48
  49def mask_data_to_segmentation(
  50    masks: List[Dict[str, Any]],
  51    with_background: bool,
  52    min_object_size: int = 0,
  53    max_object_size: Optional[int] = None,
  54    label_masks: bool = True,
  55) -> np.ndarray:
  56    """Convert the output of the automatic mask generation to an instance segmentation.
  57
  58    Args:
  59        masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator.
  60            Only supports output_mode=binary_mask.
  61        with_background: Whether the segmentation has background. If yes this function assures that the largest
  62            object in the output will be mapped to zero (the background value).
  63        min_object_size: The minimal size of an object in pixels.
  64        max_object_size: The maximal size of an object in pixels.
  65        label_masks: Whether to apply connected components to the result before removing small objects.
  66
  67    Returns:
  68        The instance segmentation.
  69    """
  70
  71    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
  72    # we could also get the shape from the crop box
  73    shape = next(iter(masks))["segmentation"].shape
  74    segmentation = np.zeros(shape, dtype="uint32")
  75
  76    def require_numpy(mask):
  77        return mask.cpu().numpy() if torch.is_tensor(mask) else mask
  78
  79    seg_id = 1
  80    for mask in masks:
  81        if mask["area"] < min_object_size:
  82            continue
  83        if max_object_size is not None and mask["area"] > max_object_size:
  84            continue
  85
  86        this_seg_id = mask.get("seg_id", seg_id)
  87        segmentation[require_numpy(mask["segmentation"])] = this_seg_id
  88        seg_id = this_seg_id + 1
  89
  90    if label_masks:
  91        segmentation = label(segmentation).astype(segmentation.dtype)
  92
  93    seg_ids, sizes = np.unique(segmentation, return_counts=True)
  94
  95    # In some cases objects may be smaller than peviously calculated,
  96    # since they are covered by other objects. We ensure these also get
  97    # filtered out here.
  98    filter_ids = seg_ids[sizes < min_object_size]
  99
 100    # If we run segmentation with background we also map the largest segment
 101    # (the most likely background object) to zero. This is often zero already,
 102    # but it does not hurt to reset that to zero either.
 103    if with_background:
 104        bg_id = seg_ids[np.argmax(sizes)]
 105        filter_ids = np.concatenate([filter_ids, [bg_id]])
 106
 107    segmentation[np.isin(segmentation, filter_ids)] = 0
 108    segmentation = relabel_sequential(segmentation)[0]
 109
 110    return segmentation
 111
 112
 113#
 114# Classes for automatic instance segmentation
 115#
 116
 117
 118class AMGBase(ABC):
 119    """Base class for the automatic mask generators.
 120    """
 121    def __init__(self):
 122        # the state that has to be computed by the 'initialize' method of the child classes
 123        self._is_initialized = False
 124        self._crop_list = None
 125        self._crop_boxes = None
 126        self._original_size = None
 127
 128    @property
 129    def is_initialized(self):
 130        """Whether the mask generator has already been initialized.
 131        """
 132        return self._is_initialized
 133
 134    @property
 135    def crop_list(self):
 136        """The list of mask data after initialization.
 137        """
 138        return self._crop_list
 139
 140    @property
 141    def crop_boxes(self):
 142        """The list of crop boxes.
 143        """
 144        return self._crop_boxes
 145
 146    @property
 147    def original_size(self):
 148        """The original image size.
 149        """
 150        return self._original_size
 151
 152    def _postprocess_batch(
 153        self,
 154        data,
 155        crop_box,
 156        original_size,
 157        pred_iou_thresh,
 158        stability_score_thresh,
 159        box_nms_thresh,
 160    ):
 161        orig_h, orig_w = original_size
 162
 163        # filter by predicted IoU
 164        if pred_iou_thresh > 0.0:
 165            keep_mask = data["iou_preds"] > pred_iou_thresh
 166            data.filter(keep_mask)
 167
 168        # filter by stability score
 169        if stability_score_thresh > 0.0:
 170            keep_mask = data["stability_score"] >= stability_score_thresh
 171            data.filter(keep_mask)
 172
 173        # filter boxes that touch crop boundaries
 174        keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
 175        if not torch.all(keep_mask):
 176            data.filter(keep_mask)
 177
 178        # remove duplicates within this crop.
 179        keep_by_nms = batched_nms(
 180            data["boxes"].float(),
 181            data["iou_preds"],
 182            torch.zeros_like(data["boxes"][:, 0]),  # categories
 183            iou_threshold=box_nms_thresh,
 184        )
 185        data.filter(keep_by_nms)
 186
 187        # return to the original image frame
 188        data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box)
 189        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
 190        # the data from embedding based segmentation doesn't have the points
 191        # so we skip if the corresponding key can't be found
 192        try:
 193            data["points"] = amg_utils.uncrop_points(data["points"], crop_box)
 194        except KeyError:
 195            pass
 196
 197        return data
 198
 199    def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
 200
 201        if len(mask_data["rles"]) == 0:
 202            return mask_data
 203
 204        # filter small disconnected regions and holes
 205        new_masks = []
 206        scores = []
 207        for rle in mask_data["rles"]:
 208            mask = amg_utils.rle_to_mask(rle)
 209
 210            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes")
 211            unchanged = not changed
 212            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands")
 213            unchanged = unchanged and not changed
 214
 215            new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0))
 216            # give score=0 to changed masks and score=1 to unchanged masks
 217            # so NMS will prefer ones that didn't need postprocessing
 218            scores.append(float(unchanged))
 219
 220        # recalculate boxes and remove any new duplicates
 221        masks = torch.cat(new_masks, dim=0)
 222        boxes = batched_mask_to_box(masks.to(torch.bool))  # Casting this to boolean as we work with one-hot labels.
 223        keep_by_nms = batched_nms(
 224            boxes.float(),
 225            torch.as_tensor(scores, dtype=torch.float),
 226            torch.zeros_like(boxes[:, 0]),  # categories
 227            iou_threshold=nms_thresh,
 228        )
 229
 230        # only recalculate RLEs for masks that have changed
 231        for i_mask in keep_by_nms:
 232            if scores[i_mask] == 0.0:
 233                mask_torch = masks[i_mask].unsqueeze(0)
 234                # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
 235                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
 236                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
 237        mask_data.filter(keep_by_nms)
 238
 239        return mask_data
 240
 241    def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode):
 242        # filter small disconnected regions and holes in masks
 243        if min_mask_region_area > 0:
 244            mask_data = self._postprocess_small_regions(
 245                mask_data,
 246                min_mask_region_area,
 247                max(box_nms_thresh, crop_nms_thresh),
 248            )
 249
 250        # encode masks
 251        if output_mode == "coco_rle":
 252            mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]]
 253        elif output_mode == "binary_mask":
 254            mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]]
 255        else:
 256            mask_data["segmentations"] = mask_data["rles"]
 257
 258        # write mask records
 259        curr_anns = []
 260        for idx in range(len(mask_data["segmentations"])):
 261            ann = {
 262                "segmentation": mask_data["segmentations"][idx],
 263                "area": amg_utils.area_from_rle(mask_data["rles"][idx]),
 264                "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
 265                "predicted_iou": mask_data["iou_preds"][idx].item(),
 266                "stability_score": mask_data["stability_score"][idx].item(),
 267                "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
 268            }
 269            # the data from embedding based segmentation doesn't have the points
 270            # so we skip if the corresponding key can't be found
 271            try:
 272                ann["point_coords"] = [mask_data["points"][idx].tolist()]
 273            except KeyError:
 274                pass
 275            curr_anns.append(ann)
 276
 277        return curr_anns
 278
 279    def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
 280        orig_h, orig_w = original_size
 281
 282        # serialize predictions and store in MaskData
 283        data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
 284        if points is not None:
 285            data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
 286
 287        del masks
 288
 289        # calculate the stability scores
 290        data["stability_score"] = amg_utils.calculate_stability_score(
 291            data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
 292        )
 293
 294        # threshold masks and calculate boxes
 295        data["masks"] = data["masks"] > self._predictor.model.mask_threshold
 296        data["masks"] = data["masks"].type(torch.bool)
 297        data["boxes"] = batched_mask_to_box(data["masks"])
 298
 299        # compress to RLE
 300        data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
 301        # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
 302        data["rles"] = mask_to_rle_pytorch(data["masks"])
 303        del data["masks"]
 304
 305        return data
 306
 307    def get_state(self) -> Dict[str, Any]:
 308        """Get the initialized state of the mask generator.
 309
 310        Returns:
 311            State of the mask generator.
 312        """
 313        if not self.is_initialized:
 314            raise RuntimeError("The state has not been computed yet. Call initialize first.")
 315
 316        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
 317
 318    def set_state(self, state: Dict[str, Any]) -> None:
 319        """Set the state of the mask generator.
 320
 321        Args:
 322            state: The state of the mask generator, e.g. from serialized state.
 323        """
 324        self._crop_list = state["crop_list"]
 325        self._crop_boxes = state["crop_boxes"]
 326        self._original_size = state["original_size"]
 327        self._is_initialized = True
 328
 329    def clear_state(self):
 330        """Clear the state of the mask generator.
 331        """
 332        self._crop_list = None
 333        self._crop_boxes = None
 334        self._original_size = None
 335        self._is_initialized = False
 336
 337
 338class AutomaticMaskGenerator(AMGBase):
 339    """Generates an instance segmentation without prompts, using a point grid.
 340
 341    This class implements the same logic as
 342    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
 343    It decouples the computationally expensive steps of generating masks from the cheap post-processing operation
 344    to filter these masks to enable grid search and interactively changing the post-processing.
 345
 346    Use this class as follows:
 347    ```python
 348    amg = AutomaticMaskGenerator(predictor)
 349    amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
 350    masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
 351    ```
 352
 353    Args:
 354        predictor: The segment anything predictor.
 355        points_per_side: The number of points to be sampled along one side of the image.
 356            If None, `point_grids` must provide explicit point sampling.
 357        points_per_batch: The number of points run simultaneously by the model.
 358            Higher numbers may be faster but use more GPU memory.
 359        crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
 360        crop_overlap_ratio: Sets the degree to which crops overlap.
 361        crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
 362        point_grids: A lisst over explicit grids of points used for sampling masks.
 363            Normalized to [0, 1] with respect to the image coordinate system.
 364        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
 365    """
 366    def __init__(
 367        self,
 368        predictor: SamPredictor,
 369        points_per_side: Optional[int] = 32,
 370        points_per_batch: Optional[int] = None,
 371        crop_n_layers: int = 0,
 372        crop_overlap_ratio: float = 512 / 1500,
 373        crop_n_points_downscale_factor: int = 1,
 374        point_grids: Optional[List[np.ndarray]] = None,
 375        stability_score_offset: float = 1.0,
 376    ):
 377        super().__init__()
 378
 379        if points_per_side is not None:
 380            self.point_grids = amg_utils.build_all_layer_point_grids(
 381                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
 382            )
 383        elif point_grids is not None:
 384            self.point_grids = point_grids
 385        else:
 386            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
 387
 388        self._predictor = predictor
 389        self._points_per_side = points_per_side
 390
 391        # we set the points per batch to 16 for mps for performance reasons
 392        # and otherwise keep them at the default of 64
 393        if points_per_batch is None:
 394            points_per_batch = 16 if str(predictor.device) == "mps" else 64
 395        self._points_per_batch = points_per_batch
 396
 397        self._crop_n_layers = crop_n_layers
 398        self._crop_overlap_ratio = crop_overlap_ratio
 399        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
 400        self._stability_score_offset = stability_score_offset
 401
 402    def _process_batch(self, points, im_size, crop_box, original_size):
 403        # run model on this batch
 404        transformed_points = self._predictor.transform.apply_coords(points, im_size)
 405        in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
 406        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
 407        masks, iou_preds, _ = self._predictor.predict_torch(
 408            point_coords=in_points[:, None, :],
 409            point_labels=in_labels[:, None],
 410            multimask_output=True,
 411            return_logits=True,
 412        )
 413        data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
 414        del masks
 415        return data
 416
 417    def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None):
 418        # Crop the image and calculate embeddings.
 419        x0, y0, x1, y1 = crop_box
 420        cropped_im = image[y0:y1, x0:x1, :]
 421        cropped_im_size = cropped_im.shape[:2]
 422
 423        if not precomputed_embeddings:
 424            self._predictor.set_image(cropped_im)
 425
 426        # Get the points for this crop.
 427        points_scale = np.array(cropped_im_size)[None, ::-1]
 428        points_for_image = self.point_grids[crop_layer_idx] * points_scale
 429
 430        # Generate masks for this crop in batches.
 431        data = amg_utils.MaskData()
 432        n_batches = len(points_for_image) // self._points_per_batch +\
 433            int(len(points_for_image) % self._points_per_batch != 0)
 434        if pbar_init is not None:
 435            pbar_init(n_batches, "Predict masks for point grid prompts")
 436
 437        for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image):
 438            batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
 439            data.cat(batch_data)
 440            del batch_data
 441            if pbar_update is not None:
 442                pbar_update(1)
 443
 444        if not precomputed_embeddings:
 445            self._predictor.reset_image()
 446
 447        return data
 448
 449    @torch.no_grad()
 450    def initialize(
 451        self,
 452        image: np.ndarray,
 453        image_embeddings: Optional[util.ImageEmbeddings] = None,
 454        i: Optional[int] = None,
 455        verbose: bool = False,
 456        pbar_init: Optional[callable] = None,
 457        pbar_update: Optional[callable] = None,
 458    ) -> None:
 459        """Initialize image embeddings and masks for an image.
 460
 461        Args:
 462            image: The input image, volume or timeseries.
 463            image_embeddings: Optional precomputed image embeddings.
 464                See `util.precompute_image_embeddings` for details.
 465            i: Index for the image data. Required if `image` has three spatial dimensions
 466                or a time dimension and two spatial dimensions.
 467            verbose: Whether to print computation progress.
 468            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 469                Can be used together with pbar_update to handle napari progress bar in other thread.
 470                To enables using this function within a threadworker.
 471            pbar_update: Callback to update an external progress bar.
 472        """
 473        original_size = image.shape[:2]
 474        self._original_size = original_size
 475
 476        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
 477            original_size, self._crop_n_layers, self._crop_overlap_ratio
 478        )
 479
 480        # We can set fixed image embeddings if we only have a single crop box (the default setting).
 481        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
 482        if len(crop_boxes) == 1:
 483            if image_embeddings is None:
 484                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 485            util.set_precomputed(self._predictor, image_embeddings, i=i)
 486            precomputed_embeddings = True
 487        else:
 488            precomputed_embeddings = False
 489
 490        # we need to cast to the image representation that is compatible with SAM
 491        image = util._to_image(image)
 492
 493        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 494
 495        crop_list = []
 496        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
 497            crop_data = self._process_crop(
 498                image, crop_box, layer_idx,
 499                precomputed_embeddings=precomputed_embeddings,
 500                pbar_init=pbar_init, pbar_update=pbar_update,
 501            )
 502            crop_list.append(crop_data)
 503        pbar_close()
 504
 505        self._is_initialized = True
 506        self._crop_list = crop_list
 507        self._crop_boxes = crop_boxes
 508
 509    @torch.no_grad()
 510    def generate(
 511        self,
 512        pred_iou_thresh: float = 0.88,
 513        stability_score_thresh: float = 0.95,
 514        box_nms_thresh: float = 0.7,
 515        crop_nms_thresh: float = 0.7,
 516        min_mask_region_area: int = 0,
 517        output_mode: str = "binary_mask",
 518    ) -> List[Dict[str, Any]]:
 519        """Generate instance segmentation for the currently initialized image.
 520
 521        Args:
 522            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
 523            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
 524                under changes to the cutoff used to binarize the model prediction.
 525            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
 526            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
 527            min_mask_region_area: Minimal size for the predicted masks.
 528            output_mode: The form masks are returned in.
 529
 530        Returns:
 531            The instance segmentation masks.
 532        """
 533        if not self.is_initialized:
 534            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
 535
 536        data = amg_utils.MaskData()
 537        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
 538            crop_data = self._postprocess_batch(
 539                data=deepcopy(data_),
 540                crop_box=crop_box, original_size=self.original_size,
 541                pred_iou_thresh=pred_iou_thresh,
 542                stability_score_thresh=stability_score_thresh,
 543                box_nms_thresh=box_nms_thresh
 544            )
 545            data.cat(crop_data)
 546
 547        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
 548            # Prefer masks from smaller crops
 549            scores = 1 / box_area(data["crop_boxes"])
 550            scores = scores.to(data["boxes"].device)
 551            keep_by_nms = batched_nms(
 552                data["boxes"].float(),
 553                scores,
 554                torch.zeros_like(data["boxes"][:, 0]),  # categories
 555                iou_threshold=crop_nms_thresh,
 556            )
 557            data.filter(keep_by_nms)
 558
 559        data.to_numpy()
 560        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
 561        return masks
 562
 563
 564# Helper function for tiled embedding computation and checking consistent state.
 565def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo, verbose):
 566    if image_embeddings is None:
 567        if tile_shape is None or halo is None:
 568            raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.")
 569        image_embeddings = util.precompute_image_embeddings(
 570            predictor, image, tile_shape=tile_shape, halo=halo, verbose=verbose
 571        )
 572
 573    # Use tile shape and halo from the precomputed embeddings if not given.
 574    # Otherwise check that they are consistent.
 575    feats = image_embeddings["features"]
 576    tile_shape_, halo_ = tuple(feats.attrs["tile_shape"]), tuple(feats.attrs["halo"])
 577    if tile_shape is None:
 578        tile_shape = tile_shape_
 579    elif tile_shape != tile_shape_:
 580        raise ValueError(
 581            f"Inconsistent tile_shape parameter {tile_shape} with precomputed embeedings: {tile_shape_}."
 582        )
 583    if halo is None:
 584        halo = halo_
 585    elif halo != halo_:
 586        raise ValueError(f"Inconsistent halo parameter {halo} with precomputed embeedings: {halo_}.")
 587
 588    return image_embeddings, tile_shape, halo
 589
 590
 591class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
 592    """Generates an instance segmentation without prompts, using a point grid.
 593
 594    Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
 595
 596    Args:
 597        predictor: The segment anything predictor.
 598        points_per_side: The number of points to be sampled along one side of the image.
 599            If None, `point_grids` must provide explicit point sampling.
 600        points_per_batch: The number of points run simultaneously by the model.
 601            Higher numbers may be faster but use more GPU memory.
 602        point_grids: A lisst over explicit grids of points used for sampling masks.
 603            Normalized to [0, 1] with respect to the image coordinate system.
 604        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
 605    """
 606
 607    # We only expose the arguments that make sense for the tiled mask generator.
 608    # Anything related to crops doesn't make sense, because we re-use that functionality
 609    # for tiling, so these parameters wouldn't have any effect.
 610    def __init__(
 611        self,
 612        predictor: SamPredictor,
 613        points_per_side: Optional[int] = 32,
 614        points_per_batch: int = 64,
 615        point_grids: Optional[List[np.ndarray]] = None,
 616        stability_score_offset: float = 1.0,
 617    ) -> None:
 618        super().__init__(
 619            predictor=predictor,
 620            points_per_side=points_per_side,
 621            points_per_batch=points_per_batch,
 622            point_grids=point_grids,
 623            stability_score_offset=stability_score_offset,
 624        )
 625
 626    @torch.no_grad()
 627    def initialize(
 628        self,
 629        image: np.ndarray,
 630        image_embeddings: Optional[util.ImageEmbeddings] = None,
 631        i: Optional[int] = None,
 632        tile_shape: Optional[Tuple[int, int]] = None,
 633        halo: Optional[Tuple[int, int]] = None,
 634        verbose: bool = False,
 635        pbar_init: Optional[callable] = None,
 636        pbar_update: Optional[callable] = None,
 637    ) -> None:
 638        """Initialize image embeddings and masks for an image.
 639
 640        Args:
 641            image: The input image, volume or timeseries.
 642            image_embeddings: Optional precomputed image embeddings.
 643                See `util.precompute_image_embeddings` for details.
 644            i: Index for the image data. Required if `image` has three spatial dimensions
 645                or a time dimension and two spatial dimensions.
 646            tile_shape: The tile shape for embedding prediction.
 647            halo: The overlap of between tiles.
 648            verbose: Whether to print computation progress.
 649            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 650                Can be used together with pbar_update to handle napari progress bar in other thread.
 651                To enables using this function within a threadworker.
 652            pbar_update: Callback to update an external progress bar.
 653        """
 654        original_size = image.shape[:2]
 655        self._original_size = original_size
 656
 657        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
 658            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose,
 659        )
 660
 661        tiling = blocking([0, 0], original_size, tile_shape)
 662        n_tiles = tiling.numberOfBlocks
 663
 664        # The crop box is always the full local tile.
 665        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
 666        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
 667
 668        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 669        pbar_init(n_tiles, "Compute masks for tile")
 670
 671        # We need to cast to the image representation that is compatible with SAM.
 672        image = util._to_image(image)
 673
 674        mask_data = []
 675        for tile_id in range(n_tiles):
 676            # set the pre-computed embeddings for this tile
 677            features = image_embeddings["features"][tile_id]
 678            tile_embeddings = {
 679                "features": features,
 680                "input_size": features.attrs["input_size"],
 681                "original_size": features.attrs["original_size"],
 682            }
 683            util.set_precomputed(self._predictor, tile_embeddings, i)
 684
 685            # compute the mask data for this tile and append it
 686            this_mask_data = self._process_crop(
 687                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
 688            )
 689            mask_data.append(this_mask_data)
 690            pbar_update(1)
 691        pbar_close()
 692
 693        # set the initialized data
 694        self._is_initialized = True
 695        self._crop_list = mask_data
 696        self._crop_boxes = crop_boxes
 697
 698
 699#
 700# Instance segmentation functionality based on fine-tuned decoder
 701#
 702
 703
 704class DecoderAdapter(torch.nn.Module):
 705    """Adapter to contain the UNETR decoder in a single module.
 706
 707    To apply the decoder on top of pre-computed embeddings for the segmentation functionality.
 708    See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
 709    """
 710    def __init__(self, unetr):
 711        super().__init__()
 712
 713        self.base = unetr.base
 714        self.out_conv = unetr.out_conv
 715        self.deconv_out = unetr.deconv_out
 716        self.decoder_head = unetr.decoder_head
 717        self.final_activation = unetr.final_activation
 718        self.postprocess_masks = unetr.postprocess_masks
 719
 720        self.decoder = unetr.decoder
 721        self.deconv1 = unetr.deconv1
 722        self.deconv2 = unetr.deconv2
 723        self.deconv3 = unetr.deconv3
 724        self.deconv4 = unetr.deconv4
 725
 726    def forward(self, input_, input_shape, original_shape):
 727        z12 = input_
 728
 729        z9 = self.deconv1(z12)
 730        z6 = self.deconv2(z9)
 731        z3 = self.deconv3(z6)
 732        z0 = self.deconv4(z3)
 733
 734        updated_from_encoder = [z9, z6, z3]
 735
 736        x = self.base(z12)
 737        x = self.decoder(x, encoder_inputs=updated_from_encoder)
 738        x = self.deconv_out(x)
 739
 740        x = torch.cat([x, z0], dim=1)
 741        x = self.decoder_head(x)
 742
 743        x = self.out_conv(x)
 744        if self.final_activation is not None:
 745            x = self.final_activation(x)
 746
 747        x = self.postprocess_masks(x, input_shape, original_shape)
 748        return x
 749
 750
 751def get_unetr(
 752    image_encoder: torch.nn.Module,
 753    decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
 754    device: Optional[Union[str, torch.device]] = None,
 755    out_channels: int = 3,
 756    flexible_load_checkpoint: bool = False,
 757) -> torch.nn.Module:
 758    """Get UNETR model for automatic instance segmentation.
 759
 760    Args:
 761        image_encoder: The image encoder of the SAM model.
 762            This is used as encoder by the UNETR too.
 763        decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
 764        device: The device.
 765        out_channels: The number of output channels.
 766        flexible_load_checkpoint: Whether to allow reinitialization of parameters
 767            which could not be found in the provided decoder state.
 768
 769    Returns:
 770        The UNETR model.
 771    """
 772    device = util.get_device(device)
 773
 774    unetr = UNETR(
 775        backbone="sam",
 776        encoder=image_encoder,
 777        out_channels=out_channels,
 778        use_sam_stats=True,
 779        final_activation="Sigmoid",
 780        use_skip_connection=False,
 781        resize_input=True,
 782    )
 783    if decoder_state is not None:
 784        unetr_state_dict = unetr.state_dict()
 785        for k, v in unetr_state_dict.items():
 786            if not k.startswith("encoder"):
 787                if flexible_load_checkpoint:  # Whether allow reinitalization of params, if not found.
 788                    if k in decoder_state:  # First check whether the key is available in the provided decoder state.
 789                        unetr_state_dict[k] = decoder_state[k]
 790                    else:  # Otherwise, allow it to initialize it.
 791                        warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.")
 792                        unetr_state_dict[k] = v
 793
 794                else:  # Whether be strict on finding the parameter in the decoder state.
 795                    if k not in decoder_state:
 796                        raise RuntimeError(f"The parameters for '{k}' could not be found.")
 797                    unetr_state_dict[k] = decoder_state[k]
 798
 799        unetr.load_state_dict(unetr_state_dict)
 800
 801    unetr.to(device)
 802    return unetr
 803
 804
 805def get_decoder(
 806    image_encoder: torch.nn.Module,
 807    decoder_state: OrderedDict[str, torch.Tensor],
 808    device: Optional[Union[str, torch.device]] = None,
 809) -> DecoderAdapter:
 810    """Get decoder to predict outputs for automatic instance segmentation
 811
 812    Args:
 813        image_encoder: The image encoder of the SAM model.
 814        decoder_state: State to initialize the weights of the UNETR decoder.
 815        device: The device.
 816
 817    Returns:
 818        The decoder for instance segmentation.
 819    """
 820    unetr = get_unetr(image_encoder, decoder_state, device)
 821    return DecoderAdapter(unetr)
 822
 823
 824def get_predictor_and_decoder(
 825    model_type: str,
 826    checkpoint_path: Union[str, os.PathLike],
 827    device: Optional[Union[str, torch.device]] = None,
 828    peft_kwargs: Optional[Dict] = None,
 829) -> Tuple[SamPredictor, DecoderAdapter]:
 830    """Load the SAM model (predictor) and instance segmentation decoder.
 831
 832    This requires a checkpoint that contains the state for both predictor
 833    and decoder.
 834
 835    Args:
 836        model_type: The type of the image encoder used in the SAM model.
 837        checkpoint_path: Path to the checkpoint from which to load the data.
 838        device: The device.
 839        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 840
 841    Returns:
 842        The SAM predictor.
 843        The decoder for instance segmentation.
 844    """
 845    device = util.get_device(device)
 846    predictor, state = util.get_sam_model(
 847        model_type=model_type,
 848        checkpoint_path=checkpoint_path,
 849        device=device,
 850        return_state=True,
 851        peft_kwargs=peft_kwargs,
 852    )
 853    if "decoder_state" not in state:
 854        raise ValueError(
 855            f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
 856        )
 857    decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
 858    return predictor, decoder
 859
 860
 861def _watershed_from_center_and_boundary_distances_parallel(
 862    center_distances,
 863    boundary_distances,
 864    foreground_map,
 865    center_distance_threshold,
 866    boundary_distance_threshold,
 867    foreground_threshold,
 868    distance_smoothing,
 869    min_size,
 870    tile_shape,
 871    halo,
 872    n_threads,
 873    verbose=False,
 874):
 875    center_distances = apply_filter(
 876        center_distances, "gaussianSmoothing", sigma=distance_smoothing,
 877        block_shape=tile_shape, n_threads=n_threads
 878    )
 879    boundary_distances = apply_filter(
 880        boundary_distances, "gaussianSmoothing", sigma=distance_smoothing,
 881        block_shape=tile_shape, n_threads=n_threads
 882    )
 883
 884    fg_mask = foreground_map > foreground_threshold
 885
 886    marker_map = np.logical_and(
 887        center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold
 888    )
 889    marker_map[~fg_mask] = 0
 890
 891    markers = np.zeros(marker_map.shape, dtype="uint64")
 892    markers = parallel.label(
 893        marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose,
 894    )
 895
 896    seg = np.zeros_like(markers, dtype="uint64")
 897    seg = parallel.seeded_watershed(
 898        boundary_distances, seeds=markers, out=seg, block_shape=tile_shape,
 899        halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask,
 900    )
 901
 902    out = np.zeros_like(seg, dtype="uint64")
 903    out = parallel.size_filter(
 904        seg, out=out, min_size=min_size, block_shape=tile_shape, n_threads=n_threads, verbose=verbose
 905    )
 906
 907    return out
 908
 909
 910class InstanceSegmentationWithDecoder:
 911    """Generates an instance segmentation without prompts, using a decoder.
 912
 913    Implements the same interface as `AutomaticMaskGenerator`.
 914
 915    Use this class as follows:
 916    ```python
 917    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
 918    segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
 919    masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
 920    ```
 921
 922    Args:
 923        predictor: The segment anything predictor.
 924        decoder: The decoder to predict intermediate representations
 925            for instance segmentation.
 926    """
 927    def __init__(
 928        self,
 929        predictor: SamPredictor,
 930        decoder: torch.nn.Module,
 931    ) -> None:
 932        self._predictor = predictor
 933        self._decoder = decoder
 934
 935        # The decoder outputs.
 936        self._foreground = None
 937        self._center_distances = None
 938        self._boundary_distances = None
 939
 940        self._is_initialized = False
 941
 942    @property
 943    def is_initialized(self):
 944        """Whether the mask generator has already been initialized.
 945        """
 946        return self._is_initialized
 947
 948    @torch.no_grad()
 949    def initialize(
 950        self,
 951        image: np.ndarray,
 952        image_embeddings: Optional[util.ImageEmbeddings] = None,
 953        i: Optional[int] = None,
 954        verbose: bool = False,
 955        pbar_init: Optional[callable] = None,
 956        pbar_update: Optional[callable] = None,
 957    ) -> None:
 958        """Initialize image embeddings and decoder predictions for an image.
 959
 960        Args:
 961            image: The input image, volume or timeseries.
 962            image_embeddings: Optional precomputed image embeddings.
 963                See `util.precompute_image_embeddings` for details.
 964            i: Index for the image data. Required if `image` has three spatial dimensions
 965                or a time dimension and two spatial dimensions.
 966            verbose: Whether to be verbose.
 967            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 968                Can be used together with pbar_update to handle napari progress bar in other thread.
 969                To enables using this function within a threadworker.
 970            pbar_update: Callback to update an external progress bar.
 971        """
 972        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 973        pbar_init(1, "Initialize instance segmentation with decoder")
 974
 975        if image_embeddings is None:
 976            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 977
 978        # Get the image embeddings from the predictor.
 979        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
 980        embeddings = self._predictor.features
 981        input_shape = tuple(self._predictor.input_size)
 982        original_shape = tuple(self._predictor.original_size)
 983
 984        # Run prediction with the UNETR decoder.
 985        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
 986        assert output.shape[0] == 3, f"{output.shape}"
 987        pbar_update(1)
 988        pbar_close()
 989
 990        # Set the state.
 991        self._foreground = output[0]
 992        self._center_distances = output[1]
 993        self._boundary_distances = output[2]
 994        self._is_initialized = True
 995
 996    def _to_masks(self, segmentation, output_mode):
 997        if output_mode != "binary_mask":
 998            raise NotImplementedError
 999
1000        props = regionprops(segmentation)
1001        ndim = segmentation.ndim
1002        assert ndim in (2, 3)
1003
1004        shape = segmentation.shape
1005        if ndim == 2:
1006            crop_box = [0, shape[1], 0, shape[0]]
1007        else:
1008            crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]
1009
1010        # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
1011        def to_bbox_2d(bbox):
1012            y0, x0 = bbox[0], bbox[1]
1013            w = bbox[3] - x0
1014            h = bbox[2] - y0
1015            return [x0, w, y0, h]
1016
1017        def to_bbox_3d(bbox):
1018            z0, y0, x0 = bbox[0], bbox[1], bbox[2]
1019            w = bbox[5] - x0
1020            h = bbox[4] - y0
1021            d = bbox[3] - y0
1022            return [x0, w, y0, h, z0, d]
1023
1024        to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
1025        masks = [
1026            {
1027                "segmentation": segmentation == prop.label,
1028                "area": prop.area,
1029                "bbox": to_bbox(prop.bbox),
1030                "crop_box": crop_box,
1031                "seg_id": prop.label,
1032            } for prop in props
1033        ]
1034        return masks
1035
1036    def generate(
1037        self,
1038        center_distance_threshold: float = 0.5,
1039        boundary_distance_threshold: float = 0.5,
1040        foreground_threshold: float = 0.5,
1041        foreground_smoothing: float = 1.0,
1042        distance_smoothing: float = 1.6,
1043        min_size: int = 0,
1044        output_mode: Optional[str] = "binary_mask",
1045        tile_shape: Optional[Tuple[int, int]] = None,
1046        halo: Optional[Tuple[int, int]] = None,
1047        n_threads: Optional[int] = None,
1048    ) -> List[Dict[str, Any]]:
1049        """Generate instance segmentation for the currently initialized image.
1050
1051        Args:
1052            center_distance_threshold: Center distance predictions below this value will be
1053                used to find seeds (intersected with thresholded boundary distance predictions).
1054            boundary_distance_threshold: Boundary distance predictions below this value will be
1055                used to find seeds (intersected with thresholded center distance predictions).
1056            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1057                checkerboard artifacts in the prediction.
1058            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1059            distance_smoothing: Sigma value for smoothing the distance predictions.
1060            min_size: Minimal object size in the segmentation result.
1061            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
1062            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1063                This parameter is independent from the tile shape for computing the embeddings.
1064                If not given then post-processing will not be parallelized.
1065            halo: Halo for parallel post-processing. See also `tile_shape`.
1066            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1067
1068        Returns:
1069            The instance segmentation masks.
1070        """
1071        if not self.is_initialized:
1072            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1073
1074        if foreground_smoothing > 0:
1075            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1076        else:
1077            foreground = self._foreground
1078
1079        if tile_shape is None:
1080            segmentation = watershed_from_center_and_boundary_distances(
1081                center_distances=self._center_distances,
1082                boundary_distances=self._boundary_distances,
1083                foreground_map=foreground,
1084                center_distance_threshold=center_distance_threshold,
1085                boundary_distance_threshold=boundary_distance_threshold,
1086                foreground_threshold=foreground_threshold,
1087                distance_smoothing=distance_smoothing,
1088                min_size=min_size,
1089            )
1090        else:
1091            if halo is None:
1092                raise ValueError("You must pass a value for halo if tile_shape is given.")
1093            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1094                center_distances=self._center_distances,
1095                boundary_distances=self._boundary_distances,
1096                foreground_map=foreground,
1097                center_distance_threshold=center_distance_threshold,
1098                boundary_distance_threshold=boundary_distance_threshold,
1099                foreground_threshold=foreground_threshold,
1100                distance_smoothing=distance_smoothing,
1101                min_size=min_size,
1102                tile_shape=tile_shape,
1103                halo=halo,
1104                n_threads=n_threads,
1105                verbose=False,
1106            )
1107
1108        if output_mode is not None:
1109            segmentation = self._to_masks(segmentation, output_mode)
1110        return segmentation
1111
1112    def get_state(self) -> Dict[str, Any]:
1113        """Get the initialized state of the instance segmenter.
1114
1115        Returns:
1116            Instance segmentation state.
1117        """
1118        if not self.is_initialized:
1119            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1120
1121        return {
1122            "foreground": self._foreground,
1123            "center_distances": self._center_distances,
1124            "boundary_distances": self._boundary_distances,
1125        }
1126
1127    def set_state(self, state: Dict[str, Any]) -> None:
1128        """Set the state of the instance segmenter.
1129
1130        Args:
1131            state: The instance segmentation state
1132        """
1133        self._foreground = state["foreground"]
1134        self._center_distances = state["center_distances"]
1135        self._boundary_distances = state["boundary_distances"]
1136        self._is_initialized = True
1137
1138    def clear_state(self):
1139        """Clear the state of the instance segmenter.
1140        """
1141        self._foreground = None
1142        self._center_distances = None
1143        self._boundary_distances = None
1144        self._is_initialized = False
1145
1146
1147class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1148    """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings.
1149    """
1150
1151    @torch.no_grad()
1152    def initialize(
1153        self,
1154        image: np.ndarray,
1155        image_embeddings: Optional[util.ImageEmbeddings] = None,
1156        i: Optional[int] = None,
1157        tile_shape: Optional[Tuple[int, int]] = None,
1158        halo: Optional[Tuple[int, int]] = None,
1159        verbose: bool = False,
1160        pbar_init: Optional[callable] = None,
1161        pbar_update: Optional[callable] = None,
1162    ) -> None:
1163        """Initialize image embeddings and decoder predictions for an image.
1164
1165        Args:
1166            image: The input image, volume or timeseries.
1167            image_embeddings: Optional precomputed image embeddings.
1168                See `util.precompute_image_embeddings` for details.
1169            i: Index for the image data. Required if `image` has three spatial dimensions
1170                or a time dimension and two spatial dimensions.
1171            tile_shape: Shape of the tiles for precomputing image embeddings.
1172            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1173            verbose: Dummy input to be compatible with other function signatures.
1174            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1175                Can be used together with pbar_update to handle napari progress bar in other thread.
1176                To enables using this function within a threadworker.
1177            pbar_update: Callback to update an external progress bar.
1178        """
1179        original_size = image.shape[:2]
1180        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1181            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose,
1182        )
1183        tiling = blocking([0, 0], original_size, tile_shape)
1184
1185        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1186        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1187
1188        foreground = np.zeros(original_size, dtype="float32")
1189        center_distances = np.zeros(original_size, dtype="float32")
1190        boundary_distances = np.zeros(original_size, dtype="float32")
1191
1192        for tile_id in range(tiling.numberOfBlocks):
1193
1194            # Get the image embeddings from the predictor for this tile.
1195            self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1196            embeddings = self._predictor.features
1197            input_shape = tuple(self._predictor.input_size)
1198            original_shape = tuple(self._predictor.original_size)
1199
1200            # Predict with the UNETR decoder for this tile.
1201            output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1202            assert output.shape[0] == 3, f"{output.shape}"
1203
1204            # Set the predictions in the output for this tile.
1205            block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1206            local_bb = tuple(
1207                slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1208            )
1209            inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1210
1211            foreground[inner_bb] = output[0][local_bb]
1212            center_distances[inner_bb] = output[1][local_bb]
1213            boundary_distances[inner_bb] = output[2][local_bb]
1214            pbar_update(1)
1215
1216        pbar_close()
1217
1218        # Set the state.
1219        self._foreground = foreground
1220        self._center_distances = center_distances
1221        self._boundary_distances = boundary_distances
1222        self._is_initialized = True
1223
1224
1225def get_amg(
1226    predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs,
1227) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1228    """Get the automatic mask generator class.
1229
1230    Args:
1231        predictor: The segment anything predictor.
1232        is_tiled: Whether tiled embeddings are used.
1233        decoder: Decoder to predict instacne segmmentation.
1234        kwargs: The keyword arguments for the amg class.
1235
1236    Returns:
1237        The automatic mask generator.
1238    """
1239    if decoder is None:
1240        segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator
1241        segmenter = segmenter_class(predictor, **kwargs)
1242    else:
1243        segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder
1244        segmenter = segmenter_class(predictor, decoder, **kwargs)
1245
1246    return segmenter
def mask_data_to_segmentation( masks: List[Dict[str, Any]], with_background: bool, min_object_size: int = 0, max_object_size: Optional[int] = None, label_masks: bool = True) -> numpy.ndarray:
 50def mask_data_to_segmentation(
 51    masks: List[Dict[str, Any]],
 52    with_background: bool,
 53    min_object_size: int = 0,
 54    max_object_size: Optional[int] = None,
 55    label_masks: bool = True,
 56) -> np.ndarray:
 57    """Convert the output of the automatic mask generation to an instance segmentation.
 58
 59    Args:
 60        masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator.
 61            Only supports output_mode=binary_mask.
 62        with_background: Whether the segmentation has background. If yes this function assures that the largest
 63            object in the output will be mapped to zero (the background value).
 64        min_object_size: The minimal size of an object in pixels.
 65        max_object_size: The maximal size of an object in pixels.
 66        label_masks: Whether to apply connected components to the result before removing small objects.
 67
 68    Returns:
 69        The instance segmentation.
 70    """
 71
 72    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
 73    # we could also get the shape from the crop box
 74    shape = next(iter(masks))["segmentation"].shape
 75    segmentation = np.zeros(shape, dtype="uint32")
 76
 77    def require_numpy(mask):
 78        return mask.cpu().numpy() if torch.is_tensor(mask) else mask
 79
 80    seg_id = 1
 81    for mask in masks:
 82        if mask["area"] < min_object_size:
 83            continue
 84        if max_object_size is not None and mask["area"] > max_object_size:
 85            continue
 86
 87        this_seg_id = mask.get("seg_id", seg_id)
 88        segmentation[require_numpy(mask["segmentation"])] = this_seg_id
 89        seg_id = this_seg_id + 1
 90
 91    if label_masks:
 92        segmentation = label(segmentation).astype(segmentation.dtype)
 93
 94    seg_ids, sizes = np.unique(segmentation, return_counts=True)
 95
 96    # In some cases objects may be smaller than peviously calculated,
 97    # since they are covered by other objects. We ensure these also get
 98    # filtered out here.
 99    filter_ids = seg_ids[sizes < min_object_size]
100
101    # If we run segmentation with background we also map the largest segment
102    # (the most likely background object) to zero. This is often zero already,
103    # but it does not hurt to reset that to zero either.
104    if with_background:
105        bg_id = seg_ids[np.argmax(sizes)]
106        filter_ids = np.concatenate([filter_ids, [bg_id]])
107
108    segmentation[np.isin(segmentation, filter_ids)] = 0
109    segmentation = relabel_sequential(segmentation)[0]
110
111    return segmentation

Convert the output of the automatic mask generation to an instance segmentation.

Arguments:
  • masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. Only supports output_mode=binary_mask.
  • with_background: Whether the segmentation has background. If yes this function assures that the largest object in the output will be mapped to zero (the background value).
  • min_object_size: The minimal size of an object in pixels.
  • max_object_size: The maximal size of an object in pixels.
  • label_masks: Whether to apply connected components to the result before removing small objects.
Returns:

The instance segmentation.

class AMGBase(abc.ABC):
119class AMGBase(ABC):
120    """Base class for the automatic mask generators.
121    """
122    def __init__(self):
123        # the state that has to be computed by the 'initialize' method of the child classes
124        self._is_initialized = False
125        self._crop_list = None
126        self._crop_boxes = None
127        self._original_size = None
128
129    @property
130    def is_initialized(self):
131        """Whether the mask generator has already been initialized.
132        """
133        return self._is_initialized
134
135    @property
136    def crop_list(self):
137        """The list of mask data after initialization.
138        """
139        return self._crop_list
140
141    @property
142    def crop_boxes(self):
143        """The list of crop boxes.
144        """
145        return self._crop_boxes
146
147    @property
148    def original_size(self):
149        """The original image size.
150        """
151        return self._original_size
152
153    def _postprocess_batch(
154        self,
155        data,
156        crop_box,
157        original_size,
158        pred_iou_thresh,
159        stability_score_thresh,
160        box_nms_thresh,
161    ):
162        orig_h, orig_w = original_size
163
164        # filter by predicted IoU
165        if pred_iou_thresh > 0.0:
166            keep_mask = data["iou_preds"] > pred_iou_thresh
167            data.filter(keep_mask)
168
169        # filter by stability score
170        if stability_score_thresh > 0.0:
171            keep_mask = data["stability_score"] >= stability_score_thresh
172            data.filter(keep_mask)
173
174        # filter boxes that touch crop boundaries
175        keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
176        if not torch.all(keep_mask):
177            data.filter(keep_mask)
178
179        # remove duplicates within this crop.
180        keep_by_nms = batched_nms(
181            data["boxes"].float(),
182            data["iou_preds"],
183            torch.zeros_like(data["boxes"][:, 0]),  # categories
184            iou_threshold=box_nms_thresh,
185        )
186        data.filter(keep_by_nms)
187
188        # return to the original image frame
189        data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box)
190        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
191        # the data from embedding based segmentation doesn't have the points
192        # so we skip if the corresponding key can't be found
193        try:
194            data["points"] = amg_utils.uncrop_points(data["points"], crop_box)
195        except KeyError:
196            pass
197
198        return data
199
200    def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
201
202        if len(mask_data["rles"]) == 0:
203            return mask_data
204
205        # filter small disconnected regions and holes
206        new_masks = []
207        scores = []
208        for rle in mask_data["rles"]:
209            mask = amg_utils.rle_to_mask(rle)
210
211            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes")
212            unchanged = not changed
213            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands")
214            unchanged = unchanged and not changed
215
216            new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0))
217            # give score=0 to changed masks and score=1 to unchanged masks
218            # so NMS will prefer ones that didn't need postprocessing
219            scores.append(float(unchanged))
220
221        # recalculate boxes and remove any new duplicates
222        masks = torch.cat(new_masks, dim=0)
223        boxes = batched_mask_to_box(masks.to(torch.bool))  # Casting this to boolean as we work with one-hot labels.
224        keep_by_nms = batched_nms(
225            boxes.float(),
226            torch.as_tensor(scores, dtype=torch.float),
227            torch.zeros_like(boxes[:, 0]),  # categories
228            iou_threshold=nms_thresh,
229        )
230
231        # only recalculate RLEs for masks that have changed
232        for i_mask in keep_by_nms:
233            if scores[i_mask] == 0.0:
234                mask_torch = masks[i_mask].unsqueeze(0)
235                # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
236                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
237                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
238        mask_data.filter(keep_by_nms)
239
240        return mask_data
241
242    def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode):
243        # filter small disconnected regions and holes in masks
244        if min_mask_region_area > 0:
245            mask_data = self._postprocess_small_regions(
246                mask_data,
247                min_mask_region_area,
248                max(box_nms_thresh, crop_nms_thresh),
249            )
250
251        # encode masks
252        if output_mode == "coco_rle":
253            mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]]
254        elif output_mode == "binary_mask":
255            mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]]
256        else:
257            mask_data["segmentations"] = mask_data["rles"]
258
259        # write mask records
260        curr_anns = []
261        for idx in range(len(mask_data["segmentations"])):
262            ann = {
263                "segmentation": mask_data["segmentations"][idx],
264                "area": amg_utils.area_from_rle(mask_data["rles"][idx]),
265                "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
266                "predicted_iou": mask_data["iou_preds"][idx].item(),
267                "stability_score": mask_data["stability_score"][idx].item(),
268                "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
269            }
270            # the data from embedding based segmentation doesn't have the points
271            # so we skip if the corresponding key can't be found
272            try:
273                ann["point_coords"] = [mask_data["points"][idx].tolist()]
274            except KeyError:
275                pass
276            curr_anns.append(ann)
277
278        return curr_anns
279
280    def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
281        orig_h, orig_w = original_size
282
283        # serialize predictions and store in MaskData
284        data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
285        if points is not None:
286            data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
287
288        del masks
289
290        # calculate the stability scores
291        data["stability_score"] = amg_utils.calculate_stability_score(
292            data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
293        )
294
295        # threshold masks and calculate boxes
296        data["masks"] = data["masks"] > self._predictor.model.mask_threshold
297        data["masks"] = data["masks"].type(torch.bool)
298        data["boxes"] = batched_mask_to_box(data["masks"])
299
300        # compress to RLE
301        data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
302        # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
303        data["rles"] = mask_to_rle_pytorch(data["masks"])
304        del data["masks"]
305
306        return data
307
308    def get_state(self) -> Dict[str, Any]:
309        """Get the initialized state of the mask generator.
310
311        Returns:
312            State of the mask generator.
313        """
314        if not self.is_initialized:
315            raise RuntimeError("The state has not been computed yet. Call initialize first.")
316
317        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
318
319    def set_state(self, state: Dict[str, Any]) -> None:
320        """Set the state of the mask generator.
321
322        Args:
323            state: The state of the mask generator, e.g. from serialized state.
324        """
325        self._crop_list = state["crop_list"]
326        self._crop_boxes = state["crop_boxes"]
327        self._original_size = state["original_size"]
328        self._is_initialized = True
329
330    def clear_state(self):
331        """Clear the state of the mask generator.
332        """
333        self._crop_list = None
334        self._crop_boxes = None
335        self._original_size = None
336        self._is_initialized = False

Base class for the automatic mask generators.

is_initialized
129    @property
130    def is_initialized(self):
131        """Whether the mask generator has already been initialized.
132        """
133        return self._is_initialized

Whether the mask generator has already been initialized.

crop_list
135    @property
136    def crop_list(self):
137        """The list of mask data after initialization.
138        """
139        return self._crop_list

The list of mask data after initialization.

crop_boxes
141    @property
142    def crop_boxes(self):
143        """The list of crop boxes.
144        """
145        return self._crop_boxes

The list of crop boxes.

original_size
147    @property
148    def original_size(self):
149        """The original image size.
150        """
151        return self._original_size

The original image size.

def get_state(self) -> Dict[str, Any]:
308    def get_state(self) -> Dict[str, Any]:
309        """Get the initialized state of the mask generator.
310
311        Returns:
312            State of the mask generator.
313        """
314        if not self.is_initialized:
315            raise RuntimeError("The state has not been computed yet. Call initialize first.")
316
317        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}

Get the initialized state of the mask generator.

Returns:

State of the mask generator.

def set_state(self, state: Dict[str, Any]) -> None:
319    def set_state(self, state: Dict[str, Any]) -> None:
320        """Set the state of the mask generator.
321
322        Args:
323            state: The state of the mask generator, e.g. from serialized state.
324        """
325        self._crop_list = state["crop_list"]
326        self._crop_boxes = state["crop_boxes"]
327        self._original_size = state["original_size"]
328        self._is_initialized = True

Set the state of the mask generator.

Arguments:
  • state: The state of the mask generator, e.g. from serialized state.
def clear_state(self):
330    def clear_state(self):
331        """Clear the state of the mask generator.
332        """
333        self._crop_list = None
334        self._crop_boxes = None
335        self._original_size = None
336        self._is_initialized = False

Clear the state of the mask generator.

class AutomaticMaskGenerator(AMGBase):
339class AutomaticMaskGenerator(AMGBase):
340    """Generates an instance segmentation without prompts, using a point grid.
341
342    This class implements the same logic as
343    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
344    It decouples the computationally expensive steps of generating masks from the cheap post-processing operation
345    to filter these masks to enable grid search and interactively changing the post-processing.
346
347    Use this class as follows:
348    ```python
349    amg = AutomaticMaskGenerator(predictor)
350    amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
351    masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
352    ```
353
354    Args:
355        predictor: The segment anything predictor.
356        points_per_side: The number of points to be sampled along one side of the image.
357            If None, `point_grids` must provide explicit point sampling.
358        points_per_batch: The number of points run simultaneously by the model.
359            Higher numbers may be faster but use more GPU memory.
360        crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
361        crop_overlap_ratio: Sets the degree to which crops overlap.
362        crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
363        point_grids: A lisst over explicit grids of points used for sampling masks.
364            Normalized to [0, 1] with respect to the image coordinate system.
365        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
366    """
367    def __init__(
368        self,
369        predictor: SamPredictor,
370        points_per_side: Optional[int] = 32,
371        points_per_batch: Optional[int] = None,
372        crop_n_layers: int = 0,
373        crop_overlap_ratio: float = 512 / 1500,
374        crop_n_points_downscale_factor: int = 1,
375        point_grids: Optional[List[np.ndarray]] = None,
376        stability_score_offset: float = 1.0,
377    ):
378        super().__init__()
379
380        if points_per_side is not None:
381            self.point_grids = amg_utils.build_all_layer_point_grids(
382                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
383            )
384        elif point_grids is not None:
385            self.point_grids = point_grids
386        else:
387            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
388
389        self._predictor = predictor
390        self._points_per_side = points_per_side
391
392        # we set the points per batch to 16 for mps for performance reasons
393        # and otherwise keep them at the default of 64
394        if points_per_batch is None:
395            points_per_batch = 16 if str(predictor.device) == "mps" else 64
396        self._points_per_batch = points_per_batch
397
398        self._crop_n_layers = crop_n_layers
399        self._crop_overlap_ratio = crop_overlap_ratio
400        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
401        self._stability_score_offset = stability_score_offset
402
403    def _process_batch(self, points, im_size, crop_box, original_size):
404        # run model on this batch
405        transformed_points = self._predictor.transform.apply_coords(points, im_size)
406        in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
407        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
408        masks, iou_preds, _ = self._predictor.predict_torch(
409            point_coords=in_points[:, None, :],
410            point_labels=in_labels[:, None],
411            multimask_output=True,
412            return_logits=True,
413        )
414        data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
415        del masks
416        return data
417
418    def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None):
419        # Crop the image and calculate embeddings.
420        x0, y0, x1, y1 = crop_box
421        cropped_im = image[y0:y1, x0:x1, :]
422        cropped_im_size = cropped_im.shape[:2]
423
424        if not precomputed_embeddings:
425            self._predictor.set_image(cropped_im)
426
427        # Get the points for this crop.
428        points_scale = np.array(cropped_im_size)[None, ::-1]
429        points_for_image = self.point_grids[crop_layer_idx] * points_scale
430
431        # Generate masks for this crop in batches.
432        data = amg_utils.MaskData()
433        n_batches = len(points_for_image) // self._points_per_batch +\
434            int(len(points_for_image) % self._points_per_batch != 0)
435        if pbar_init is not None:
436            pbar_init(n_batches, "Predict masks for point grid prompts")
437
438        for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image):
439            batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
440            data.cat(batch_data)
441            del batch_data
442            if pbar_update is not None:
443                pbar_update(1)
444
445        if not precomputed_embeddings:
446            self._predictor.reset_image()
447
448        return data
449
450    @torch.no_grad()
451    def initialize(
452        self,
453        image: np.ndarray,
454        image_embeddings: Optional[util.ImageEmbeddings] = None,
455        i: Optional[int] = None,
456        verbose: bool = False,
457        pbar_init: Optional[callable] = None,
458        pbar_update: Optional[callable] = None,
459    ) -> None:
460        """Initialize image embeddings and masks for an image.
461
462        Args:
463            image: The input image, volume or timeseries.
464            image_embeddings: Optional precomputed image embeddings.
465                See `util.precompute_image_embeddings` for details.
466            i: Index for the image data. Required if `image` has three spatial dimensions
467                or a time dimension and two spatial dimensions.
468            verbose: Whether to print computation progress.
469            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
470                Can be used together with pbar_update to handle napari progress bar in other thread.
471                To enables using this function within a threadworker.
472            pbar_update: Callback to update an external progress bar.
473        """
474        original_size = image.shape[:2]
475        self._original_size = original_size
476
477        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
478            original_size, self._crop_n_layers, self._crop_overlap_ratio
479        )
480
481        # We can set fixed image embeddings if we only have a single crop box (the default setting).
482        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
483        if len(crop_boxes) == 1:
484            if image_embeddings is None:
485                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
486            util.set_precomputed(self._predictor, image_embeddings, i=i)
487            precomputed_embeddings = True
488        else:
489            precomputed_embeddings = False
490
491        # we need to cast to the image representation that is compatible with SAM
492        image = util._to_image(image)
493
494        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
495
496        crop_list = []
497        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
498            crop_data = self._process_crop(
499                image, crop_box, layer_idx,
500                precomputed_embeddings=precomputed_embeddings,
501                pbar_init=pbar_init, pbar_update=pbar_update,
502            )
503            crop_list.append(crop_data)
504        pbar_close()
505
506        self._is_initialized = True
507        self._crop_list = crop_list
508        self._crop_boxes = crop_boxes
509
510    @torch.no_grad()
511    def generate(
512        self,
513        pred_iou_thresh: float = 0.88,
514        stability_score_thresh: float = 0.95,
515        box_nms_thresh: float = 0.7,
516        crop_nms_thresh: float = 0.7,
517        min_mask_region_area: int = 0,
518        output_mode: str = "binary_mask",
519    ) -> List[Dict[str, Any]]:
520        """Generate instance segmentation for the currently initialized image.
521
522        Args:
523            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
524            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
525                under changes to the cutoff used to binarize the model prediction.
526            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
527            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
528            min_mask_region_area: Minimal size for the predicted masks.
529            output_mode: The form masks are returned in.
530
531        Returns:
532            The instance segmentation masks.
533        """
534        if not self.is_initialized:
535            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
536
537        data = amg_utils.MaskData()
538        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
539            crop_data = self._postprocess_batch(
540                data=deepcopy(data_),
541                crop_box=crop_box, original_size=self.original_size,
542                pred_iou_thresh=pred_iou_thresh,
543                stability_score_thresh=stability_score_thresh,
544                box_nms_thresh=box_nms_thresh
545            )
546            data.cat(crop_data)
547
548        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
549            # Prefer masks from smaller crops
550            scores = 1 / box_area(data["crop_boxes"])
551            scores = scores.to(data["boxes"].device)
552            keep_by_nms = batched_nms(
553                data["boxes"].float(),
554                scores,
555                torch.zeros_like(data["boxes"][:, 0]),  # categories
556                iou_threshold=crop_nms_thresh,
557            )
558            data.filter(keep_by_nms)
559
560        data.to_numpy()
561        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
562        return masks

Generates an instance segmentation without prompts, using a point grid.

This class implements the same logic as https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py It decouples the computationally expensive steps of generating masks from the cheap post-processing operation to filter these masks to enable grid search and interactively changing the post-processing.

Use this class as follows:

amg = AutomaticMaskGenerator(predictor)
amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
Arguments:
  • predictor: The segment anything predictor.
  • points_per_side: The number of points to be sampled along one side of the image. If None, point_grids must provide explicit point sampling.
  • points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
  • crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
  • crop_overlap_ratio: Sets the degree to which crops overlap.
  • crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
  • point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
  • stability_score_offset: The amount to shift the cutoff when calculating the stability score.
AutomaticMaskGenerator( predictor: segment_anything.predictor.SamPredictor, points_per_side: Optional[int] = 32, points_per_batch: Optional[int] = None, crop_n_layers: int = 0, crop_overlap_ratio: float = 0.3413333333333333, crop_n_points_downscale_factor: int = 1, point_grids: Optional[List[numpy.ndarray]] = None, stability_score_offset: float = 1.0)
367    def __init__(
368        self,
369        predictor: SamPredictor,
370        points_per_side: Optional[int] = 32,
371        points_per_batch: Optional[int] = None,
372        crop_n_layers: int = 0,
373        crop_overlap_ratio: float = 512 / 1500,
374        crop_n_points_downscale_factor: int = 1,
375        point_grids: Optional[List[np.ndarray]] = None,
376        stability_score_offset: float = 1.0,
377    ):
378        super().__init__()
379
380        if points_per_side is not None:
381            self.point_grids = amg_utils.build_all_layer_point_grids(
382                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
383            )
384        elif point_grids is not None:
385            self.point_grids = point_grids
386        else:
387            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
388
389        self._predictor = predictor
390        self._points_per_side = points_per_side
391
392        # we set the points per batch to 16 for mps for performance reasons
393        # and otherwise keep them at the default of 64
394        if points_per_batch is None:
395            points_per_batch = 16 if str(predictor.device) == "mps" else 64
396        self._points_per_batch = points_per_batch
397
398        self._crop_n_layers = crop_n_layers
399        self._crop_overlap_ratio = crop_overlap_ratio
400        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
401        self._stability_score_offset = stability_score_offset
@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
450    @torch.no_grad()
451    def initialize(
452        self,
453        image: np.ndarray,
454        image_embeddings: Optional[util.ImageEmbeddings] = None,
455        i: Optional[int] = None,
456        verbose: bool = False,
457        pbar_init: Optional[callable] = None,
458        pbar_update: Optional[callable] = None,
459    ) -> None:
460        """Initialize image embeddings and masks for an image.
461
462        Args:
463            image: The input image, volume or timeseries.
464            image_embeddings: Optional precomputed image embeddings.
465                See `util.precompute_image_embeddings` for details.
466            i: Index for the image data. Required if `image` has three spatial dimensions
467                or a time dimension and two spatial dimensions.
468            verbose: Whether to print computation progress.
469            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
470                Can be used together with pbar_update to handle napari progress bar in other thread.
471                To enables using this function within a threadworker.
472            pbar_update: Callback to update an external progress bar.
473        """
474        original_size = image.shape[:2]
475        self._original_size = original_size
476
477        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
478            original_size, self._crop_n_layers, self._crop_overlap_ratio
479        )
480
481        # We can set fixed image embeddings if we only have a single crop box (the default setting).
482        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
483        if len(crop_boxes) == 1:
484            if image_embeddings is None:
485                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
486            util.set_precomputed(self._predictor, image_embeddings, i=i)
487            precomputed_embeddings = True
488        else:
489            precomputed_embeddings = False
490
491        # we need to cast to the image representation that is compatible with SAM
492        image = util._to_image(image)
493
494        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
495
496        crop_list = []
497        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
498            crop_data = self._process_crop(
499                image, crop_box, layer_idx,
500                precomputed_embeddings=precomputed_embeddings,
501                pbar_init=pbar_init, pbar_update=pbar_update,
502            )
503            crop_list.append(crop_data)
504        pbar_close()
505
506        self._is_initialized = True
507        self._crop_list = crop_list
508        self._crop_boxes = crop_boxes

Initialize image embeddings and masks for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to print computation progress.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
@torch.no_grad()
def generate( self, pred_iou_thresh: float = 0.88, stability_score_thresh: float = 0.95, box_nms_thresh: float = 0.7, crop_nms_thresh: float = 0.7, min_mask_region_area: int = 0, output_mode: str = 'binary_mask') -> List[Dict[str, Any]]:
510    @torch.no_grad()
511    def generate(
512        self,
513        pred_iou_thresh: float = 0.88,
514        stability_score_thresh: float = 0.95,
515        box_nms_thresh: float = 0.7,
516        crop_nms_thresh: float = 0.7,
517        min_mask_region_area: int = 0,
518        output_mode: str = "binary_mask",
519    ) -> List[Dict[str, Any]]:
520        """Generate instance segmentation for the currently initialized image.
521
522        Args:
523            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
524            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
525                under changes to the cutoff used to binarize the model prediction.
526            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
527            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
528            min_mask_region_area: Minimal size for the predicted masks.
529            output_mode: The form masks are returned in.
530
531        Returns:
532            The instance segmentation masks.
533        """
534        if not self.is_initialized:
535            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
536
537        data = amg_utils.MaskData()
538        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
539            crop_data = self._postprocess_batch(
540                data=deepcopy(data_),
541                crop_box=crop_box, original_size=self.original_size,
542                pred_iou_thresh=pred_iou_thresh,
543                stability_score_thresh=stability_score_thresh,
544                box_nms_thresh=box_nms_thresh
545            )
546            data.cat(crop_data)
547
548        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
549            # Prefer masks from smaller crops
550            scores = 1 / box_area(data["crop_boxes"])
551            scores = scores.to(data["boxes"].device)
552            keep_by_nms = batched_nms(
553                data["boxes"].float(),
554                scores,
555                torch.zeros_like(data["boxes"][:, 0]),  # categories
556                iou_threshold=crop_nms_thresh,
557            )
558            data.filter(keep_by_nms)
559
560        data.to_numpy()
561        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
562        return masks

Generate instance segmentation for the currently initialized image.

Arguments:
  • pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
  • stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask under changes to the cutoff used to binarize the model prediction.
  • box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
  • crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
  • min_mask_region_area: Minimal size for the predicted masks.
  • output_mode: The form masks are returned in.
Returns:

The instance segmentation masks.

class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
592class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
593    """Generates an instance segmentation without prompts, using a point grid.
594
595    Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
596
597    Args:
598        predictor: The segment anything predictor.
599        points_per_side: The number of points to be sampled along one side of the image.
600            If None, `point_grids` must provide explicit point sampling.
601        points_per_batch: The number of points run simultaneously by the model.
602            Higher numbers may be faster but use more GPU memory.
603        point_grids: A lisst over explicit grids of points used for sampling masks.
604            Normalized to [0, 1] with respect to the image coordinate system.
605        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
606    """
607
608    # We only expose the arguments that make sense for the tiled mask generator.
609    # Anything related to crops doesn't make sense, because we re-use that functionality
610    # for tiling, so these parameters wouldn't have any effect.
611    def __init__(
612        self,
613        predictor: SamPredictor,
614        points_per_side: Optional[int] = 32,
615        points_per_batch: int = 64,
616        point_grids: Optional[List[np.ndarray]] = None,
617        stability_score_offset: float = 1.0,
618    ) -> None:
619        super().__init__(
620            predictor=predictor,
621            points_per_side=points_per_side,
622            points_per_batch=points_per_batch,
623            point_grids=point_grids,
624            stability_score_offset=stability_score_offset,
625        )
626
627    @torch.no_grad()
628    def initialize(
629        self,
630        image: np.ndarray,
631        image_embeddings: Optional[util.ImageEmbeddings] = None,
632        i: Optional[int] = None,
633        tile_shape: Optional[Tuple[int, int]] = None,
634        halo: Optional[Tuple[int, int]] = None,
635        verbose: bool = False,
636        pbar_init: Optional[callable] = None,
637        pbar_update: Optional[callable] = None,
638    ) -> None:
639        """Initialize image embeddings and masks for an image.
640
641        Args:
642            image: The input image, volume or timeseries.
643            image_embeddings: Optional precomputed image embeddings.
644                See `util.precompute_image_embeddings` for details.
645            i: Index for the image data. Required if `image` has three spatial dimensions
646                or a time dimension and two spatial dimensions.
647            tile_shape: The tile shape for embedding prediction.
648            halo: The overlap of between tiles.
649            verbose: Whether to print computation progress.
650            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
651                Can be used together with pbar_update to handle napari progress bar in other thread.
652                To enables using this function within a threadworker.
653            pbar_update: Callback to update an external progress bar.
654        """
655        original_size = image.shape[:2]
656        self._original_size = original_size
657
658        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
659            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose,
660        )
661
662        tiling = blocking([0, 0], original_size, tile_shape)
663        n_tiles = tiling.numberOfBlocks
664
665        # The crop box is always the full local tile.
666        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
667        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
668
669        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
670        pbar_init(n_tiles, "Compute masks for tile")
671
672        # We need to cast to the image representation that is compatible with SAM.
673        image = util._to_image(image)
674
675        mask_data = []
676        for tile_id in range(n_tiles):
677            # set the pre-computed embeddings for this tile
678            features = image_embeddings["features"][tile_id]
679            tile_embeddings = {
680                "features": features,
681                "input_size": features.attrs["input_size"],
682                "original_size": features.attrs["original_size"],
683            }
684            util.set_precomputed(self._predictor, tile_embeddings, i)
685
686            # compute the mask data for this tile and append it
687            this_mask_data = self._process_crop(
688                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
689            )
690            mask_data.append(this_mask_data)
691            pbar_update(1)
692        pbar_close()
693
694        # set the initialized data
695        self._is_initialized = True
696        self._crop_list = mask_data
697        self._crop_boxes = crop_boxes

Generates an instance segmentation without prompts, using a point grid.

Implements the same functionality as AutomaticMaskGenerator but for tiled embeddings.

Arguments:
  • predictor: The segment anything predictor.
  • points_per_side: The number of points to be sampled along one side of the image. If None, point_grids must provide explicit point sampling.
  • points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
  • point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
  • stability_score_offset: The amount to shift the cutoff when calculating the stability score.
TiledAutomaticMaskGenerator( predictor: segment_anything.predictor.SamPredictor, points_per_side: Optional[int] = 32, points_per_batch: int = 64, point_grids: Optional[List[numpy.ndarray]] = None, stability_score_offset: float = 1.0)
611    def __init__(
612        self,
613        predictor: SamPredictor,
614        points_per_side: Optional[int] = 32,
615        points_per_batch: int = 64,
616        point_grids: Optional[List[np.ndarray]] = None,
617        stability_score_offset: float = 1.0,
618    ) -> None:
619        super().__init__(
620            predictor=predictor,
621            points_per_side=points_per_side,
622            points_per_batch=points_per_batch,
623            point_grids=point_grids,
624            stability_score_offset=stability_score_offset,
625        )
@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
627    @torch.no_grad()
628    def initialize(
629        self,
630        image: np.ndarray,
631        image_embeddings: Optional[util.ImageEmbeddings] = None,
632        i: Optional[int] = None,
633        tile_shape: Optional[Tuple[int, int]] = None,
634        halo: Optional[Tuple[int, int]] = None,
635        verbose: bool = False,
636        pbar_init: Optional[callable] = None,
637        pbar_update: Optional[callable] = None,
638    ) -> None:
639        """Initialize image embeddings and masks for an image.
640
641        Args:
642            image: The input image, volume or timeseries.
643            image_embeddings: Optional precomputed image embeddings.
644                See `util.precompute_image_embeddings` for details.
645            i: Index for the image data. Required if `image` has three spatial dimensions
646                or a time dimension and two spatial dimensions.
647            tile_shape: The tile shape for embedding prediction.
648            halo: The overlap of between tiles.
649            verbose: Whether to print computation progress.
650            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
651                Can be used together with pbar_update to handle napari progress bar in other thread.
652                To enables using this function within a threadworker.
653            pbar_update: Callback to update an external progress bar.
654        """
655        original_size = image.shape[:2]
656        self._original_size = original_size
657
658        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
659            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose,
660        )
661
662        tiling = blocking([0, 0], original_size, tile_shape)
663        n_tiles = tiling.numberOfBlocks
664
665        # The crop box is always the full local tile.
666        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
667        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
668
669        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
670        pbar_init(n_tiles, "Compute masks for tile")
671
672        # We need to cast to the image representation that is compatible with SAM.
673        image = util._to_image(image)
674
675        mask_data = []
676        for tile_id in range(n_tiles):
677            # set the pre-computed embeddings for this tile
678            features = image_embeddings["features"][tile_id]
679            tile_embeddings = {
680                "features": features,
681                "input_size": features.attrs["input_size"],
682                "original_size": features.attrs["original_size"],
683            }
684            util.set_precomputed(self._predictor, tile_embeddings, i)
685
686            # compute the mask data for this tile and append it
687            this_mask_data = self._process_crop(
688                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
689            )
690            mask_data.append(this_mask_data)
691            pbar_update(1)
692        pbar_close()
693
694        # set the initialized data
695        self._is_initialized = True
696        self._crop_list = mask_data
697        self._crop_boxes = crop_boxes

Initialize image embeddings and masks for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_shape: The tile shape for embedding prediction.
  • halo: The overlap of between tiles.
  • verbose: Whether to print computation progress.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
class DecoderAdapter(torch.nn.modules.module.Module):
705class DecoderAdapter(torch.nn.Module):
706    """Adapter to contain the UNETR decoder in a single module.
707
708    To apply the decoder on top of pre-computed embeddings for the segmentation functionality.
709    See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
710    """
711    def __init__(self, unetr):
712        super().__init__()
713
714        self.base = unetr.base
715        self.out_conv = unetr.out_conv
716        self.deconv_out = unetr.deconv_out
717        self.decoder_head = unetr.decoder_head
718        self.final_activation = unetr.final_activation
719        self.postprocess_masks = unetr.postprocess_masks
720
721        self.decoder = unetr.decoder
722        self.deconv1 = unetr.deconv1
723        self.deconv2 = unetr.deconv2
724        self.deconv3 = unetr.deconv3
725        self.deconv4 = unetr.deconv4
726
727    def forward(self, input_, input_shape, original_shape):
728        z12 = input_
729
730        z9 = self.deconv1(z12)
731        z6 = self.deconv2(z9)
732        z3 = self.deconv3(z6)
733        z0 = self.deconv4(z3)
734
735        updated_from_encoder = [z9, z6, z3]
736
737        x = self.base(z12)
738        x = self.decoder(x, encoder_inputs=updated_from_encoder)
739        x = self.deconv_out(x)
740
741        x = torch.cat([x, z0], dim=1)
742        x = self.decoder_head(x)
743
744        x = self.out_conv(x)
745        if self.final_activation is not None:
746            x = self.final_activation(x)
747
748        x = self.postprocess_masks(x, input_shape, original_shape)
749        return x

Adapter to contain the UNETR decoder in a single module.

To apply the decoder on top of pre-computed embeddings for the segmentation functionality. See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py

DecoderAdapter(unetr)
711    def __init__(self, unetr):
712        super().__init__()
713
714        self.base = unetr.base
715        self.out_conv = unetr.out_conv
716        self.deconv_out = unetr.deconv_out
717        self.decoder_head = unetr.decoder_head
718        self.final_activation = unetr.final_activation
719        self.postprocess_masks = unetr.postprocess_masks
720
721        self.decoder = unetr.decoder
722        self.deconv1 = unetr.deconv1
723        self.deconv2 = unetr.deconv2
724        self.deconv3 = unetr.deconv3
725        self.deconv4 = unetr.deconv4

Initialize internal Module state, shared by both nn.Module and ScriptModule.

base
out_conv
deconv_out
decoder_head
final_activation
postprocess_masks
decoder
deconv1
deconv2
deconv3
deconv4
def forward(self, input_, input_shape, original_shape):
727    def forward(self, input_, input_shape, original_shape):
728        z12 = input_
729
730        z9 = self.deconv1(z12)
731        z6 = self.deconv2(z9)
732        z3 = self.deconv3(z6)
733        z0 = self.deconv4(z3)
734
735        updated_from_encoder = [z9, z6, z3]
736
737        x = self.base(z12)
738        x = self.decoder(x, encoder_inputs=updated_from_encoder)
739        x = self.deconv_out(x)
740
741        x = torch.cat([x, z0], dim=1)
742        x = self.decoder_head(x)
743
744        x = self.out_conv(x)
745        if self.final_activation is not None:
746            x = self.final_activation(x)
747
748        x = self.postprocess_masks(x, input_shape, original_shape)
749        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def get_unetr( image_encoder: torch.nn.modules.module.Module, decoder_state: Optional[collections.OrderedDict[str, torch.Tensor]] = None, device: Union[str, torch.device, NoneType] = None, out_channels: int = 3, flexible_load_checkpoint: bool = False) -> torch.nn.modules.module.Module:
752def get_unetr(
753    image_encoder: torch.nn.Module,
754    decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
755    device: Optional[Union[str, torch.device]] = None,
756    out_channels: int = 3,
757    flexible_load_checkpoint: bool = False,
758) -> torch.nn.Module:
759    """Get UNETR model for automatic instance segmentation.
760
761    Args:
762        image_encoder: The image encoder of the SAM model.
763            This is used as encoder by the UNETR too.
764        decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
765        device: The device.
766        out_channels: The number of output channels.
767        flexible_load_checkpoint: Whether to allow reinitialization of parameters
768            which could not be found in the provided decoder state.
769
770    Returns:
771        The UNETR model.
772    """
773    device = util.get_device(device)
774
775    unetr = UNETR(
776        backbone="sam",
777        encoder=image_encoder,
778        out_channels=out_channels,
779        use_sam_stats=True,
780        final_activation="Sigmoid",
781        use_skip_connection=False,
782        resize_input=True,
783    )
784    if decoder_state is not None:
785        unetr_state_dict = unetr.state_dict()
786        for k, v in unetr_state_dict.items():
787            if not k.startswith("encoder"):
788                if flexible_load_checkpoint:  # Whether allow reinitalization of params, if not found.
789                    if k in decoder_state:  # First check whether the key is available in the provided decoder state.
790                        unetr_state_dict[k] = decoder_state[k]
791                    else:  # Otherwise, allow it to initialize it.
792                        warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.")
793                        unetr_state_dict[k] = v
794
795                else:  # Whether be strict on finding the parameter in the decoder state.
796                    if k not in decoder_state:
797                        raise RuntimeError(f"The parameters for '{k}' could not be found.")
798                    unetr_state_dict[k] = decoder_state[k]
799
800        unetr.load_state_dict(unetr_state_dict)
801
802    unetr.to(device)
803    return unetr

Get UNETR model for automatic instance segmentation.

Arguments:
  • image_encoder: The image encoder of the SAM model. This is used as encoder by the UNETR too.
  • decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
  • device: The device.
  • out_channels: The number of output channels.
  • flexible_load_checkpoint: Whether to allow reinitialization of parameters which could not be found in the provided decoder state.
Returns:

The UNETR model.

def get_decoder( image_encoder: torch.nn.modules.module.Module, decoder_state: collections.OrderedDict[str, torch.Tensor], device: Union[str, torch.device, NoneType] = None) -> DecoderAdapter:
806def get_decoder(
807    image_encoder: torch.nn.Module,
808    decoder_state: OrderedDict[str, torch.Tensor],
809    device: Optional[Union[str, torch.device]] = None,
810) -> DecoderAdapter:
811    """Get decoder to predict outputs for automatic instance segmentation
812
813    Args:
814        image_encoder: The image encoder of the SAM model.
815        decoder_state: State to initialize the weights of the UNETR decoder.
816        device: The device.
817
818    Returns:
819        The decoder for instance segmentation.
820    """
821    unetr = get_unetr(image_encoder, decoder_state, device)
822    return DecoderAdapter(unetr)

Get decoder to predict outputs for automatic instance segmentation

Arguments:
  • image_encoder: The image encoder of the SAM model.
  • decoder_state: State to initialize the weights of the UNETR decoder.
  • device: The device.
Returns:

The decoder for instance segmentation.

def get_predictor_and_decoder( model_type: str, checkpoint_path: Union[str, os.PathLike], device: Union[str, torch.device, NoneType] = None, peft_kwargs: Optional[Dict] = None) -> Tuple[segment_anything.predictor.SamPredictor, DecoderAdapter]:
825def get_predictor_and_decoder(
826    model_type: str,
827    checkpoint_path: Union[str, os.PathLike],
828    device: Optional[Union[str, torch.device]] = None,
829    peft_kwargs: Optional[Dict] = None,
830) -> Tuple[SamPredictor, DecoderAdapter]:
831    """Load the SAM model (predictor) and instance segmentation decoder.
832
833    This requires a checkpoint that contains the state for both predictor
834    and decoder.
835
836    Args:
837        model_type: The type of the image encoder used in the SAM model.
838        checkpoint_path: Path to the checkpoint from which to load the data.
839        device: The device.
840        peft_kwargs: Keyword arguments for the PEFT wrapper class.
841
842    Returns:
843        The SAM predictor.
844        The decoder for instance segmentation.
845    """
846    device = util.get_device(device)
847    predictor, state = util.get_sam_model(
848        model_type=model_type,
849        checkpoint_path=checkpoint_path,
850        device=device,
851        return_state=True,
852        peft_kwargs=peft_kwargs,
853    )
854    if "decoder_state" not in state:
855        raise ValueError(
856            f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
857        )
858    decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
859    return predictor, decoder

Load the SAM model (predictor) and instance segmentation decoder.

This requires a checkpoint that contains the state for both predictor and decoder.

Arguments:
  • model_type: The type of the image encoder used in the SAM model.
  • checkpoint_path: Path to the checkpoint from which to load the data.
  • device: The device.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
Returns:

The SAM predictor. The decoder for instance segmentation.

class InstanceSegmentationWithDecoder:
 911class InstanceSegmentationWithDecoder:
 912    """Generates an instance segmentation without prompts, using a decoder.
 913
 914    Implements the same interface as `AutomaticMaskGenerator`.
 915
 916    Use this class as follows:
 917    ```python
 918    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
 919    segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
 920    masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
 921    ```
 922
 923    Args:
 924        predictor: The segment anything predictor.
 925        decoder: The decoder to predict intermediate representations
 926            for instance segmentation.
 927    """
 928    def __init__(
 929        self,
 930        predictor: SamPredictor,
 931        decoder: torch.nn.Module,
 932    ) -> None:
 933        self._predictor = predictor
 934        self._decoder = decoder
 935
 936        # The decoder outputs.
 937        self._foreground = None
 938        self._center_distances = None
 939        self._boundary_distances = None
 940
 941        self._is_initialized = False
 942
 943    @property
 944    def is_initialized(self):
 945        """Whether the mask generator has already been initialized.
 946        """
 947        return self._is_initialized
 948
 949    @torch.no_grad()
 950    def initialize(
 951        self,
 952        image: np.ndarray,
 953        image_embeddings: Optional[util.ImageEmbeddings] = None,
 954        i: Optional[int] = None,
 955        verbose: bool = False,
 956        pbar_init: Optional[callable] = None,
 957        pbar_update: Optional[callable] = None,
 958    ) -> None:
 959        """Initialize image embeddings and decoder predictions for an image.
 960
 961        Args:
 962            image: The input image, volume or timeseries.
 963            image_embeddings: Optional precomputed image embeddings.
 964                See `util.precompute_image_embeddings` for details.
 965            i: Index for the image data. Required if `image` has three spatial dimensions
 966                or a time dimension and two spatial dimensions.
 967            verbose: Whether to be verbose.
 968            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 969                Can be used together with pbar_update to handle napari progress bar in other thread.
 970                To enables using this function within a threadworker.
 971            pbar_update: Callback to update an external progress bar.
 972        """
 973        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 974        pbar_init(1, "Initialize instance segmentation with decoder")
 975
 976        if image_embeddings is None:
 977            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 978
 979        # Get the image embeddings from the predictor.
 980        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
 981        embeddings = self._predictor.features
 982        input_shape = tuple(self._predictor.input_size)
 983        original_shape = tuple(self._predictor.original_size)
 984
 985        # Run prediction with the UNETR decoder.
 986        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
 987        assert output.shape[0] == 3, f"{output.shape}"
 988        pbar_update(1)
 989        pbar_close()
 990
 991        # Set the state.
 992        self._foreground = output[0]
 993        self._center_distances = output[1]
 994        self._boundary_distances = output[2]
 995        self._is_initialized = True
 996
 997    def _to_masks(self, segmentation, output_mode):
 998        if output_mode != "binary_mask":
 999            raise NotImplementedError
1000
1001        props = regionprops(segmentation)
1002        ndim = segmentation.ndim
1003        assert ndim in (2, 3)
1004
1005        shape = segmentation.shape
1006        if ndim == 2:
1007            crop_box = [0, shape[1], 0, shape[0]]
1008        else:
1009            crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]
1010
1011        # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
1012        def to_bbox_2d(bbox):
1013            y0, x0 = bbox[0], bbox[1]
1014            w = bbox[3] - x0
1015            h = bbox[2] - y0
1016            return [x0, w, y0, h]
1017
1018        def to_bbox_3d(bbox):
1019            z0, y0, x0 = bbox[0], bbox[1], bbox[2]
1020            w = bbox[5] - x0
1021            h = bbox[4] - y0
1022            d = bbox[3] - y0
1023            return [x0, w, y0, h, z0, d]
1024
1025        to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
1026        masks = [
1027            {
1028                "segmentation": segmentation == prop.label,
1029                "area": prop.area,
1030                "bbox": to_bbox(prop.bbox),
1031                "crop_box": crop_box,
1032                "seg_id": prop.label,
1033            } for prop in props
1034        ]
1035        return masks
1036
1037    def generate(
1038        self,
1039        center_distance_threshold: float = 0.5,
1040        boundary_distance_threshold: float = 0.5,
1041        foreground_threshold: float = 0.5,
1042        foreground_smoothing: float = 1.0,
1043        distance_smoothing: float = 1.6,
1044        min_size: int = 0,
1045        output_mode: Optional[str] = "binary_mask",
1046        tile_shape: Optional[Tuple[int, int]] = None,
1047        halo: Optional[Tuple[int, int]] = None,
1048        n_threads: Optional[int] = None,
1049    ) -> List[Dict[str, Any]]:
1050        """Generate instance segmentation for the currently initialized image.
1051
1052        Args:
1053            center_distance_threshold: Center distance predictions below this value will be
1054                used to find seeds (intersected with thresholded boundary distance predictions).
1055            boundary_distance_threshold: Boundary distance predictions below this value will be
1056                used to find seeds (intersected with thresholded center distance predictions).
1057            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1058                checkerboard artifacts in the prediction.
1059            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1060            distance_smoothing: Sigma value for smoothing the distance predictions.
1061            min_size: Minimal object size in the segmentation result.
1062            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
1063            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1064                This parameter is independent from the tile shape for computing the embeddings.
1065                If not given then post-processing will not be parallelized.
1066            halo: Halo for parallel post-processing. See also `tile_shape`.
1067            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1068
1069        Returns:
1070            The instance segmentation masks.
1071        """
1072        if not self.is_initialized:
1073            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1074
1075        if foreground_smoothing > 0:
1076            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1077        else:
1078            foreground = self._foreground
1079
1080        if tile_shape is None:
1081            segmentation = watershed_from_center_and_boundary_distances(
1082                center_distances=self._center_distances,
1083                boundary_distances=self._boundary_distances,
1084                foreground_map=foreground,
1085                center_distance_threshold=center_distance_threshold,
1086                boundary_distance_threshold=boundary_distance_threshold,
1087                foreground_threshold=foreground_threshold,
1088                distance_smoothing=distance_smoothing,
1089                min_size=min_size,
1090            )
1091        else:
1092            if halo is None:
1093                raise ValueError("You must pass a value for halo if tile_shape is given.")
1094            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1095                center_distances=self._center_distances,
1096                boundary_distances=self._boundary_distances,
1097                foreground_map=foreground,
1098                center_distance_threshold=center_distance_threshold,
1099                boundary_distance_threshold=boundary_distance_threshold,
1100                foreground_threshold=foreground_threshold,
1101                distance_smoothing=distance_smoothing,
1102                min_size=min_size,
1103                tile_shape=tile_shape,
1104                halo=halo,
1105                n_threads=n_threads,
1106                verbose=False,
1107            )
1108
1109        if output_mode is not None:
1110            segmentation = self._to_masks(segmentation, output_mode)
1111        return segmentation
1112
1113    def get_state(self) -> Dict[str, Any]:
1114        """Get the initialized state of the instance segmenter.
1115
1116        Returns:
1117            Instance segmentation state.
1118        """
1119        if not self.is_initialized:
1120            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1121
1122        return {
1123            "foreground": self._foreground,
1124            "center_distances": self._center_distances,
1125            "boundary_distances": self._boundary_distances,
1126        }
1127
1128    def set_state(self, state: Dict[str, Any]) -> None:
1129        """Set the state of the instance segmenter.
1130
1131        Args:
1132            state: The instance segmentation state
1133        """
1134        self._foreground = state["foreground"]
1135        self._center_distances = state["center_distances"]
1136        self._boundary_distances = state["boundary_distances"]
1137        self._is_initialized = True
1138
1139    def clear_state(self):
1140        """Clear the state of the instance segmenter.
1141        """
1142        self._foreground = None
1143        self._center_distances = None
1144        self._boundary_distances = None
1145        self._is_initialized = False

Generates an instance segmentation without prompts, using a decoder.

Implements the same interface as AutomaticMaskGenerator.

Use this class as follows:

segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
Arguments:
  • predictor: The segment anything predictor.
  • decoder: The decoder to predict intermediate representations for instance segmentation.
InstanceSegmentationWithDecoder( predictor: segment_anything.predictor.SamPredictor, decoder: torch.nn.modules.module.Module)
928    def __init__(
929        self,
930        predictor: SamPredictor,
931        decoder: torch.nn.Module,
932    ) -> None:
933        self._predictor = predictor
934        self._decoder = decoder
935
936        # The decoder outputs.
937        self._foreground = None
938        self._center_distances = None
939        self._boundary_distances = None
940
941        self._is_initialized = False
is_initialized
943    @property
944    def is_initialized(self):
945        """Whether the mask generator has already been initialized.
946        """
947        return self._is_initialized

Whether the mask generator has already been initialized.

@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
949    @torch.no_grad()
950    def initialize(
951        self,
952        image: np.ndarray,
953        image_embeddings: Optional[util.ImageEmbeddings] = None,
954        i: Optional[int] = None,
955        verbose: bool = False,
956        pbar_init: Optional[callable] = None,
957        pbar_update: Optional[callable] = None,
958    ) -> None:
959        """Initialize image embeddings and decoder predictions for an image.
960
961        Args:
962            image: The input image, volume or timeseries.
963            image_embeddings: Optional precomputed image embeddings.
964                See `util.precompute_image_embeddings` for details.
965            i: Index for the image data. Required if `image` has three spatial dimensions
966                or a time dimension and two spatial dimensions.
967            verbose: Whether to be verbose.
968            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
969                Can be used together with pbar_update to handle napari progress bar in other thread.
970                To enables using this function within a threadworker.
971            pbar_update: Callback to update an external progress bar.
972        """
973        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
974        pbar_init(1, "Initialize instance segmentation with decoder")
975
976        if image_embeddings is None:
977            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
978
979        # Get the image embeddings from the predictor.
980        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
981        embeddings = self._predictor.features
982        input_shape = tuple(self._predictor.input_size)
983        original_shape = tuple(self._predictor.original_size)
984
985        # Run prediction with the UNETR decoder.
986        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
987        assert output.shape[0] == 3, f"{output.shape}"
988        pbar_update(1)
989        pbar_close()
990
991        # Set the state.
992        self._foreground = output[0]
993        self._center_distances = output[1]
994        self._boundary_distances = output[2]
995        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to be verbose.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
def generate( self, center_distance_threshold: float = 0.5, boundary_distance_threshold: float = 0.5, foreground_threshold: float = 0.5, foreground_smoothing: float = 1.0, distance_smoothing: float = 1.6, min_size: int = 0, output_mode: Optional[str] = 'binary_mask', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, n_threads: Optional[int] = None) -> List[Dict[str, Any]]:
1037    def generate(
1038        self,
1039        center_distance_threshold: float = 0.5,
1040        boundary_distance_threshold: float = 0.5,
1041        foreground_threshold: float = 0.5,
1042        foreground_smoothing: float = 1.0,
1043        distance_smoothing: float = 1.6,
1044        min_size: int = 0,
1045        output_mode: Optional[str] = "binary_mask",
1046        tile_shape: Optional[Tuple[int, int]] = None,
1047        halo: Optional[Tuple[int, int]] = None,
1048        n_threads: Optional[int] = None,
1049    ) -> List[Dict[str, Any]]:
1050        """Generate instance segmentation for the currently initialized image.
1051
1052        Args:
1053            center_distance_threshold: Center distance predictions below this value will be
1054                used to find seeds (intersected with thresholded boundary distance predictions).
1055            boundary_distance_threshold: Boundary distance predictions below this value will be
1056                used to find seeds (intersected with thresholded center distance predictions).
1057            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1058                checkerboard artifacts in the prediction.
1059            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1060            distance_smoothing: Sigma value for smoothing the distance predictions.
1061            min_size: Minimal object size in the segmentation result.
1062            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
1063            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1064                This parameter is independent from the tile shape for computing the embeddings.
1065                If not given then post-processing will not be parallelized.
1066            halo: Halo for parallel post-processing. See also `tile_shape`.
1067            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1068
1069        Returns:
1070            The instance segmentation masks.
1071        """
1072        if not self.is_initialized:
1073            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1074
1075        if foreground_smoothing > 0:
1076            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1077        else:
1078            foreground = self._foreground
1079
1080        if tile_shape is None:
1081            segmentation = watershed_from_center_and_boundary_distances(
1082                center_distances=self._center_distances,
1083                boundary_distances=self._boundary_distances,
1084                foreground_map=foreground,
1085                center_distance_threshold=center_distance_threshold,
1086                boundary_distance_threshold=boundary_distance_threshold,
1087                foreground_threshold=foreground_threshold,
1088                distance_smoothing=distance_smoothing,
1089                min_size=min_size,
1090            )
1091        else:
1092            if halo is None:
1093                raise ValueError("You must pass a value for halo if tile_shape is given.")
1094            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1095                center_distances=self._center_distances,
1096                boundary_distances=self._boundary_distances,
1097                foreground_map=foreground,
1098                center_distance_threshold=center_distance_threshold,
1099                boundary_distance_threshold=boundary_distance_threshold,
1100                foreground_threshold=foreground_threshold,
1101                distance_smoothing=distance_smoothing,
1102                min_size=min_size,
1103                tile_shape=tile_shape,
1104                halo=halo,
1105                n_threads=n_threads,
1106                verbose=False,
1107            )
1108
1109        if output_mode is not None:
1110            segmentation = self._to_masks(segmentation, output_mode)
1111        return segmentation

Generate instance segmentation for the currently initialized image.

Arguments:
  • center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions).
  • boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions).
  • foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction.
  • foreground_threshold: Foreground predictions above this value will be used as foreground mask.
  • distance_smoothing: Sigma value for smoothing the distance predictions.
  • min_size: Minimal object size in the segmentation result.
  • output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
  • tile_shape: Tile shape for parallelizing the instance segmentation post-processing. This parameter is independent from the tile shape for computing the embeddings. If not given then post-processing will not be parallelized.
  • halo: Halo for parallel post-processing. See also tile_shape.
  • n_threads: Number of threads for parallel post-processing. See also tile_shape.
Returns:

The instance segmentation masks.

def get_state(self) -> Dict[str, Any]:
1113    def get_state(self) -> Dict[str, Any]:
1114        """Get the initialized state of the instance segmenter.
1115
1116        Returns:
1117            Instance segmentation state.
1118        """
1119        if not self.is_initialized:
1120            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1121
1122        return {
1123            "foreground": self._foreground,
1124            "center_distances": self._center_distances,
1125            "boundary_distances": self._boundary_distances,
1126        }

Get the initialized state of the instance segmenter.

Returns:

Instance segmentation state.

def set_state(self, state: Dict[str, Any]) -> None:
1128    def set_state(self, state: Dict[str, Any]) -> None:
1129        """Set the state of the instance segmenter.
1130
1131        Args:
1132            state: The instance segmentation state
1133        """
1134        self._foreground = state["foreground"]
1135        self._center_distances = state["center_distances"]
1136        self._boundary_distances = state["boundary_distances"]
1137        self._is_initialized = True

Set the state of the instance segmenter.

Arguments:
  • state: The instance segmentation state
def clear_state(self):
1139    def clear_state(self):
1140        """Clear the state of the instance segmenter.
1141        """
1142        self._foreground = None
1143        self._center_distances = None
1144        self._boundary_distances = None
1145        self._is_initialized = False

Clear the state of the instance segmenter.

class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1148class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1149    """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings.
1150    """
1151
1152    @torch.no_grad()
1153    def initialize(
1154        self,
1155        image: np.ndarray,
1156        image_embeddings: Optional[util.ImageEmbeddings] = None,
1157        i: Optional[int] = None,
1158        tile_shape: Optional[Tuple[int, int]] = None,
1159        halo: Optional[Tuple[int, int]] = None,
1160        verbose: bool = False,
1161        pbar_init: Optional[callable] = None,
1162        pbar_update: Optional[callable] = None,
1163    ) -> None:
1164        """Initialize image embeddings and decoder predictions for an image.
1165
1166        Args:
1167            image: The input image, volume or timeseries.
1168            image_embeddings: Optional precomputed image embeddings.
1169                See `util.precompute_image_embeddings` for details.
1170            i: Index for the image data. Required if `image` has three spatial dimensions
1171                or a time dimension and two spatial dimensions.
1172            tile_shape: Shape of the tiles for precomputing image embeddings.
1173            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1174            verbose: Dummy input to be compatible with other function signatures.
1175            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1176                Can be used together with pbar_update to handle napari progress bar in other thread.
1177                To enables using this function within a threadworker.
1178            pbar_update: Callback to update an external progress bar.
1179        """
1180        original_size = image.shape[:2]
1181        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1182            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose,
1183        )
1184        tiling = blocking([0, 0], original_size, tile_shape)
1185
1186        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1187        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1188
1189        foreground = np.zeros(original_size, dtype="float32")
1190        center_distances = np.zeros(original_size, dtype="float32")
1191        boundary_distances = np.zeros(original_size, dtype="float32")
1192
1193        for tile_id in range(tiling.numberOfBlocks):
1194
1195            # Get the image embeddings from the predictor for this tile.
1196            self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1197            embeddings = self._predictor.features
1198            input_shape = tuple(self._predictor.input_size)
1199            original_shape = tuple(self._predictor.original_size)
1200
1201            # Predict with the UNETR decoder for this tile.
1202            output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1203            assert output.shape[0] == 3, f"{output.shape}"
1204
1205            # Set the predictions in the output for this tile.
1206            block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1207            local_bb = tuple(
1208                slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1209            )
1210            inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1211
1212            foreground[inner_bb] = output[0][local_bb]
1213            center_distances[inner_bb] = output[1][local_bb]
1214            boundary_distances[inner_bb] = output[2][local_bb]
1215            pbar_update(1)
1216
1217        pbar_close()
1218
1219        # Set the state.
1220        self._foreground = foreground
1221        self._center_distances = center_distances
1222        self._boundary_distances = boundary_distances
1223        self._is_initialized = True

Same as InstanceSegmentationWithDecoder but for tiled image embeddings.

@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
1152    @torch.no_grad()
1153    def initialize(
1154        self,
1155        image: np.ndarray,
1156        image_embeddings: Optional[util.ImageEmbeddings] = None,
1157        i: Optional[int] = None,
1158        tile_shape: Optional[Tuple[int, int]] = None,
1159        halo: Optional[Tuple[int, int]] = None,
1160        verbose: bool = False,
1161        pbar_init: Optional[callable] = None,
1162        pbar_update: Optional[callable] = None,
1163    ) -> None:
1164        """Initialize image embeddings and decoder predictions for an image.
1165
1166        Args:
1167            image: The input image, volume or timeseries.
1168            image_embeddings: Optional precomputed image embeddings.
1169                See `util.precompute_image_embeddings` for details.
1170            i: Index for the image data. Required if `image` has three spatial dimensions
1171                or a time dimension and two spatial dimensions.
1172            tile_shape: Shape of the tiles for precomputing image embeddings.
1173            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1174            verbose: Dummy input to be compatible with other function signatures.
1175            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1176                Can be used together with pbar_update to handle napari progress bar in other thread.
1177                To enables using this function within a threadworker.
1178            pbar_update: Callback to update an external progress bar.
1179        """
1180        original_size = image.shape[:2]
1181        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1182            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose,
1183        )
1184        tiling = blocking([0, 0], original_size, tile_shape)
1185
1186        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1187        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1188
1189        foreground = np.zeros(original_size, dtype="float32")
1190        center_distances = np.zeros(original_size, dtype="float32")
1191        boundary_distances = np.zeros(original_size, dtype="float32")
1192
1193        for tile_id in range(tiling.numberOfBlocks):
1194
1195            # Get the image embeddings from the predictor for this tile.
1196            self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1197            embeddings = self._predictor.features
1198            input_shape = tuple(self._predictor.input_size)
1199            original_shape = tuple(self._predictor.original_size)
1200
1201            # Predict with the UNETR decoder for this tile.
1202            output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1203            assert output.shape[0] == 3, f"{output.shape}"
1204
1205            # Set the predictions in the output for this tile.
1206            block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1207            local_bb = tuple(
1208                slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1209            )
1210            inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1211
1212            foreground[inner_bb] = output[0][local_bb]
1213            center_distances[inner_bb] = output[1][local_bb]
1214            boundary_distances[inner_bb] = output[2][local_bb]
1215            pbar_update(1)
1216
1217        pbar_close()
1218
1219        # Set the state.
1220        self._foreground = foreground
1221        self._center_distances = center_distances
1222        self._boundary_distances = boundary_distances
1223        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_shape: Shape of the tiles for precomputing image embeddings.
  • halo: Overlap of the tiles for tiled precomputation of image embeddings.
  • verbose: Dummy input to be compatible with other function signatures.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
def get_amg( predictor: segment_anything.predictor.SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.modules.module.Module] = None, **kwargs) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1226def get_amg(
1227    predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs,
1228) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1229    """Get the automatic mask generator class.
1230
1231    Args:
1232        predictor: The segment anything predictor.
1233        is_tiled: Whether tiled embeddings are used.
1234        decoder: Decoder to predict instacne segmmentation.
1235        kwargs: The keyword arguments for the amg class.
1236
1237    Returns:
1238        The automatic mask generator.
1239    """
1240    if decoder is None:
1241        segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator
1242        segmenter = segmenter_class(predictor, **kwargs)
1243    else:
1244        segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder
1245        segmenter = segmenter_class(predictor, decoder, **kwargs)
1246
1247    return segmenter

Get the automatic mask generator class.

Arguments:
  • predictor: The segment anything predictor.
  • is_tiled: Whether tiled embeddings are used.
  • decoder: Decoder to predict instacne segmmentation.
  • kwargs: The keyword arguments for the amg class.
Returns:

The automatic mask generator.