micro_sam.instance_segmentation

Automated instance segmentation functionality. The classes implemented here extend the automatic instance segmentation from Segment Anything: https://computational-cell-analytics.github.io/micro-sam/micro_sam.html

   1"""Automated instance segmentation functionality.
   2The classes implemented here extend the automatic instance segmentation from Segment Anything:
   3https://computational-cell-analytics.github.io/micro-sam/micro_sam.html
   4"""
   5
   6import os
   7import warnings
   8from abc import ABC
   9from copy import deepcopy
  10from collections import OrderedDict
  11from typing import Any, Dict, List, Optional, Tuple, Union
  12
  13import vigra
  14import numpy as np
  15from skimage.measure import label, regionprops
  16from skimage.segmentation import relabel_sequential
  17
  18import torch
  19from torchvision.ops.boxes import batched_nms, box_area
  20
  21from torch_em.model import UNETR
  22from torch_em.util.segmentation import watershed_from_center_and_boundary_distances
  23
  24import elf.parallel as parallel
  25from elf.parallel.filters import apply_filter
  26
  27from nifty.tools import blocking
  28
  29import segment_anything.utils.amg as amg_utils
  30from segment_anything.predictor import SamPredictor
  31
  32from . import util
  33from ._vendored import batched_mask_to_box, mask_to_rle_pytorch
  34
  35#
  36# Utility Functionality
  37#
  38
  39
  40class _FakeInput:
  41    def __init__(self, shape):
  42        self.shape = shape
  43
  44    def __getitem__(self, index):
  45        block_shape = tuple(ind.stop - ind.start for ind in index)
  46        return np.zeros(block_shape, dtype="float32")
  47
  48
  49def mask_data_to_segmentation(
  50    masks: List[Dict[str, Any]],
  51    with_background: bool,
  52    min_object_size: int = 0,
  53    max_object_size: Optional[int] = None,
  54    label_masks: bool = True,
  55) -> np.ndarray:
  56    """Convert the output of the automatic mask generation to an instance segmentation.
  57
  58    Args:
  59        masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator.
  60            Only supports output_mode=binary_mask.
  61        with_background: Whether the segmentation has background. If yes this function assures that the largest
  62            object in the output will be mapped to zero (the background value).
  63        min_object_size: The minimal size of an object in pixels.
  64        max_object_size: The maximal size of an object in pixels.
  65        label_masks: Whether to apply connected components to the result before removing small objects.
  66
  67    Returns:
  68        The instance segmentation.
  69    """
  70
  71    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
  72    # we could also get the shape from the crop box
  73    shape = next(iter(masks))["segmentation"].shape
  74    segmentation = np.zeros(shape, dtype="uint32")
  75
  76    def require_numpy(mask):
  77        return mask.cpu().numpy() if torch.is_tensor(mask) else mask
  78
  79    seg_id = 1
  80    for mask in masks:
  81        if mask["area"] < min_object_size:
  82            continue
  83        if max_object_size is not None and mask["area"] > max_object_size:
  84            continue
  85
  86        this_seg_id = mask.get("seg_id", seg_id)
  87        segmentation[require_numpy(mask["segmentation"])] = this_seg_id
  88        seg_id = this_seg_id + 1
  89
  90    if label_masks:
  91        segmentation = label(segmentation).astype(segmentation.dtype)
  92
  93    seg_ids, sizes = np.unique(segmentation, return_counts=True)
  94
  95    # In some cases objects may be smaller than peviously calculated,
  96    # since they are covered by other objects. We ensure these also get
  97    # filtered out here.
  98    filter_ids = seg_ids[sizes < min_object_size]
  99
 100    # If we run segmentation with background we also map the largest segment
 101    # (the most likely background object) to zero. This is often zero already,
 102    # but it does not hurt to reset that to zero either.
 103    if with_background:
 104        bg_id = seg_ids[np.argmax(sizes)]
 105        filter_ids = np.concatenate([filter_ids, [bg_id]])
 106
 107    segmentation[np.isin(segmentation, filter_ids)] = 0
 108    segmentation = relabel_sequential(segmentation)[0]
 109
 110    return segmentation
 111
 112
 113#
 114# Classes for automatic instance segmentation
 115#
 116
 117
 118class AMGBase(ABC):
 119    """Base class for the automatic mask generators.
 120    """
 121    def __init__(self):
 122        # the state that has to be computed by the 'initialize' method of the child classes
 123        self._is_initialized = False
 124        self._crop_list = None
 125        self._crop_boxes = None
 126        self._original_size = None
 127
 128    @property
 129    def is_initialized(self):
 130        """Whether the mask generator has already been initialized.
 131        """
 132        return self._is_initialized
 133
 134    @property
 135    def crop_list(self):
 136        """The list of mask data after initialization.
 137        """
 138        return self._crop_list
 139
 140    @property
 141    def crop_boxes(self):
 142        """The list of crop boxes.
 143        """
 144        return self._crop_boxes
 145
 146    @property
 147    def original_size(self):
 148        """The original image size.
 149        """
 150        return self._original_size
 151
 152    def _postprocess_batch(
 153        self,
 154        data,
 155        crop_box,
 156        original_size,
 157        pred_iou_thresh,
 158        stability_score_thresh,
 159        box_nms_thresh,
 160    ):
 161        orig_h, orig_w = original_size
 162
 163        # filter by predicted IoU
 164        if pred_iou_thresh > 0.0:
 165            keep_mask = data["iou_preds"] > pred_iou_thresh
 166            data.filter(keep_mask)
 167
 168        # filter by stability score
 169        if stability_score_thresh > 0.0:
 170            keep_mask = data["stability_score"] >= stability_score_thresh
 171            data.filter(keep_mask)
 172
 173        # filter boxes that touch crop boundaries
 174        keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
 175        if not torch.all(keep_mask):
 176            data.filter(keep_mask)
 177
 178        # remove duplicates within this crop.
 179        keep_by_nms = batched_nms(
 180            data["boxes"].float(),
 181            data["iou_preds"],
 182            torch.zeros_like(data["boxes"][:, 0]),  # categories
 183            iou_threshold=box_nms_thresh,
 184        )
 185        data.filter(keep_by_nms)
 186
 187        # return to the original image frame
 188        data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box)
 189        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
 190        # the data from embedding based segmentation doesn't have the points
 191        # so we skip if the corresponding key can't be found
 192        try:
 193            data["points"] = amg_utils.uncrop_points(data["points"], crop_box)
 194        except KeyError:
 195            pass
 196
 197        return data
 198
 199    def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
 200
 201        if len(mask_data["rles"]) == 0:
 202            return mask_data
 203
 204        # filter small disconnected regions and holes
 205        new_masks = []
 206        scores = []
 207        for rle in mask_data["rles"]:
 208            mask = amg_utils.rle_to_mask(rle)
 209
 210            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes")
 211            unchanged = not changed
 212            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands")
 213            unchanged = unchanged and not changed
 214
 215            new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0))
 216            # give score=0 to changed masks and score=1 to unchanged masks
 217            # so NMS will prefer ones that didn't need postprocessing
 218            scores.append(float(unchanged))
 219
 220        # recalculate boxes and remove any new duplicates
 221        masks = torch.cat(new_masks, dim=0)
 222        boxes = batched_mask_to_box(masks.to(torch.bool))  # Casting this to boolean as we work with one-hot labels.
 223        keep_by_nms = batched_nms(
 224            boxes.float(),
 225            torch.as_tensor(scores, dtype=torch.float),
 226            torch.zeros_like(boxes[:, 0]),  # categories
 227            iou_threshold=nms_thresh,
 228        )
 229
 230        # only recalculate RLEs for masks that have changed
 231        for i_mask in keep_by_nms:
 232            if scores[i_mask] == 0.0:
 233                mask_torch = masks[i_mask].unsqueeze(0)
 234                # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
 235                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
 236                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
 237        mask_data.filter(keep_by_nms)
 238
 239        return mask_data
 240
 241    def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode):
 242        # filter small disconnected regions and holes in masks
 243        if min_mask_region_area > 0:
 244            mask_data = self._postprocess_small_regions(
 245                mask_data,
 246                min_mask_region_area,
 247                max(box_nms_thresh, crop_nms_thresh),
 248            )
 249
 250        # encode masks
 251        if output_mode == "coco_rle":
 252            mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]]
 253        elif output_mode == "binary_mask":
 254            mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]]
 255        else:
 256            mask_data["segmentations"] = mask_data["rles"]
 257
 258        # write mask records
 259        curr_anns = []
 260        for idx in range(len(mask_data["segmentations"])):
 261            ann = {
 262                "segmentation": mask_data["segmentations"][idx],
 263                "area": amg_utils.area_from_rle(mask_data["rles"][idx]),
 264                "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
 265                "predicted_iou": mask_data["iou_preds"][idx].item(),
 266                "stability_score": mask_data["stability_score"][idx].item(),
 267                "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
 268            }
 269            # the data from embedding based segmentation doesn't have the points
 270            # so we skip if the corresponding key can't be found
 271            try:
 272                ann["point_coords"] = [mask_data["points"][idx].tolist()]
 273            except KeyError:
 274                pass
 275            curr_anns.append(ann)
 276
 277        return curr_anns
 278
 279    def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
 280        orig_h, orig_w = original_size
 281
 282        # serialize predictions and store in MaskData
 283        data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
 284        if points is not None:
 285            data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
 286
 287        del masks
 288
 289        # calculate the stability scores
 290        data["stability_score"] = amg_utils.calculate_stability_score(
 291            data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
 292        )
 293
 294        # threshold masks and calculate boxes
 295        data["masks"] = data["masks"] > self._predictor.model.mask_threshold
 296        data["masks"] = data["masks"].type(torch.bool)
 297        data["boxes"] = batched_mask_to_box(data["masks"])
 298
 299        # compress to RLE
 300        data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
 301        # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
 302        data["rles"] = mask_to_rle_pytorch(data["masks"])
 303        del data["masks"]
 304
 305        return data
 306
 307    def get_state(self) -> Dict[str, Any]:
 308        """Get the initialized state of the mask generator.
 309
 310        Returns:
 311            State of the mask generator.
 312        """
 313        if not self.is_initialized:
 314            raise RuntimeError("The state has not been computed yet. Call initialize first.")
 315
 316        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
 317
 318    def set_state(self, state: Dict[str, Any]) -> None:
 319        """Set the state of the mask generator.
 320
 321        Args:
 322            state: The state of the mask generator, e.g. from serialized state.
 323        """
 324        self._crop_list = state["crop_list"]
 325        self._crop_boxes = state["crop_boxes"]
 326        self._original_size = state["original_size"]
 327        self._is_initialized = True
 328
 329    def clear_state(self):
 330        """Clear the state of the mask generator.
 331        """
 332        self._crop_list = None
 333        self._crop_boxes = None
 334        self._original_size = None
 335        self._is_initialized = False
 336
 337
 338class AutomaticMaskGenerator(AMGBase):
 339    """Generates an instance segmentation without prompts, using a point grid.
 340
 341    This class implements the same logic as
 342    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
 343    It decouples the computationally expensive steps of generating masks from the cheap post-processing operation
 344    to filter these masks to enable grid search and interactively changing the post-processing.
 345
 346    Use this class as follows:
 347    ```python
 348    amg = AutomaticMaskGenerator(predictor)
 349    amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
 350    masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
 351    ```
 352
 353    Args:
 354        predictor: The segment anything predictor.
 355        points_per_side: The number of points to be sampled along one side of the image.
 356            If None, `point_grids` must provide explicit point sampling.
 357        points_per_batch: The number of points run simultaneously by the model.
 358            Higher numbers may be faster but use more GPU memory.
 359        crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
 360        crop_overlap_ratio: Sets the degree to which crops overlap.
 361        crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
 362        point_grids: A lisst over explicit grids of points used for sampling masks.
 363            Normalized to [0, 1] with respect to the image coordinate system.
 364        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
 365    """
 366    def __init__(
 367        self,
 368        predictor: SamPredictor,
 369        points_per_side: Optional[int] = 32,
 370        points_per_batch: Optional[int] = None,
 371        crop_n_layers: int = 0,
 372        crop_overlap_ratio: float = 512 / 1500,
 373        crop_n_points_downscale_factor: int = 1,
 374        point_grids: Optional[List[np.ndarray]] = None,
 375        stability_score_offset: float = 1.0,
 376    ):
 377        super().__init__()
 378
 379        if points_per_side is not None:
 380            self.point_grids = amg_utils.build_all_layer_point_grids(
 381                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
 382            )
 383        elif point_grids is not None:
 384            self.point_grids = point_grids
 385        else:
 386            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
 387
 388        self._predictor = predictor
 389        self._points_per_side = points_per_side
 390
 391        # we set the points per batch to 16 for mps for performance reasons
 392        # and otherwise keep them at the default of 64
 393        if points_per_batch is None:
 394            points_per_batch = 16 if str(predictor.device) == "mps" else 64
 395        self._points_per_batch = points_per_batch
 396
 397        self._crop_n_layers = crop_n_layers
 398        self._crop_overlap_ratio = crop_overlap_ratio
 399        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
 400        self._stability_score_offset = stability_score_offset
 401
 402    def _process_batch(self, points, im_size, crop_box, original_size):
 403        # run model on this batch
 404        transformed_points = self._predictor.transform.apply_coords(points, im_size)
 405        in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
 406        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
 407        masks, iou_preds, _ = self._predictor.predict_torch(
 408            point_coords=in_points[:, None, :],
 409            point_labels=in_labels[:, None],
 410            multimask_output=True,
 411            return_logits=True,
 412        )
 413        data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
 414        del masks
 415        return data
 416
 417    def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None):
 418        # Crop the image and calculate embeddings.
 419        x0, y0, x1, y1 = crop_box
 420        cropped_im = image[y0:y1, x0:x1, :]
 421        cropped_im_size = cropped_im.shape[:2]
 422
 423        if not precomputed_embeddings:
 424            self._predictor.set_image(cropped_im)
 425
 426        # Get the points for this crop.
 427        points_scale = np.array(cropped_im_size)[None, ::-1]
 428        points_for_image = self.point_grids[crop_layer_idx] * points_scale
 429
 430        # Generate masks for this crop in batches.
 431        data = amg_utils.MaskData()
 432        n_batches = len(points_for_image) // self._points_per_batch +\
 433            int(len(points_for_image) % self._points_per_batch != 0)
 434        if pbar_init is not None:
 435            pbar_init(n_batches, "Predict masks for point grid prompts")
 436
 437        for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image):
 438            batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
 439            data.cat(batch_data)
 440            del batch_data
 441            if pbar_update is not None:
 442                pbar_update(1)
 443
 444        if not precomputed_embeddings:
 445            self._predictor.reset_image()
 446
 447        return data
 448
 449    @torch.no_grad()
 450    def initialize(
 451        self,
 452        image: np.ndarray,
 453        image_embeddings: Optional[util.ImageEmbeddings] = None,
 454        i: Optional[int] = None,
 455        verbose: bool = False,
 456        pbar_init: Optional[callable] = None,
 457        pbar_update: Optional[callable] = None,
 458    ) -> None:
 459        """Initialize image embeddings and masks for an image.
 460
 461        Args:
 462            image: The input image, volume or timeseries.
 463            image_embeddings: Optional precomputed image embeddings.
 464                See `util.precompute_image_embeddings` for details.
 465            i: Index for the image data. Required if `image` has three spatial dimensions
 466                or a time dimension and two spatial dimensions.
 467            verbose: Whether to print computation progress.
 468            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 469                Can be used together with pbar_update to handle napari progress bar in other thread.
 470                To enables using this function within a threadworker.
 471            pbar_update: Callback to update an external progress bar.
 472        """
 473        original_size = image.shape[:2]
 474        self._original_size = original_size
 475
 476        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
 477            original_size, self._crop_n_layers, self._crop_overlap_ratio
 478        )
 479
 480        # We can set fixed image embeddings if we only have a single crop box (the default setting).
 481        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
 482        if len(crop_boxes) == 1:
 483            if image_embeddings is None:
 484                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 485            util.set_precomputed(self._predictor, image_embeddings, i=i)
 486            precomputed_embeddings = True
 487        else:
 488            precomputed_embeddings = False
 489
 490        # we need to cast to the image representation that is compatible with SAM
 491        image = util._to_image(image)
 492
 493        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 494
 495        crop_list = []
 496        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
 497            crop_data = self._process_crop(
 498                image, crop_box, layer_idx,
 499                precomputed_embeddings=precomputed_embeddings,
 500                pbar_init=pbar_init, pbar_update=pbar_update,
 501            )
 502            crop_list.append(crop_data)
 503        pbar_close()
 504
 505        self._is_initialized = True
 506        self._crop_list = crop_list
 507        self._crop_boxes = crop_boxes
 508
 509    @torch.no_grad()
 510    def generate(
 511        self,
 512        pred_iou_thresh: float = 0.88,
 513        stability_score_thresh: float = 0.95,
 514        box_nms_thresh: float = 0.7,
 515        crop_nms_thresh: float = 0.7,
 516        min_mask_region_area: int = 0,
 517        output_mode: str = "binary_mask",
 518    ) -> List[Dict[str, Any]]:
 519        """Generate instance segmentation for the currently initialized image.
 520
 521        Args:
 522            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
 523            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
 524                under changes to the cutoff used to binarize the model prediction.
 525            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
 526            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
 527            min_mask_region_area: Minimal size for the predicted masks.
 528            output_mode: The form masks are returned in.
 529
 530        Returns:
 531            The instance segmentation masks.
 532        """
 533        if not self.is_initialized:
 534            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
 535
 536        data = amg_utils.MaskData()
 537        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
 538            crop_data = self._postprocess_batch(
 539                data=deepcopy(data_),
 540                crop_box=crop_box, original_size=self.original_size,
 541                pred_iou_thresh=pred_iou_thresh,
 542                stability_score_thresh=stability_score_thresh,
 543                box_nms_thresh=box_nms_thresh
 544            )
 545            data.cat(crop_data)
 546
 547        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
 548            # Prefer masks from smaller crops
 549            scores = 1 / box_area(data["crop_boxes"])
 550            scores = scores.to(data["boxes"].device)
 551            keep_by_nms = batched_nms(
 552                data["boxes"].float(),
 553                scores,
 554                torch.zeros_like(data["boxes"][:, 0]),  # categories
 555                iou_threshold=crop_nms_thresh,
 556            )
 557            data.filter(keep_by_nms)
 558
 559        data.to_numpy()
 560        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
 561        return masks
 562
 563
 564# Helper function for tiled embedding computation and checking consistent state.
 565def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo, verbose, batch_size):
 566    if image_embeddings is None:
 567        if tile_shape is None or halo is None:
 568            raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.")
 569        image_embeddings = util.precompute_image_embeddings(
 570            predictor, image, tile_shape=tile_shape, halo=halo, verbose=verbose, batch_size=batch_size
 571        )
 572
 573    # Use tile shape and halo from the precomputed embeddings if not given.
 574    # Otherwise check that they are consistent.
 575    feats = image_embeddings["features"]
 576    tile_shape_, halo_ = tuple(feats.attrs["tile_shape"]), tuple(feats.attrs["halo"])
 577    if tile_shape is None:
 578        tile_shape = tile_shape_
 579    elif tile_shape != tile_shape_:
 580        raise ValueError(
 581            f"Inconsistent tile_shape parameter {tile_shape} with precomputed embeedings: {tile_shape_}."
 582        )
 583    if halo is None:
 584        halo = halo_
 585    elif halo != halo_:
 586        raise ValueError(f"Inconsistent halo parameter {halo} with precomputed embeedings: {halo_}.")
 587
 588    return image_embeddings, tile_shape, halo
 589
 590
 591class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
 592    """Generates an instance segmentation without prompts, using a point grid.
 593
 594    Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
 595
 596    Args:
 597        predictor: The segment anything predictor.
 598        points_per_side: The number of points to be sampled along one side of the image.
 599            If None, `point_grids` must provide explicit point sampling.
 600        points_per_batch: The number of points run simultaneously by the model.
 601            Higher numbers may be faster but use more GPU memory.
 602        point_grids: A lisst over explicit grids of points used for sampling masks.
 603            Normalized to [0, 1] with respect to the image coordinate system.
 604        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
 605    """
 606
 607    # We only expose the arguments that make sense for the tiled mask generator.
 608    # Anything related to crops doesn't make sense, because we re-use that functionality
 609    # for tiling, so these parameters wouldn't have any effect.
 610    def __init__(
 611        self,
 612        predictor: SamPredictor,
 613        points_per_side: Optional[int] = 32,
 614        points_per_batch: int = 64,
 615        point_grids: Optional[List[np.ndarray]] = None,
 616        stability_score_offset: float = 1.0,
 617    ) -> None:
 618        super().__init__(
 619            predictor=predictor,
 620            points_per_side=points_per_side,
 621            points_per_batch=points_per_batch,
 622            point_grids=point_grids,
 623            stability_score_offset=stability_score_offset,
 624        )
 625
 626    @torch.no_grad()
 627    def initialize(
 628        self,
 629        image: np.ndarray,
 630        image_embeddings: Optional[util.ImageEmbeddings] = None,
 631        i: Optional[int] = None,
 632        tile_shape: Optional[Tuple[int, int]] = None,
 633        halo: Optional[Tuple[int, int]] = None,
 634        verbose: bool = False,
 635        pbar_init: Optional[callable] = None,
 636        pbar_update: Optional[callable] = None,
 637        batch_size: int = 1,
 638    ) -> None:
 639        """Initialize image embeddings and masks for an image.
 640
 641        Args:
 642            image: The input image, volume or timeseries.
 643            image_embeddings: Optional precomputed image embeddings.
 644                See `util.precompute_image_embeddings` for details.
 645            i: Index for the image data. Required if `image` has three spatial dimensions
 646                or a time dimension and two spatial dimensions.
 647            tile_shape: The tile shape for embedding prediction.
 648            halo: The overlap of between tiles.
 649            verbose: Whether to print computation progress.
 650            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 651                Can be used together with pbar_update to handle napari progress bar in other thread.
 652                To enables using this function within a threadworker.
 653            pbar_update: Callback to update an external progress bar.
 654            batch_size: The batch size for image embedding prediction.
 655        """
 656        original_size = image.shape[:2]
 657        self._original_size = original_size
 658
 659        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
 660            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size
 661        )
 662
 663        tiling = blocking([0, 0], original_size, tile_shape)
 664        n_tiles = tiling.numberOfBlocks
 665
 666        # The crop box is always the full local tile.
 667        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
 668        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
 669
 670        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 671        pbar_init(n_tiles, "Compute masks for tile")
 672
 673        # We need to cast to the image representation that is compatible with SAM.
 674        image = util._to_image(image)
 675
 676        mask_data = []
 677        for tile_id in range(n_tiles):
 678            # set the pre-computed embeddings for this tile
 679            features = image_embeddings["features"][tile_id]
 680            tile_embeddings = {
 681                "features": features,
 682                "input_size": features.attrs["input_size"],
 683                "original_size": features.attrs["original_size"],
 684            }
 685            util.set_precomputed(self._predictor, tile_embeddings, i)
 686
 687            # compute the mask data for this tile and append it
 688            this_mask_data = self._process_crop(
 689                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
 690            )
 691            mask_data.append(this_mask_data)
 692            pbar_update(1)
 693        pbar_close()
 694
 695        # set the initialized data
 696        self._is_initialized = True
 697        self._crop_list = mask_data
 698        self._crop_boxes = crop_boxes
 699
 700
 701#
 702# Instance segmentation functionality based on fine-tuned decoder
 703#
 704
 705
 706class DecoderAdapter(torch.nn.Module):
 707    """Adapter to contain the UNETR decoder in a single module.
 708
 709    To apply the decoder on top of pre-computed embeddings for the segmentation functionality.
 710    See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
 711    """
 712    def __init__(self, unetr):
 713        super().__init__()
 714
 715        self.base = unetr.base
 716        self.out_conv = unetr.out_conv
 717        self.deconv_out = unetr.deconv_out
 718        self.decoder_head = unetr.decoder_head
 719        self.final_activation = unetr.final_activation
 720        self.postprocess_masks = unetr.postprocess_masks
 721
 722        self.decoder = unetr.decoder
 723        self.deconv1 = unetr.deconv1
 724        self.deconv2 = unetr.deconv2
 725        self.deconv3 = unetr.deconv3
 726        self.deconv4 = unetr.deconv4
 727
 728    def _forward_impl(self, input_):
 729        z12 = input_
 730
 731        z9 = self.deconv1(z12)
 732        z6 = self.deconv2(z9)
 733        z3 = self.deconv3(z6)
 734        z0 = self.deconv4(z3)
 735
 736        updated_from_encoder = [z9, z6, z3]
 737
 738        x = self.base(z12)
 739        x = self.decoder(x, encoder_inputs=updated_from_encoder)
 740        x = self.deconv_out(x)
 741
 742        x = torch.cat([x, z0], dim=1)
 743        x = self.decoder_head(x)
 744
 745        x = self.out_conv(x)
 746        if self.final_activation is not None:
 747            x = self.final_activation(x)
 748        return x
 749
 750    def forward(self, input_, input_shape, original_shape):
 751        x = self._forward_impl(input_)
 752        x = self.postprocess_masks(x, input_shape, original_shape)
 753        return x
 754
 755
 756def get_unetr(
 757    image_encoder: torch.nn.Module,
 758    decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
 759    device: Optional[Union[str, torch.device]] = None,
 760    out_channels: int = 3,
 761    flexible_load_checkpoint: bool = False,
 762) -> torch.nn.Module:
 763    """Get UNETR model for automatic instance segmentation.
 764
 765    Args:
 766        image_encoder: The image encoder of the SAM model.
 767            This is used as encoder by the UNETR too.
 768        decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
 769        device: The device.
 770        out_channels: The number of output channels.
 771        flexible_load_checkpoint: Whether to allow reinitialization of parameters
 772            which could not be found in the provided decoder state.
 773
 774    Returns:
 775        The UNETR model.
 776    """
 777    device = util.get_device(device)
 778
 779    unetr = UNETR(
 780        backbone="sam",
 781        encoder=image_encoder,
 782        out_channels=out_channels,
 783        use_sam_stats=True,
 784        final_activation="Sigmoid",
 785        use_skip_connection=False,
 786        resize_input=True,
 787    )
 788    if decoder_state is not None:
 789        unetr_state_dict = unetr.state_dict()
 790        for k, v in unetr_state_dict.items():
 791            if not k.startswith("encoder"):
 792                if flexible_load_checkpoint:  # Whether allow reinitalization of params, if not found.
 793                    if k in decoder_state:  # First check whether the key is available in the provided decoder state.
 794                        unetr_state_dict[k] = decoder_state[k]
 795                    else:  # Otherwise, allow it to initialize it.
 796                        warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.")
 797                        unetr_state_dict[k] = v
 798
 799                else:  # Whether be strict on finding the parameter in the decoder state.
 800                    if k not in decoder_state:
 801                        raise RuntimeError(f"The parameters for '{k}' could not be found.")
 802                    unetr_state_dict[k] = decoder_state[k]
 803
 804        unetr.load_state_dict(unetr_state_dict)
 805
 806    unetr.to(device)
 807    return unetr
 808
 809
 810def get_decoder(
 811    image_encoder: torch.nn.Module,
 812    decoder_state: OrderedDict[str, torch.Tensor],
 813    device: Optional[Union[str, torch.device]] = None,
 814) -> DecoderAdapter:
 815    """Get decoder to predict outputs for automatic instance segmentation
 816
 817    Args:
 818        image_encoder: The image encoder of the SAM model.
 819        decoder_state: State to initialize the weights of the UNETR decoder.
 820        device: The device.
 821
 822    Returns:
 823        The decoder for instance segmentation.
 824    """
 825    unetr = get_unetr(image_encoder, decoder_state, device)
 826    return DecoderAdapter(unetr)
 827
 828
 829def get_predictor_and_decoder(
 830    model_type: str,
 831    checkpoint_path: Union[str, os.PathLike],
 832    device: Optional[Union[str, torch.device]] = None,
 833    peft_kwargs: Optional[Dict] = None,
 834) -> Tuple[SamPredictor, DecoderAdapter]:
 835    """Load the SAM model (predictor) and instance segmentation decoder.
 836
 837    This requires a checkpoint that contains the state for both predictor
 838    and decoder.
 839
 840    Args:
 841        model_type: The type of the image encoder used in the SAM model.
 842        checkpoint_path: Path to the checkpoint from which to load the data.
 843        device: The device.
 844        peft_kwargs: Keyword arguments for the PEFT wrapper class.
 845
 846    Returns:
 847        The SAM predictor.
 848        The decoder for instance segmentation.
 849    """
 850    device = util.get_device(device)
 851    predictor, state = util.get_sam_model(
 852        model_type=model_type,
 853        checkpoint_path=checkpoint_path,
 854        device=device,
 855        return_state=True,
 856        peft_kwargs=peft_kwargs,
 857    )
 858    if "decoder_state" not in state:
 859        raise ValueError(
 860            f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
 861        )
 862    decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
 863    return predictor, decoder
 864
 865
 866def _watershed_from_center_and_boundary_distances_parallel(
 867    center_distances,
 868    boundary_distances,
 869    foreground_map,
 870    center_distance_threshold,
 871    boundary_distance_threshold,
 872    foreground_threshold,
 873    distance_smoothing,
 874    min_size,
 875    tile_shape,
 876    halo,
 877    n_threads,
 878    verbose=False,
 879):
 880    center_distances = apply_filter(
 881        center_distances, "gaussianSmoothing", sigma=distance_smoothing,
 882        block_shape=tile_shape, n_threads=n_threads
 883    )
 884    boundary_distances = apply_filter(
 885        boundary_distances, "gaussianSmoothing", sigma=distance_smoothing,
 886        block_shape=tile_shape, n_threads=n_threads
 887    )
 888
 889    fg_mask = foreground_map > foreground_threshold
 890
 891    marker_map = np.logical_and(
 892        center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold
 893    )
 894    marker_map[~fg_mask] = 0
 895
 896    markers = np.zeros(marker_map.shape, dtype="uint64")
 897    markers = parallel.label(
 898        marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose,
 899    )
 900
 901    seg = np.zeros_like(markers, dtype="uint64")
 902    seg = parallel.seeded_watershed(
 903        boundary_distances, seeds=markers, out=seg, block_shape=tile_shape,
 904        halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask,
 905    )
 906
 907    out = np.zeros_like(seg, dtype="uint64")
 908    out = parallel.size_filter(
 909        seg, out=out, min_size=min_size, block_shape=tile_shape, n_threads=n_threads, verbose=verbose
 910    )
 911
 912    return out
 913
 914
 915class InstanceSegmentationWithDecoder:
 916    """Generates an instance segmentation without prompts, using a decoder.
 917
 918    Implements the same interface as `AutomaticMaskGenerator`.
 919
 920    Use this class as follows:
 921    ```python
 922    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
 923    segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
 924    masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
 925    ```
 926
 927    Args:
 928        predictor: The segment anything predictor.
 929        decoder: The decoder to predict intermediate representations
 930            for instance segmentation.
 931    """
 932    def __init__(
 933        self,
 934        predictor: SamPredictor,
 935        decoder: torch.nn.Module,
 936    ) -> None:
 937        self._predictor = predictor
 938        self._decoder = decoder
 939
 940        # The decoder outputs.
 941        self._foreground = None
 942        self._center_distances = None
 943        self._boundary_distances = None
 944
 945        self._is_initialized = False
 946
 947    @property
 948    def is_initialized(self):
 949        """Whether the mask generator has already been initialized.
 950        """
 951        return self._is_initialized
 952
 953    @torch.no_grad()
 954    def initialize(
 955        self,
 956        image: np.ndarray,
 957        image_embeddings: Optional[util.ImageEmbeddings] = None,
 958        i: Optional[int] = None,
 959        verbose: bool = False,
 960        pbar_init: Optional[callable] = None,
 961        pbar_update: Optional[callable] = None,
 962    ) -> None:
 963        """Initialize image embeddings and decoder predictions for an image.
 964
 965        Args:
 966            image: The input image, volume or timeseries.
 967            image_embeddings: Optional precomputed image embeddings.
 968                See `util.precompute_image_embeddings` for details.
 969            i: Index for the image data. Required if `image` has three spatial dimensions
 970                or a time dimension and two spatial dimensions.
 971            verbose: Whether to be verbose.
 972            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 973                Can be used together with pbar_update to handle napari progress bar in other thread.
 974                To enables using this function within a threadworker.
 975            pbar_update: Callback to update an external progress bar.
 976        """
 977        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 978        pbar_init(1, "Initialize instance segmentation with decoder")
 979
 980        if image_embeddings is None:
 981            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 982
 983        # Get the image embeddings from the predictor.
 984        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
 985        embeddings = self._predictor.features
 986        input_shape = tuple(self._predictor.input_size)
 987        original_shape = tuple(self._predictor.original_size)
 988
 989        # Run prediction with the UNETR decoder.
 990        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
 991        assert output.shape[0] == 3, f"{output.shape}"
 992        pbar_update(1)
 993        pbar_close()
 994
 995        # Set the state.
 996        self._foreground = output[0]
 997        self._center_distances = output[1]
 998        self._boundary_distances = output[2]
 999        self._is_initialized = True
1000
1001    def _to_masks(self, segmentation, output_mode):
1002        if output_mode != "binary_mask":
1003            raise NotImplementedError
1004
1005        props = regionprops(segmentation)
1006        ndim = segmentation.ndim
1007        assert ndim in (2, 3)
1008
1009        shape = segmentation.shape
1010        if ndim == 2:
1011            crop_box = [0, shape[1], 0, shape[0]]
1012        else:
1013            crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]
1014
1015        # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
1016        def to_bbox_2d(bbox):
1017            y0, x0 = bbox[0], bbox[1]
1018            w = bbox[3] - x0
1019            h = bbox[2] - y0
1020            return [x0, w, y0, h]
1021
1022        def to_bbox_3d(bbox):
1023            z0, y0, x0 = bbox[0], bbox[1], bbox[2]
1024            w = bbox[5] - x0
1025            h = bbox[4] - y0
1026            d = bbox[3] - y0
1027            return [x0, w, y0, h, z0, d]
1028
1029        to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
1030        masks = [
1031            {
1032                "segmentation": segmentation == prop.label,
1033                "area": prop.area,
1034                "bbox": to_bbox(prop.bbox),
1035                "crop_box": crop_box,
1036                "seg_id": prop.label,
1037            } for prop in props
1038        ]
1039        return masks
1040
1041    def generate(
1042        self,
1043        center_distance_threshold: float = 0.5,
1044        boundary_distance_threshold: float = 0.5,
1045        foreground_threshold: float = 0.5,
1046        foreground_smoothing: float = 1.0,
1047        distance_smoothing: float = 1.6,
1048        min_size: int = 0,
1049        output_mode: Optional[str] = "binary_mask",
1050        tile_shape: Optional[Tuple[int, int]] = None,
1051        halo: Optional[Tuple[int, int]] = None,
1052        n_threads: Optional[int] = None,
1053    ) -> List[Dict[str, Any]]:
1054        """Generate instance segmentation for the currently initialized image.
1055
1056        Args:
1057            center_distance_threshold: Center distance predictions below this value will be
1058                used to find seeds (intersected with thresholded boundary distance predictions).
1059            boundary_distance_threshold: Boundary distance predictions below this value will be
1060                used to find seeds (intersected with thresholded center distance predictions).
1061            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1062                checkerboard artifacts in the prediction.
1063            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1064            distance_smoothing: Sigma value for smoothing the distance predictions.
1065            min_size: Minimal object size in the segmentation result.
1066            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
1067            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1068                This parameter is independent from the tile shape for computing the embeddings.
1069                If not given then post-processing will not be parallelized.
1070            halo: Halo for parallel post-processing. See also `tile_shape`.
1071            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1072
1073        Returns:
1074            The instance segmentation masks.
1075        """
1076        if not self.is_initialized:
1077            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1078
1079        if foreground_smoothing > 0:
1080            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1081        else:
1082            foreground = self._foreground
1083
1084        if tile_shape is None:
1085            segmentation = watershed_from_center_and_boundary_distances(
1086                center_distances=self._center_distances,
1087                boundary_distances=self._boundary_distances,
1088                foreground_map=foreground,
1089                center_distance_threshold=center_distance_threshold,
1090                boundary_distance_threshold=boundary_distance_threshold,
1091                foreground_threshold=foreground_threshold,
1092                distance_smoothing=distance_smoothing,
1093                min_size=min_size,
1094            )
1095        else:
1096            if halo is None:
1097                raise ValueError("You must pass a value for halo if tile_shape is given.")
1098            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1099                center_distances=self._center_distances,
1100                boundary_distances=self._boundary_distances,
1101                foreground_map=foreground,
1102                center_distance_threshold=center_distance_threshold,
1103                boundary_distance_threshold=boundary_distance_threshold,
1104                foreground_threshold=foreground_threshold,
1105                distance_smoothing=distance_smoothing,
1106                min_size=min_size,
1107                tile_shape=tile_shape,
1108                halo=halo,
1109                n_threads=n_threads,
1110                verbose=False,
1111            )
1112
1113        if output_mode is not None:
1114            segmentation = self._to_masks(segmentation, output_mode)
1115        return segmentation
1116
1117    def get_state(self) -> Dict[str, Any]:
1118        """Get the initialized state of the instance segmenter.
1119
1120        Returns:
1121            Instance segmentation state.
1122        """
1123        if not self.is_initialized:
1124            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1125
1126        return {
1127            "foreground": self._foreground,
1128            "center_distances": self._center_distances,
1129            "boundary_distances": self._boundary_distances,
1130        }
1131
1132    def set_state(self, state: Dict[str, Any]) -> None:
1133        """Set the state of the instance segmenter.
1134
1135        Args:
1136            state: The instance segmentation state
1137        """
1138        self._foreground = state["foreground"]
1139        self._center_distances = state["center_distances"]
1140        self._boundary_distances = state["boundary_distances"]
1141        self._is_initialized = True
1142
1143    def clear_state(self):
1144        """Clear the state of the instance segmenter.
1145        """
1146        self._foreground = None
1147        self._center_distances = None
1148        self._boundary_distances = None
1149        self._is_initialized = False
1150
1151
1152class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1153    """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings.
1154    """
1155
1156    # Apply the decoder in a batched fashion, and then perform the resizing independently per output.
1157    # This is necessary, because the individual tiles may have different tile shapes due to border tiles.
1158    def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes):
1159        batched_embeddings = torch.cat(batched_embeddings)
1160        output = self._decoder._forward_impl(batched_embeddings)
1161
1162        batched_output = []
1163        for x, input_shape, original_shape in zip(output, input_shapes, original_shapes):
1164            x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0)
1165            batched_output.append(x.cpu().numpy())
1166        return batched_output
1167
1168    @torch.no_grad()
1169    def initialize(
1170        self,
1171        image: np.ndarray,
1172        image_embeddings: Optional[util.ImageEmbeddings] = None,
1173        i: Optional[int] = None,
1174        tile_shape: Optional[Tuple[int, int]] = None,
1175        halo: Optional[Tuple[int, int]] = None,
1176        verbose: bool = False,
1177        pbar_init: Optional[callable] = None,
1178        pbar_update: Optional[callable] = None,
1179        batch_size: int = 1,
1180    ) -> None:
1181        """Initialize image embeddings and decoder predictions for an image.
1182
1183        Args:
1184            image: The input image, volume or timeseries.
1185            image_embeddings: Optional precomputed image embeddings.
1186                See `util.precompute_image_embeddings` for details.
1187            i: Index for the image data. Required if `image` has three spatial dimensions
1188                or a time dimension and two spatial dimensions.
1189            tile_shape: Shape of the tiles for precomputing image embeddings.
1190            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1191            verbose: Dummy input to be compatible with other function signatures.
1192            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1193                Can be used together with pbar_update to handle napari progress bar in other thread.
1194                To enables using this function within a threadworker.
1195            pbar_update: Callback to update an external progress bar.
1196            batch_size: The batch size for image embedding computation and segmentation decoder prediction.
1197        """
1198        original_size = image.shape[:2]
1199        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1200            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size
1201        )
1202        tiling = blocking([0, 0], original_size, tile_shape)
1203
1204        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1205        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1206
1207        foreground = np.zeros(original_size, dtype="float32")
1208        center_distances = np.zeros(original_size, dtype="float32")
1209        boundary_distances = np.zeros(original_size, dtype="float32")
1210
1211        n_tiles = tiling.numberOfBlocks
1212        n_batches = int(np.ceil(n_tiles / batch_size))
1213
1214        for batch_id in range(n_batches):
1215            tile_start = batch_id * batch_size
1216            tile_stop = min(tile_start + batch_size, n_tiles)
1217
1218            batched_embeddings, input_shapes, original_shapes = [], [], []
1219            for tile_id in range(tile_start, tile_stop):
1220                # Get the image embeddings from the predictor for this tile.
1221                self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1222
1223                batched_embeddings.append(self._predictor.features)
1224                input_shapes.append(tuple(self._predictor.input_size))
1225                original_shapes.append(tuple(self._predictor.original_size))
1226
1227            batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes)
1228
1229            for output_id, tile_id in enumerate(range(tile_start, tile_stop)):
1230                output = batched_output[output_id]
1231                assert output.shape[0] == 3
1232
1233                # Set the predictions in the output for this tile.
1234                block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1235                local_bb = tuple(
1236                    slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1237                )
1238                inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1239
1240                foreground[inner_bb] = output[0][local_bb]
1241                center_distances[inner_bb] = output[1][local_bb]
1242                boundary_distances[inner_bb] = output[2][local_bb]
1243                pbar_update(1)
1244
1245        pbar_close()
1246
1247        # Set the state.
1248        self._foreground = foreground
1249        self._center_distances = center_distances
1250        self._boundary_distances = boundary_distances
1251        self._is_initialized = True
1252
1253
1254def get_amg(
1255    predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs,
1256) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1257    """Get the automatic mask generator class.
1258
1259    Args:
1260        predictor: The segment anything predictor.
1261        is_tiled: Whether tiled embeddings are used.
1262        decoder: Decoder to predict instacne segmmentation.
1263        kwargs: The keyword arguments for the amg class.
1264
1265    Returns:
1266        The automatic mask generator.
1267    """
1268    if decoder is None:
1269        segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator
1270        segmenter = segmenter_class(predictor, **kwargs)
1271    else:
1272        segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder
1273        segmenter = segmenter_class(predictor, decoder, **kwargs)
1274
1275    return segmenter
def mask_data_to_segmentation( masks: List[Dict[str, Any]], with_background: bool, min_object_size: int = 0, max_object_size: Optional[int] = None, label_masks: bool = True) -> numpy.ndarray:
 50def mask_data_to_segmentation(
 51    masks: List[Dict[str, Any]],
 52    with_background: bool,
 53    min_object_size: int = 0,
 54    max_object_size: Optional[int] = None,
 55    label_masks: bool = True,
 56) -> np.ndarray:
 57    """Convert the output of the automatic mask generation to an instance segmentation.
 58
 59    Args:
 60        masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator.
 61            Only supports output_mode=binary_mask.
 62        with_background: Whether the segmentation has background. If yes this function assures that the largest
 63            object in the output will be mapped to zero (the background value).
 64        min_object_size: The minimal size of an object in pixels.
 65        max_object_size: The maximal size of an object in pixels.
 66        label_masks: Whether to apply connected components to the result before removing small objects.
 67
 68    Returns:
 69        The instance segmentation.
 70    """
 71
 72    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
 73    # we could also get the shape from the crop box
 74    shape = next(iter(masks))["segmentation"].shape
 75    segmentation = np.zeros(shape, dtype="uint32")
 76
 77    def require_numpy(mask):
 78        return mask.cpu().numpy() if torch.is_tensor(mask) else mask
 79
 80    seg_id = 1
 81    for mask in masks:
 82        if mask["area"] < min_object_size:
 83            continue
 84        if max_object_size is not None and mask["area"] > max_object_size:
 85            continue
 86
 87        this_seg_id = mask.get("seg_id", seg_id)
 88        segmentation[require_numpy(mask["segmentation"])] = this_seg_id
 89        seg_id = this_seg_id + 1
 90
 91    if label_masks:
 92        segmentation = label(segmentation).astype(segmentation.dtype)
 93
 94    seg_ids, sizes = np.unique(segmentation, return_counts=True)
 95
 96    # In some cases objects may be smaller than peviously calculated,
 97    # since they are covered by other objects. We ensure these also get
 98    # filtered out here.
 99    filter_ids = seg_ids[sizes < min_object_size]
100
101    # If we run segmentation with background we also map the largest segment
102    # (the most likely background object) to zero. This is often zero already,
103    # but it does not hurt to reset that to zero either.
104    if with_background:
105        bg_id = seg_ids[np.argmax(sizes)]
106        filter_ids = np.concatenate([filter_ids, [bg_id]])
107
108    segmentation[np.isin(segmentation, filter_ids)] = 0
109    segmentation = relabel_sequential(segmentation)[0]
110
111    return segmentation

Convert the output of the automatic mask generation to an instance segmentation.

Arguments:
  • masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator. Only supports output_mode=binary_mask.
  • with_background: Whether the segmentation has background. If yes this function assures that the largest object in the output will be mapped to zero (the background value).
  • min_object_size: The minimal size of an object in pixels.
  • max_object_size: The maximal size of an object in pixels.
  • label_masks: Whether to apply connected components to the result before removing small objects.
Returns:

The instance segmentation.

class AMGBase(abc.ABC):
119class AMGBase(ABC):
120    """Base class for the automatic mask generators.
121    """
122    def __init__(self):
123        # the state that has to be computed by the 'initialize' method of the child classes
124        self._is_initialized = False
125        self._crop_list = None
126        self._crop_boxes = None
127        self._original_size = None
128
129    @property
130    def is_initialized(self):
131        """Whether the mask generator has already been initialized.
132        """
133        return self._is_initialized
134
135    @property
136    def crop_list(self):
137        """The list of mask data after initialization.
138        """
139        return self._crop_list
140
141    @property
142    def crop_boxes(self):
143        """The list of crop boxes.
144        """
145        return self._crop_boxes
146
147    @property
148    def original_size(self):
149        """The original image size.
150        """
151        return self._original_size
152
153    def _postprocess_batch(
154        self,
155        data,
156        crop_box,
157        original_size,
158        pred_iou_thresh,
159        stability_score_thresh,
160        box_nms_thresh,
161    ):
162        orig_h, orig_w = original_size
163
164        # filter by predicted IoU
165        if pred_iou_thresh > 0.0:
166            keep_mask = data["iou_preds"] > pred_iou_thresh
167            data.filter(keep_mask)
168
169        # filter by stability score
170        if stability_score_thresh > 0.0:
171            keep_mask = data["stability_score"] >= stability_score_thresh
172            data.filter(keep_mask)
173
174        # filter boxes that touch crop boundaries
175        keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
176        if not torch.all(keep_mask):
177            data.filter(keep_mask)
178
179        # remove duplicates within this crop.
180        keep_by_nms = batched_nms(
181            data["boxes"].float(),
182            data["iou_preds"],
183            torch.zeros_like(data["boxes"][:, 0]),  # categories
184            iou_threshold=box_nms_thresh,
185        )
186        data.filter(keep_by_nms)
187
188        # return to the original image frame
189        data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box)
190        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
191        # the data from embedding based segmentation doesn't have the points
192        # so we skip if the corresponding key can't be found
193        try:
194            data["points"] = amg_utils.uncrop_points(data["points"], crop_box)
195        except KeyError:
196            pass
197
198        return data
199
200    def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
201
202        if len(mask_data["rles"]) == 0:
203            return mask_data
204
205        # filter small disconnected regions and holes
206        new_masks = []
207        scores = []
208        for rle in mask_data["rles"]:
209            mask = amg_utils.rle_to_mask(rle)
210
211            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes")
212            unchanged = not changed
213            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands")
214            unchanged = unchanged and not changed
215
216            new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0))
217            # give score=0 to changed masks and score=1 to unchanged masks
218            # so NMS will prefer ones that didn't need postprocessing
219            scores.append(float(unchanged))
220
221        # recalculate boxes and remove any new duplicates
222        masks = torch.cat(new_masks, dim=0)
223        boxes = batched_mask_to_box(masks.to(torch.bool))  # Casting this to boolean as we work with one-hot labels.
224        keep_by_nms = batched_nms(
225            boxes.float(),
226            torch.as_tensor(scores, dtype=torch.float),
227            torch.zeros_like(boxes[:, 0]),  # categories
228            iou_threshold=nms_thresh,
229        )
230
231        # only recalculate RLEs for masks that have changed
232        for i_mask in keep_by_nms:
233            if scores[i_mask] == 0.0:
234                mask_torch = masks[i_mask].unsqueeze(0)
235                # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
236                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
237                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
238        mask_data.filter(keep_by_nms)
239
240        return mask_data
241
242    def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode):
243        # filter small disconnected regions and holes in masks
244        if min_mask_region_area > 0:
245            mask_data = self._postprocess_small_regions(
246                mask_data,
247                min_mask_region_area,
248                max(box_nms_thresh, crop_nms_thresh),
249            )
250
251        # encode masks
252        if output_mode == "coco_rle":
253            mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]]
254        elif output_mode == "binary_mask":
255            mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]]
256        else:
257            mask_data["segmentations"] = mask_data["rles"]
258
259        # write mask records
260        curr_anns = []
261        for idx in range(len(mask_data["segmentations"])):
262            ann = {
263                "segmentation": mask_data["segmentations"][idx],
264                "area": amg_utils.area_from_rle(mask_data["rles"][idx]),
265                "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
266                "predicted_iou": mask_data["iou_preds"][idx].item(),
267                "stability_score": mask_data["stability_score"][idx].item(),
268                "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
269            }
270            # the data from embedding based segmentation doesn't have the points
271            # so we skip if the corresponding key can't be found
272            try:
273                ann["point_coords"] = [mask_data["points"][idx].tolist()]
274            except KeyError:
275                pass
276            curr_anns.append(ann)
277
278        return curr_anns
279
280    def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
281        orig_h, orig_w = original_size
282
283        # serialize predictions and store in MaskData
284        data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
285        if points is not None:
286            data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
287
288        del masks
289
290        # calculate the stability scores
291        data["stability_score"] = amg_utils.calculate_stability_score(
292            data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
293        )
294
295        # threshold masks and calculate boxes
296        data["masks"] = data["masks"] > self._predictor.model.mask_threshold
297        data["masks"] = data["masks"].type(torch.bool)
298        data["boxes"] = batched_mask_to_box(data["masks"])
299
300        # compress to RLE
301        data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
302        # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
303        data["rles"] = mask_to_rle_pytorch(data["masks"])
304        del data["masks"]
305
306        return data
307
308    def get_state(self) -> Dict[str, Any]:
309        """Get the initialized state of the mask generator.
310
311        Returns:
312            State of the mask generator.
313        """
314        if not self.is_initialized:
315            raise RuntimeError("The state has not been computed yet. Call initialize first.")
316
317        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
318
319    def set_state(self, state: Dict[str, Any]) -> None:
320        """Set the state of the mask generator.
321
322        Args:
323            state: The state of the mask generator, e.g. from serialized state.
324        """
325        self._crop_list = state["crop_list"]
326        self._crop_boxes = state["crop_boxes"]
327        self._original_size = state["original_size"]
328        self._is_initialized = True
329
330    def clear_state(self):
331        """Clear the state of the mask generator.
332        """
333        self._crop_list = None
334        self._crop_boxes = None
335        self._original_size = None
336        self._is_initialized = False

Base class for the automatic mask generators.

is_initialized
129    @property
130    def is_initialized(self):
131        """Whether the mask generator has already been initialized.
132        """
133        return self._is_initialized

Whether the mask generator has already been initialized.

crop_list
135    @property
136    def crop_list(self):
137        """The list of mask data after initialization.
138        """
139        return self._crop_list

The list of mask data after initialization.

crop_boxes
141    @property
142    def crop_boxes(self):
143        """The list of crop boxes.
144        """
145        return self._crop_boxes

The list of crop boxes.

original_size
147    @property
148    def original_size(self):
149        """The original image size.
150        """
151        return self._original_size

The original image size.

def get_state(self) -> Dict[str, Any]:
308    def get_state(self) -> Dict[str, Any]:
309        """Get the initialized state of the mask generator.
310
311        Returns:
312            State of the mask generator.
313        """
314        if not self.is_initialized:
315            raise RuntimeError("The state has not been computed yet. Call initialize first.")
316
317        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}

Get the initialized state of the mask generator.

Returns:

State of the mask generator.

def set_state(self, state: Dict[str, Any]) -> None:
319    def set_state(self, state: Dict[str, Any]) -> None:
320        """Set the state of the mask generator.
321
322        Args:
323            state: The state of the mask generator, e.g. from serialized state.
324        """
325        self._crop_list = state["crop_list"]
326        self._crop_boxes = state["crop_boxes"]
327        self._original_size = state["original_size"]
328        self._is_initialized = True

Set the state of the mask generator.

Arguments:
  • state: The state of the mask generator, e.g. from serialized state.
def clear_state(self):
330    def clear_state(self):
331        """Clear the state of the mask generator.
332        """
333        self._crop_list = None
334        self._crop_boxes = None
335        self._original_size = None
336        self._is_initialized = False

Clear the state of the mask generator.

class AutomaticMaskGenerator(AMGBase):
339class AutomaticMaskGenerator(AMGBase):
340    """Generates an instance segmentation without prompts, using a point grid.
341
342    This class implements the same logic as
343    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
344    It decouples the computationally expensive steps of generating masks from the cheap post-processing operation
345    to filter these masks to enable grid search and interactively changing the post-processing.
346
347    Use this class as follows:
348    ```python
349    amg = AutomaticMaskGenerator(predictor)
350    amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
351    masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
352    ```
353
354    Args:
355        predictor: The segment anything predictor.
356        points_per_side: The number of points to be sampled along one side of the image.
357            If None, `point_grids` must provide explicit point sampling.
358        points_per_batch: The number of points run simultaneously by the model.
359            Higher numbers may be faster but use more GPU memory.
360        crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
361        crop_overlap_ratio: Sets the degree to which crops overlap.
362        crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
363        point_grids: A lisst over explicit grids of points used for sampling masks.
364            Normalized to [0, 1] with respect to the image coordinate system.
365        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
366    """
367    def __init__(
368        self,
369        predictor: SamPredictor,
370        points_per_side: Optional[int] = 32,
371        points_per_batch: Optional[int] = None,
372        crop_n_layers: int = 0,
373        crop_overlap_ratio: float = 512 / 1500,
374        crop_n_points_downscale_factor: int = 1,
375        point_grids: Optional[List[np.ndarray]] = None,
376        stability_score_offset: float = 1.0,
377    ):
378        super().__init__()
379
380        if points_per_side is not None:
381            self.point_grids = amg_utils.build_all_layer_point_grids(
382                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
383            )
384        elif point_grids is not None:
385            self.point_grids = point_grids
386        else:
387            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
388
389        self._predictor = predictor
390        self._points_per_side = points_per_side
391
392        # we set the points per batch to 16 for mps for performance reasons
393        # and otherwise keep them at the default of 64
394        if points_per_batch is None:
395            points_per_batch = 16 if str(predictor.device) == "mps" else 64
396        self._points_per_batch = points_per_batch
397
398        self._crop_n_layers = crop_n_layers
399        self._crop_overlap_ratio = crop_overlap_ratio
400        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
401        self._stability_score_offset = stability_score_offset
402
403    def _process_batch(self, points, im_size, crop_box, original_size):
404        # run model on this batch
405        transformed_points = self._predictor.transform.apply_coords(points, im_size)
406        in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
407        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
408        masks, iou_preds, _ = self._predictor.predict_torch(
409            point_coords=in_points[:, None, :],
410            point_labels=in_labels[:, None],
411            multimask_output=True,
412            return_logits=True,
413        )
414        data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
415        del masks
416        return data
417
418    def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None):
419        # Crop the image and calculate embeddings.
420        x0, y0, x1, y1 = crop_box
421        cropped_im = image[y0:y1, x0:x1, :]
422        cropped_im_size = cropped_im.shape[:2]
423
424        if not precomputed_embeddings:
425            self._predictor.set_image(cropped_im)
426
427        # Get the points for this crop.
428        points_scale = np.array(cropped_im_size)[None, ::-1]
429        points_for_image = self.point_grids[crop_layer_idx] * points_scale
430
431        # Generate masks for this crop in batches.
432        data = amg_utils.MaskData()
433        n_batches = len(points_for_image) // self._points_per_batch +\
434            int(len(points_for_image) % self._points_per_batch != 0)
435        if pbar_init is not None:
436            pbar_init(n_batches, "Predict masks for point grid prompts")
437
438        for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image):
439            batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
440            data.cat(batch_data)
441            del batch_data
442            if pbar_update is not None:
443                pbar_update(1)
444
445        if not precomputed_embeddings:
446            self._predictor.reset_image()
447
448        return data
449
450    @torch.no_grad()
451    def initialize(
452        self,
453        image: np.ndarray,
454        image_embeddings: Optional[util.ImageEmbeddings] = None,
455        i: Optional[int] = None,
456        verbose: bool = False,
457        pbar_init: Optional[callable] = None,
458        pbar_update: Optional[callable] = None,
459    ) -> None:
460        """Initialize image embeddings and masks for an image.
461
462        Args:
463            image: The input image, volume or timeseries.
464            image_embeddings: Optional precomputed image embeddings.
465                See `util.precompute_image_embeddings` for details.
466            i: Index for the image data. Required if `image` has three spatial dimensions
467                or a time dimension and two spatial dimensions.
468            verbose: Whether to print computation progress.
469            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
470                Can be used together with pbar_update to handle napari progress bar in other thread.
471                To enables using this function within a threadworker.
472            pbar_update: Callback to update an external progress bar.
473        """
474        original_size = image.shape[:2]
475        self._original_size = original_size
476
477        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
478            original_size, self._crop_n_layers, self._crop_overlap_ratio
479        )
480
481        # We can set fixed image embeddings if we only have a single crop box (the default setting).
482        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
483        if len(crop_boxes) == 1:
484            if image_embeddings is None:
485                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
486            util.set_precomputed(self._predictor, image_embeddings, i=i)
487            precomputed_embeddings = True
488        else:
489            precomputed_embeddings = False
490
491        # we need to cast to the image representation that is compatible with SAM
492        image = util._to_image(image)
493
494        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
495
496        crop_list = []
497        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
498            crop_data = self._process_crop(
499                image, crop_box, layer_idx,
500                precomputed_embeddings=precomputed_embeddings,
501                pbar_init=pbar_init, pbar_update=pbar_update,
502            )
503            crop_list.append(crop_data)
504        pbar_close()
505
506        self._is_initialized = True
507        self._crop_list = crop_list
508        self._crop_boxes = crop_boxes
509
510    @torch.no_grad()
511    def generate(
512        self,
513        pred_iou_thresh: float = 0.88,
514        stability_score_thresh: float = 0.95,
515        box_nms_thresh: float = 0.7,
516        crop_nms_thresh: float = 0.7,
517        min_mask_region_area: int = 0,
518        output_mode: str = "binary_mask",
519    ) -> List[Dict[str, Any]]:
520        """Generate instance segmentation for the currently initialized image.
521
522        Args:
523            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
524            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
525                under changes to the cutoff used to binarize the model prediction.
526            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
527            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
528            min_mask_region_area: Minimal size for the predicted masks.
529            output_mode: The form masks are returned in.
530
531        Returns:
532            The instance segmentation masks.
533        """
534        if not self.is_initialized:
535            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
536
537        data = amg_utils.MaskData()
538        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
539            crop_data = self._postprocess_batch(
540                data=deepcopy(data_),
541                crop_box=crop_box, original_size=self.original_size,
542                pred_iou_thresh=pred_iou_thresh,
543                stability_score_thresh=stability_score_thresh,
544                box_nms_thresh=box_nms_thresh
545            )
546            data.cat(crop_data)
547
548        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
549            # Prefer masks from smaller crops
550            scores = 1 / box_area(data["crop_boxes"])
551            scores = scores.to(data["boxes"].device)
552            keep_by_nms = batched_nms(
553                data["boxes"].float(),
554                scores,
555                torch.zeros_like(data["boxes"][:, 0]),  # categories
556                iou_threshold=crop_nms_thresh,
557            )
558            data.filter(keep_by_nms)
559
560        data.to_numpy()
561        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
562        return masks

Generates an instance segmentation without prompts, using a point grid.

This class implements the same logic as https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py It decouples the computationally expensive steps of generating masks from the cheap post-processing operation to filter these masks to enable grid search and interactively changing the post-processing.

Use this class as follows:

amg = AutomaticMaskGenerator(predictor)
amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
Arguments:
  • predictor: The segment anything predictor.
  • points_per_side: The number of points to be sampled along one side of the image. If None, point_grids must provide explicit point sampling.
  • points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
  • crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
  • crop_overlap_ratio: Sets the degree to which crops overlap.
  • crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
  • point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
  • stability_score_offset: The amount to shift the cutoff when calculating the stability score.
AutomaticMaskGenerator( predictor: segment_anything.predictor.SamPredictor, points_per_side: Optional[int] = 32, points_per_batch: Optional[int] = None, crop_n_layers: int = 0, crop_overlap_ratio: float = 0.3413333333333333, crop_n_points_downscale_factor: int = 1, point_grids: Optional[List[numpy.ndarray]] = None, stability_score_offset: float = 1.0)
367    def __init__(
368        self,
369        predictor: SamPredictor,
370        points_per_side: Optional[int] = 32,
371        points_per_batch: Optional[int] = None,
372        crop_n_layers: int = 0,
373        crop_overlap_ratio: float = 512 / 1500,
374        crop_n_points_downscale_factor: int = 1,
375        point_grids: Optional[List[np.ndarray]] = None,
376        stability_score_offset: float = 1.0,
377    ):
378        super().__init__()
379
380        if points_per_side is not None:
381            self.point_grids = amg_utils.build_all_layer_point_grids(
382                points_per_side, crop_n_layers, crop_n_points_downscale_factor,
383            )
384        elif point_grids is not None:
385            self.point_grids = point_grids
386        else:
387            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
388
389        self._predictor = predictor
390        self._points_per_side = points_per_side
391
392        # we set the points per batch to 16 for mps for performance reasons
393        # and otherwise keep them at the default of 64
394        if points_per_batch is None:
395            points_per_batch = 16 if str(predictor.device) == "mps" else 64
396        self._points_per_batch = points_per_batch
397
398        self._crop_n_layers = crop_n_layers
399        self._crop_overlap_ratio = crop_overlap_ratio
400        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
401        self._stability_score_offset = stability_score_offset
@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
450    @torch.no_grad()
451    def initialize(
452        self,
453        image: np.ndarray,
454        image_embeddings: Optional[util.ImageEmbeddings] = None,
455        i: Optional[int] = None,
456        verbose: bool = False,
457        pbar_init: Optional[callable] = None,
458        pbar_update: Optional[callable] = None,
459    ) -> None:
460        """Initialize image embeddings and masks for an image.
461
462        Args:
463            image: The input image, volume or timeseries.
464            image_embeddings: Optional precomputed image embeddings.
465                See `util.precompute_image_embeddings` for details.
466            i: Index for the image data. Required if `image` has three spatial dimensions
467                or a time dimension and two spatial dimensions.
468            verbose: Whether to print computation progress.
469            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
470                Can be used together with pbar_update to handle napari progress bar in other thread.
471                To enables using this function within a threadworker.
472            pbar_update: Callback to update an external progress bar.
473        """
474        original_size = image.shape[:2]
475        self._original_size = original_size
476
477        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
478            original_size, self._crop_n_layers, self._crop_overlap_ratio
479        )
480
481        # We can set fixed image embeddings if we only have a single crop box (the default setting).
482        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
483        if len(crop_boxes) == 1:
484            if image_embeddings is None:
485                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
486            util.set_precomputed(self._predictor, image_embeddings, i=i)
487            precomputed_embeddings = True
488        else:
489            precomputed_embeddings = False
490
491        # we need to cast to the image representation that is compatible with SAM
492        image = util._to_image(image)
493
494        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
495
496        crop_list = []
497        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
498            crop_data = self._process_crop(
499                image, crop_box, layer_idx,
500                precomputed_embeddings=precomputed_embeddings,
501                pbar_init=pbar_init, pbar_update=pbar_update,
502            )
503            crop_list.append(crop_data)
504        pbar_close()
505
506        self._is_initialized = True
507        self._crop_list = crop_list
508        self._crop_boxes = crop_boxes

Initialize image embeddings and masks for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to print computation progress.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
@torch.no_grad()
def generate( self, pred_iou_thresh: float = 0.88, stability_score_thresh: float = 0.95, box_nms_thresh: float = 0.7, crop_nms_thresh: float = 0.7, min_mask_region_area: int = 0, output_mode: str = 'binary_mask') -> List[Dict[str, Any]]:
510    @torch.no_grad()
511    def generate(
512        self,
513        pred_iou_thresh: float = 0.88,
514        stability_score_thresh: float = 0.95,
515        box_nms_thresh: float = 0.7,
516        crop_nms_thresh: float = 0.7,
517        min_mask_region_area: int = 0,
518        output_mode: str = "binary_mask",
519    ) -> List[Dict[str, Any]]:
520        """Generate instance segmentation for the currently initialized image.
521
522        Args:
523            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
524            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
525                under changes to the cutoff used to binarize the model prediction.
526            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
527            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
528            min_mask_region_area: Minimal size for the predicted masks.
529            output_mode: The form masks are returned in.
530
531        Returns:
532            The instance segmentation masks.
533        """
534        if not self.is_initialized:
535            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
536
537        data = amg_utils.MaskData()
538        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
539            crop_data = self._postprocess_batch(
540                data=deepcopy(data_),
541                crop_box=crop_box, original_size=self.original_size,
542                pred_iou_thresh=pred_iou_thresh,
543                stability_score_thresh=stability_score_thresh,
544                box_nms_thresh=box_nms_thresh
545            )
546            data.cat(crop_data)
547
548        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
549            # Prefer masks from smaller crops
550            scores = 1 / box_area(data["crop_boxes"])
551            scores = scores.to(data["boxes"].device)
552            keep_by_nms = batched_nms(
553                data["boxes"].float(),
554                scores,
555                torch.zeros_like(data["boxes"][:, 0]),  # categories
556                iou_threshold=crop_nms_thresh,
557            )
558            data.filter(keep_by_nms)
559
560        data.to_numpy()
561        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
562        return masks

Generate instance segmentation for the currently initialized image.

Arguments:
  • pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
  • stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask under changes to the cutoff used to binarize the model prediction.
  • box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
  • crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
  • min_mask_region_area: Minimal size for the predicted masks.
  • output_mode: The form masks are returned in.
Returns:

The instance segmentation masks.

class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
592class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
593    """Generates an instance segmentation without prompts, using a point grid.
594
595    Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
596
597    Args:
598        predictor: The segment anything predictor.
599        points_per_side: The number of points to be sampled along one side of the image.
600            If None, `point_grids` must provide explicit point sampling.
601        points_per_batch: The number of points run simultaneously by the model.
602            Higher numbers may be faster but use more GPU memory.
603        point_grids: A lisst over explicit grids of points used for sampling masks.
604            Normalized to [0, 1] with respect to the image coordinate system.
605        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
606    """
607
608    # We only expose the arguments that make sense for the tiled mask generator.
609    # Anything related to crops doesn't make sense, because we re-use that functionality
610    # for tiling, so these parameters wouldn't have any effect.
611    def __init__(
612        self,
613        predictor: SamPredictor,
614        points_per_side: Optional[int] = 32,
615        points_per_batch: int = 64,
616        point_grids: Optional[List[np.ndarray]] = None,
617        stability_score_offset: float = 1.0,
618    ) -> None:
619        super().__init__(
620            predictor=predictor,
621            points_per_side=points_per_side,
622            points_per_batch=points_per_batch,
623            point_grids=point_grids,
624            stability_score_offset=stability_score_offset,
625        )
626
627    @torch.no_grad()
628    def initialize(
629        self,
630        image: np.ndarray,
631        image_embeddings: Optional[util.ImageEmbeddings] = None,
632        i: Optional[int] = None,
633        tile_shape: Optional[Tuple[int, int]] = None,
634        halo: Optional[Tuple[int, int]] = None,
635        verbose: bool = False,
636        pbar_init: Optional[callable] = None,
637        pbar_update: Optional[callable] = None,
638        batch_size: int = 1,
639    ) -> None:
640        """Initialize image embeddings and masks for an image.
641
642        Args:
643            image: The input image, volume or timeseries.
644            image_embeddings: Optional precomputed image embeddings.
645                See `util.precompute_image_embeddings` for details.
646            i: Index for the image data. Required if `image` has three spatial dimensions
647                or a time dimension and two spatial dimensions.
648            tile_shape: The tile shape for embedding prediction.
649            halo: The overlap of between tiles.
650            verbose: Whether to print computation progress.
651            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
652                Can be used together with pbar_update to handle napari progress bar in other thread.
653                To enables using this function within a threadworker.
654            pbar_update: Callback to update an external progress bar.
655            batch_size: The batch size for image embedding prediction.
656        """
657        original_size = image.shape[:2]
658        self._original_size = original_size
659
660        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
661            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size
662        )
663
664        tiling = blocking([0, 0], original_size, tile_shape)
665        n_tiles = tiling.numberOfBlocks
666
667        # The crop box is always the full local tile.
668        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
669        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
670
671        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
672        pbar_init(n_tiles, "Compute masks for tile")
673
674        # We need to cast to the image representation that is compatible with SAM.
675        image = util._to_image(image)
676
677        mask_data = []
678        for tile_id in range(n_tiles):
679            # set the pre-computed embeddings for this tile
680            features = image_embeddings["features"][tile_id]
681            tile_embeddings = {
682                "features": features,
683                "input_size": features.attrs["input_size"],
684                "original_size": features.attrs["original_size"],
685            }
686            util.set_precomputed(self._predictor, tile_embeddings, i)
687
688            # compute the mask data for this tile and append it
689            this_mask_data = self._process_crop(
690                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
691            )
692            mask_data.append(this_mask_data)
693            pbar_update(1)
694        pbar_close()
695
696        # set the initialized data
697        self._is_initialized = True
698        self._crop_list = mask_data
699        self._crop_boxes = crop_boxes

Generates an instance segmentation without prompts, using a point grid.

Implements the same functionality as AutomaticMaskGenerator but for tiled embeddings.

Arguments:
  • predictor: The segment anything predictor.
  • points_per_side: The number of points to be sampled along one side of the image. If None, point_grids must provide explicit point sampling.
  • points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
  • point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
  • stability_score_offset: The amount to shift the cutoff when calculating the stability score.
TiledAutomaticMaskGenerator( predictor: segment_anything.predictor.SamPredictor, points_per_side: Optional[int] = 32, points_per_batch: int = 64, point_grids: Optional[List[numpy.ndarray]] = None, stability_score_offset: float = 1.0)
611    def __init__(
612        self,
613        predictor: SamPredictor,
614        points_per_side: Optional[int] = 32,
615        points_per_batch: int = 64,
616        point_grids: Optional[List[np.ndarray]] = None,
617        stability_score_offset: float = 1.0,
618    ) -> None:
619        super().__init__(
620            predictor=predictor,
621            points_per_side=points_per_side,
622            points_per_batch=points_per_batch,
623            point_grids=point_grids,
624            stability_score_offset=stability_score_offset,
625        )
@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None, batch_size: int = 1) -> None:
627    @torch.no_grad()
628    def initialize(
629        self,
630        image: np.ndarray,
631        image_embeddings: Optional[util.ImageEmbeddings] = None,
632        i: Optional[int] = None,
633        tile_shape: Optional[Tuple[int, int]] = None,
634        halo: Optional[Tuple[int, int]] = None,
635        verbose: bool = False,
636        pbar_init: Optional[callable] = None,
637        pbar_update: Optional[callable] = None,
638        batch_size: int = 1,
639    ) -> None:
640        """Initialize image embeddings and masks for an image.
641
642        Args:
643            image: The input image, volume or timeseries.
644            image_embeddings: Optional precomputed image embeddings.
645                See `util.precompute_image_embeddings` for details.
646            i: Index for the image data. Required if `image` has three spatial dimensions
647                or a time dimension and two spatial dimensions.
648            tile_shape: The tile shape for embedding prediction.
649            halo: The overlap of between tiles.
650            verbose: Whether to print computation progress.
651            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
652                Can be used together with pbar_update to handle napari progress bar in other thread.
653                To enables using this function within a threadworker.
654            pbar_update: Callback to update an external progress bar.
655            batch_size: The batch size for image embedding prediction.
656        """
657        original_size = image.shape[:2]
658        self._original_size = original_size
659
660        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
661            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size
662        )
663
664        tiling = blocking([0, 0], original_size, tile_shape)
665        n_tiles = tiling.numberOfBlocks
666
667        # The crop box is always the full local tile.
668        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
669        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
670
671        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
672        pbar_init(n_tiles, "Compute masks for tile")
673
674        # We need to cast to the image representation that is compatible with SAM.
675        image = util._to_image(image)
676
677        mask_data = []
678        for tile_id in range(n_tiles):
679            # set the pre-computed embeddings for this tile
680            features = image_embeddings["features"][tile_id]
681            tile_embeddings = {
682                "features": features,
683                "input_size": features.attrs["input_size"],
684                "original_size": features.attrs["original_size"],
685            }
686            util.set_precomputed(self._predictor, tile_embeddings, i)
687
688            # compute the mask data for this tile and append it
689            this_mask_data = self._process_crop(
690                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
691            )
692            mask_data.append(this_mask_data)
693            pbar_update(1)
694        pbar_close()
695
696        # set the initialized data
697        self._is_initialized = True
698        self._crop_list = mask_data
699        self._crop_boxes = crop_boxes

Initialize image embeddings and masks for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_shape: The tile shape for embedding prediction.
  • halo: The overlap of between tiles.
  • verbose: Whether to print computation progress.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
  • batch_size: The batch size for image embedding prediction.
class DecoderAdapter(torch.nn.modules.module.Module):
707class DecoderAdapter(torch.nn.Module):
708    """Adapter to contain the UNETR decoder in a single module.
709
710    To apply the decoder on top of pre-computed embeddings for the segmentation functionality.
711    See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
712    """
713    def __init__(self, unetr):
714        super().__init__()
715
716        self.base = unetr.base
717        self.out_conv = unetr.out_conv
718        self.deconv_out = unetr.deconv_out
719        self.decoder_head = unetr.decoder_head
720        self.final_activation = unetr.final_activation
721        self.postprocess_masks = unetr.postprocess_masks
722
723        self.decoder = unetr.decoder
724        self.deconv1 = unetr.deconv1
725        self.deconv2 = unetr.deconv2
726        self.deconv3 = unetr.deconv3
727        self.deconv4 = unetr.deconv4
728
729    def _forward_impl(self, input_):
730        z12 = input_
731
732        z9 = self.deconv1(z12)
733        z6 = self.deconv2(z9)
734        z3 = self.deconv3(z6)
735        z0 = self.deconv4(z3)
736
737        updated_from_encoder = [z9, z6, z3]
738
739        x = self.base(z12)
740        x = self.decoder(x, encoder_inputs=updated_from_encoder)
741        x = self.deconv_out(x)
742
743        x = torch.cat([x, z0], dim=1)
744        x = self.decoder_head(x)
745
746        x = self.out_conv(x)
747        if self.final_activation is not None:
748            x = self.final_activation(x)
749        return x
750
751    def forward(self, input_, input_shape, original_shape):
752        x = self._forward_impl(input_)
753        x = self.postprocess_masks(x, input_shape, original_shape)
754        return x

Adapter to contain the UNETR decoder in a single module.

To apply the decoder on top of pre-computed embeddings for the segmentation functionality. See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py

DecoderAdapter(unetr)
713    def __init__(self, unetr):
714        super().__init__()
715
716        self.base = unetr.base
717        self.out_conv = unetr.out_conv
718        self.deconv_out = unetr.deconv_out
719        self.decoder_head = unetr.decoder_head
720        self.final_activation = unetr.final_activation
721        self.postprocess_masks = unetr.postprocess_masks
722
723        self.decoder = unetr.decoder
724        self.deconv1 = unetr.deconv1
725        self.deconv2 = unetr.deconv2
726        self.deconv3 = unetr.deconv3
727        self.deconv4 = unetr.deconv4

Initialize internal Module state, shared by both nn.Module and ScriptModule.

base
out_conv
deconv_out
decoder_head
final_activation
postprocess_masks
decoder
deconv1
deconv2
deconv3
deconv4
def forward(self, input_, input_shape, original_shape):
751    def forward(self, input_, input_shape, original_shape):
752        x = self._forward_impl(input_)
753        x = self.postprocess_masks(x, input_shape, original_shape)
754        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def get_unetr( image_encoder: torch.nn.modules.module.Module, decoder_state: Optional[collections.OrderedDict[str, torch.Tensor]] = None, device: Union[str, torch.device, NoneType] = None, out_channels: int = 3, flexible_load_checkpoint: bool = False) -> torch.nn.modules.module.Module:
757def get_unetr(
758    image_encoder: torch.nn.Module,
759    decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
760    device: Optional[Union[str, torch.device]] = None,
761    out_channels: int = 3,
762    flexible_load_checkpoint: bool = False,
763) -> torch.nn.Module:
764    """Get UNETR model for automatic instance segmentation.
765
766    Args:
767        image_encoder: The image encoder of the SAM model.
768            This is used as encoder by the UNETR too.
769        decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
770        device: The device.
771        out_channels: The number of output channels.
772        flexible_load_checkpoint: Whether to allow reinitialization of parameters
773            which could not be found in the provided decoder state.
774
775    Returns:
776        The UNETR model.
777    """
778    device = util.get_device(device)
779
780    unetr = UNETR(
781        backbone="sam",
782        encoder=image_encoder,
783        out_channels=out_channels,
784        use_sam_stats=True,
785        final_activation="Sigmoid",
786        use_skip_connection=False,
787        resize_input=True,
788    )
789    if decoder_state is not None:
790        unetr_state_dict = unetr.state_dict()
791        for k, v in unetr_state_dict.items():
792            if not k.startswith("encoder"):
793                if flexible_load_checkpoint:  # Whether allow reinitalization of params, if not found.
794                    if k in decoder_state:  # First check whether the key is available in the provided decoder state.
795                        unetr_state_dict[k] = decoder_state[k]
796                    else:  # Otherwise, allow it to initialize it.
797                        warnings.warn(f"Could not find '{k}' in the pretrained state dict. Hence, we reinitialize it.")
798                        unetr_state_dict[k] = v
799
800                else:  # Whether be strict on finding the parameter in the decoder state.
801                    if k not in decoder_state:
802                        raise RuntimeError(f"The parameters for '{k}' could not be found.")
803                    unetr_state_dict[k] = decoder_state[k]
804
805        unetr.load_state_dict(unetr_state_dict)
806
807    unetr.to(device)
808    return unetr

Get UNETR model for automatic instance segmentation.

Arguments:
  • image_encoder: The image encoder of the SAM model. This is used as encoder by the UNETR too.
  • decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
  • device: The device.
  • out_channels: The number of output channels.
  • flexible_load_checkpoint: Whether to allow reinitialization of parameters which could not be found in the provided decoder state.
Returns:

The UNETR model.

def get_decoder( image_encoder: torch.nn.modules.module.Module, decoder_state: collections.OrderedDict[str, torch.Tensor], device: Union[str, torch.device, NoneType] = None) -> DecoderAdapter:
811def get_decoder(
812    image_encoder: torch.nn.Module,
813    decoder_state: OrderedDict[str, torch.Tensor],
814    device: Optional[Union[str, torch.device]] = None,
815) -> DecoderAdapter:
816    """Get decoder to predict outputs for automatic instance segmentation
817
818    Args:
819        image_encoder: The image encoder of the SAM model.
820        decoder_state: State to initialize the weights of the UNETR decoder.
821        device: The device.
822
823    Returns:
824        The decoder for instance segmentation.
825    """
826    unetr = get_unetr(image_encoder, decoder_state, device)
827    return DecoderAdapter(unetr)

Get decoder to predict outputs for automatic instance segmentation

Arguments:
  • image_encoder: The image encoder of the SAM model.
  • decoder_state: State to initialize the weights of the UNETR decoder.
  • device: The device.
Returns:

The decoder for instance segmentation.

def get_predictor_and_decoder( model_type: str, checkpoint_path: Union[str, os.PathLike], device: Union[str, torch.device, NoneType] = None, peft_kwargs: Optional[Dict] = None) -> Tuple[segment_anything.predictor.SamPredictor, DecoderAdapter]:
830def get_predictor_and_decoder(
831    model_type: str,
832    checkpoint_path: Union[str, os.PathLike],
833    device: Optional[Union[str, torch.device]] = None,
834    peft_kwargs: Optional[Dict] = None,
835) -> Tuple[SamPredictor, DecoderAdapter]:
836    """Load the SAM model (predictor) and instance segmentation decoder.
837
838    This requires a checkpoint that contains the state for both predictor
839    and decoder.
840
841    Args:
842        model_type: The type of the image encoder used in the SAM model.
843        checkpoint_path: Path to the checkpoint from which to load the data.
844        device: The device.
845        peft_kwargs: Keyword arguments for the PEFT wrapper class.
846
847    Returns:
848        The SAM predictor.
849        The decoder for instance segmentation.
850    """
851    device = util.get_device(device)
852    predictor, state = util.get_sam_model(
853        model_type=model_type,
854        checkpoint_path=checkpoint_path,
855        device=device,
856        return_state=True,
857        peft_kwargs=peft_kwargs,
858    )
859    if "decoder_state" not in state:
860        raise ValueError(
861            f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
862        )
863    decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
864    return predictor, decoder

Load the SAM model (predictor) and instance segmentation decoder.

This requires a checkpoint that contains the state for both predictor and decoder.

Arguments:
  • model_type: The type of the image encoder used in the SAM model.
  • checkpoint_path: Path to the checkpoint from which to load the data.
  • device: The device.
  • peft_kwargs: Keyword arguments for the PEFT wrapper class.
Returns:

The SAM predictor. The decoder for instance segmentation.

class InstanceSegmentationWithDecoder:
 916class InstanceSegmentationWithDecoder:
 917    """Generates an instance segmentation without prompts, using a decoder.
 918
 919    Implements the same interface as `AutomaticMaskGenerator`.
 920
 921    Use this class as follows:
 922    ```python
 923    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
 924    segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
 925    masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
 926    ```
 927
 928    Args:
 929        predictor: The segment anything predictor.
 930        decoder: The decoder to predict intermediate representations
 931            for instance segmentation.
 932    """
 933    def __init__(
 934        self,
 935        predictor: SamPredictor,
 936        decoder: torch.nn.Module,
 937    ) -> None:
 938        self._predictor = predictor
 939        self._decoder = decoder
 940
 941        # The decoder outputs.
 942        self._foreground = None
 943        self._center_distances = None
 944        self._boundary_distances = None
 945
 946        self._is_initialized = False
 947
 948    @property
 949    def is_initialized(self):
 950        """Whether the mask generator has already been initialized.
 951        """
 952        return self._is_initialized
 953
 954    @torch.no_grad()
 955    def initialize(
 956        self,
 957        image: np.ndarray,
 958        image_embeddings: Optional[util.ImageEmbeddings] = None,
 959        i: Optional[int] = None,
 960        verbose: bool = False,
 961        pbar_init: Optional[callable] = None,
 962        pbar_update: Optional[callable] = None,
 963    ) -> None:
 964        """Initialize image embeddings and decoder predictions for an image.
 965
 966        Args:
 967            image: The input image, volume or timeseries.
 968            image_embeddings: Optional precomputed image embeddings.
 969                See `util.precompute_image_embeddings` for details.
 970            i: Index for the image data. Required if `image` has three spatial dimensions
 971                or a time dimension and two spatial dimensions.
 972            verbose: Whether to be verbose.
 973            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 974                Can be used together with pbar_update to handle napari progress bar in other thread.
 975                To enables using this function within a threadworker.
 976            pbar_update: Callback to update an external progress bar.
 977        """
 978        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 979        pbar_init(1, "Initialize instance segmentation with decoder")
 980
 981        if image_embeddings is None:
 982            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 983
 984        # Get the image embeddings from the predictor.
 985        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
 986        embeddings = self._predictor.features
 987        input_shape = tuple(self._predictor.input_size)
 988        original_shape = tuple(self._predictor.original_size)
 989
 990        # Run prediction with the UNETR decoder.
 991        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
 992        assert output.shape[0] == 3, f"{output.shape}"
 993        pbar_update(1)
 994        pbar_close()
 995
 996        # Set the state.
 997        self._foreground = output[0]
 998        self._center_distances = output[1]
 999        self._boundary_distances = output[2]
1000        self._is_initialized = True
1001
1002    def _to_masks(self, segmentation, output_mode):
1003        if output_mode != "binary_mask":
1004            raise NotImplementedError
1005
1006        props = regionprops(segmentation)
1007        ndim = segmentation.ndim
1008        assert ndim in (2, 3)
1009
1010        shape = segmentation.shape
1011        if ndim == 2:
1012            crop_box = [0, shape[1], 0, shape[0]]
1013        else:
1014            crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]
1015
1016        # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
1017        def to_bbox_2d(bbox):
1018            y0, x0 = bbox[0], bbox[1]
1019            w = bbox[3] - x0
1020            h = bbox[2] - y0
1021            return [x0, w, y0, h]
1022
1023        def to_bbox_3d(bbox):
1024            z0, y0, x0 = bbox[0], bbox[1], bbox[2]
1025            w = bbox[5] - x0
1026            h = bbox[4] - y0
1027            d = bbox[3] - y0
1028            return [x0, w, y0, h, z0, d]
1029
1030        to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
1031        masks = [
1032            {
1033                "segmentation": segmentation == prop.label,
1034                "area": prop.area,
1035                "bbox": to_bbox(prop.bbox),
1036                "crop_box": crop_box,
1037                "seg_id": prop.label,
1038            } for prop in props
1039        ]
1040        return masks
1041
1042    def generate(
1043        self,
1044        center_distance_threshold: float = 0.5,
1045        boundary_distance_threshold: float = 0.5,
1046        foreground_threshold: float = 0.5,
1047        foreground_smoothing: float = 1.0,
1048        distance_smoothing: float = 1.6,
1049        min_size: int = 0,
1050        output_mode: Optional[str] = "binary_mask",
1051        tile_shape: Optional[Tuple[int, int]] = None,
1052        halo: Optional[Tuple[int, int]] = None,
1053        n_threads: Optional[int] = None,
1054    ) -> List[Dict[str, Any]]:
1055        """Generate instance segmentation for the currently initialized image.
1056
1057        Args:
1058            center_distance_threshold: Center distance predictions below this value will be
1059                used to find seeds (intersected with thresholded boundary distance predictions).
1060            boundary_distance_threshold: Boundary distance predictions below this value will be
1061                used to find seeds (intersected with thresholded center distance predictions).
1062            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1063                checkerboard artifacts in the prediction.
1064            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1065            distance_smoothing: Sigma value for smoothing the distance predictions.
1066            min_size: Minimal object size in the segmentation result.
1067            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
1068            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1069                This parameter is independent from the tile shape for computing the embeddings.
1070                If not given then post-processing will not be parallelized.
1071            halo: Halo for parallel post-processing. See also `tile_shape`.
1072            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1073
1074        Returns:
1075            The instance segmentation masks.
1076        """
1077        if not self.is_initialized:
1078            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1079
1080        if foreground_smoothing > 0:
1081            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1082        else:
1083            foreground = self._foreground
1084
1085        if tile_shape is None:
1086            segmentation = watershed_from_center_and_boundary_distances(
1087                center_distances=self._center_distances,
1088                boundary_distances=self._boundary_distances,
1089                foreground_map=foreground,
1090                center_distance_threshold=center_distance_threshold,
1091                boundary_distance_threshold=boundary_distance_threshold,
1092                foreground_threshold=foreground_threshold,
1093                distance_smoothing=distance_smoothing,
1094                min_size=min_size,
1095            )
1096        else:
1097            if halo is None:
1098                raise ValueError("You must pass a value for halo if tile_shape is given.")
1099            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1100                center_distances=self._center_distances,
1101                boundary_distances=self._boundary_distances,
1102                foreground_map=foreground,
1103                center_distance_threshold=center_distance_threshold,
1104                boundary_distance_threshold=boundary_distance_threshold,
1105                foreground_threshold=foreground_threshold,
1106                distance_smoothing=distance_smoothing,
1107                min_size=min_size,
1108                tile_shape=tile_shape,
1109                halo=halo,
1110                n_threads=n_threads,
1111                verbose=False,
1112            )
1113
1114        if output_mode is not None:
1115            segmentation = self._to_masks(segmentation, output_mode)
1116        return segmentation
1117
1118    def get_state(self) -> Dict[str, Any]:
1119        """Get the initialized state of the instance segmenter.
1120
1121        Returns:
1122            Instance segmentation state.
1123        """
1124        if not self.is_initialized:
1125            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1126
1127        return {
1128            "foreground": self._foreground,
1129            "center_distances": self._center_distances,
1130            "boundary_distances": self._boundary_distances,
1131        }
1132
1133    def set_state(self, state: Dict[str, Any]) -> None:
1134        """Set the state of the instance segmenter.
1135
1136        Args:
1137            state: The instance segmentation state
1138        """
1139        self._foreground = state["foreground"]
1140        self._center_distances = state["center_distances"]
1141        self._boundary_distances = state["boundary_distances"]
1142        self._is_initialized = True
1143
1144    def clear_state(self):
1145        """Clear the state of the instance segmenter.
1146        """
1147        self._foreground = None
1148        self._center_distances = None
1149        self._boundary_distances = None
1150        self._is_initialized = False

Generates an instance segmentation without prompts, using a decoder.

Implements the same interface as AutomaticMaskGenerator.

Use this class as follows:

segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
Arguments:
  • predictor: The segment anything predictor.
  • decoder: The decoder to predict intermediate representations for instance segmentation.
InstanceSegmentationWithDecoder( predictor: segment_anything.predictor.SamPredictor, decoder: torch.nn.modules.module.Module)
933    def __init__(
934        self,
935        predictor: SamPredictor,
936        decoder: torch.nn.Module,
937    ) -> None:
938        self._predictor = predictor
939        self._decoder = decoder
940
941        # The decoder outputs.
942        self._foreground = None
943        self._center_distances = None
944        self._boundary_distances = None
945
946        self._is_initialized = False
is_initialized
948    @property
949    def is_initialized(self):
950        """Whether the mask generator has already been initialized.
951        """
952        return self._is_initialized

Whether the mask generator has already been initialized.

@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
 954    @torch.no_grad()
 955    def initialize(
 956        self,
 957        image: np.ndarray,
 958        image_embeddings: Optional[util.ImageEmbeddings] = None,
 959        i: Optional[int] = None,
 960        verbose: bool = False,
 961        pbar_init: Optional[callable] = None,
 962        pbar_update: Optional[callable] = None,
 963    ) -> None:
 964        """Initialize image embeddings and decoder predictions for an image.
 965
 966        Args:
 967            image: The input image, volume or timeseries.
 968            image_embeddings: Optional precomputed image embeddings.
 969                See `util.precompute_image_embeddings` for details.
 970            i: Index for the image data. Required if `image` has three spatial dimensions
 971                or a time dimension and two spatial dimensions.
 972            verbose: Whether to be verbose.
 973            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 974                Can be used together with pbar_update to handle napari progress bar in other thread.
 975                To enables using this function within a threadworker.
 976            pbar_update: Callback to update an external progress bar.
 977        """
 978        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 979        pbar_init(1, "Initialize instance segmentation with decoder")
 980
 981        if image_embeddings is None:
 982            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 983
 984        # Get the image embeddings from the predictor.
 985        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
 986        embeddings = self._predictor.features
 987        input_shape = tuple(self._predictor.input_size)
 988        original_shape = tuple(self._predictor.original_size)
 989
 990        # Run prediction with the UNETR decoder.
 991        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
 992        assert output.shape[0] == 3, f"{output.shape}"
 993        pbar_update(1)
 994        pbar_close()
 995
 996        # Set the state.
 997        self._foreground = output[0]
 998        self._center_distances = output[1]
 999        self._boundary_distances = output[2]
1000        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to be verbose.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
def generate( self, center_distance_threshold: float = 0.5, boundary_distance_threshold: float = 0.5, foreground_threshold: float = 0.5, foreground_smoothing: float = 1.0, distance_smoothing: float = 1.6, min_size: int = 0, output_mode: Optional[str] = 'binary_mask', tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, n_threads: Optional[int] = None) -> List[Dict[str, Any]]:
1042    def generate(
1043        self,
1044        center_distance_threshold: float = 0.5,
1045        boundary_distance_threshold: float = 0.5,
1046        foreground_threshold: float = 0.5,
1047        foreground_smoothing: float = 1.0,
1048        distance_smoothing: float = 1.6,
1049        min_size: int = 0,
1050        output_mode: Optional[str] = "binary_mask",
1051        tile_shape: Optional[Tuple[int, int]] = None,
1052        halo: Optional[Tuple[int, int]] = None,
1053        n_threads: Optional[int] = None,
1054    ) -> List[Dict[str, Any]]:
1055        """Generate instance segmentation for the currently initialized image.
1056
1057        Args:
1058            center_distance_threshold: Center distance predictions below this value will be
1059                used to find seeds (intersected with thresholded boundary distance predictions).
1060            boundary_distance_threshold: Boundary distance predictions below this value will be
1061                used to find seeds (intersected with thresholded center distance predictions).
1062            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
1063                checkerboard artifacts in the prediction.
1064            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
1065            distance_smoothing: Sigma value for smoothing the distance predictions.
1066            min_size: Minimal object size in the segmentation result.
1067            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
1068            tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1069                This parameter is independent from the tile shape for computing the embeddings.
1070                If not given then post-processing will not be parallelized.
1071            halo: Halo for parallel post-processing. See also `tile_shape`.
1072            n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
1073
1074        Returns:
1075            The instance segmentation masks.
1076        """
1077        if not self.is_initialized:
1078            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1079
1080        if foreground_smoothing > 0:
1081            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1082        else:
1083            foreground = self._foreground
1084
1085        if tile_shape is None:
1086            segmentation = watershed_from_center_and_boundary_distances(
1087                center_distances=self._center_distances,
1088                boundary_distances=self._boundary_distances,
1089                foreground_map=foreground,
1090                center_distance_threshold=center_distance_threshold,
1091                boundary_distance_threshold=boundary_distance_threshold,
1092                foreground_threshold=foreground_threshold,
1093                distance_smoothing=distance_smoothing,
1094                min_size=min_size,
1095            )
1096        else:
1097            if halo is None:
1098                raise ValueError("You must pass a value for halo if tile_shape is given.")
1099            segmentation = _watershed_from_center_and_boundary_distances_parallel(
1100                center_distances=self._center_distances,
1101                boundary_distances=self._boundary_distances,
1102                foreground_map=foreground,
1103                center_distance_threshold=center_distance_threshold,
1104                boundary_distance_threshold=boundary_distance_threshold,
1105                foreground_threshold=foreground_threshold,
1106                distance_smoothing=distance_smoothing,
1107                min_size=min_size,
1108                tile_shape=tile_shape,
1109                halo=halo,
1110                n_threads=n_threads,
1111                verbose=False,
1112            )
1113
1114        if output_mode is not None:
1115            segmentation = self._to_masks(segmentation, output_mode)
1116        return segmentation

Generate instance segmentation for the currently initialized image.

Arguments:
  • center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions).
  • boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions).
  • foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction.
  • foreground_threshold: Foreground predictions above this value will be used as foreground mask.
  • distance_smoothing: Sigma value for smoothing the distance predictions.
  • min_size: Minimal object size in the segmentation result.
  • output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
  • tile_shape: Tile shape for parallelizing the instance segmentation post-processing. This parameter is independent from the tile shape for computing the embeddings. If not given then post-processing will not be parallelized.
  • halo: Halo for parallel post-processing. See also tile_shape.
  • n_threads: Number of threads for parallel post-processing. See also tile_shape.
Returns:

The instance segmentation masks.

def get_state(self) -> Dict[str, Any]:
1118    def get_state(self) -> Dict[str, Any]:
1119        """Get the initialized state of the instance segmenter.
1120
1121        Returns:
1122            Instance segmentation state.
1123        """
1124        if not self.is_initialized:
1125            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1126
1127        return {
1128            "foreground": self._foreground,
1129            "center_distances": self._center_distances,
1130            "boundary_distances": self._boundary_distances,
1131        }

Get the initialized state of the instance segmenter.

Returns:

Instance segmentation state.

def set_state(self, state: Dict[str, Any]) -> None:
1133    def set_state(self, state: Dict[str, Any]) -> None:
1134        """Set the state of the instance segmenter.
1135
1136        Args:
1137            state: The instance segmentation state
1138        """
1139        self._foreground = state["foreground"]
1140        self._center_distances = state["center_distances"]
1141        self._boundary_distances = state["boundary_distances"]
1142        self._is_initialized = True

Set the state of the instance segmenter.

Arguments:
  • state: The instance segmentation state
def clear_state(self):
1144    def clear_state(self):
1145        """Clear the state of the instance segmenter.
1146        """
1147        self._foreground = None
1148        self._center_distances = None
1149        self._boundary_distances = None
1150        self._is_initialized = False

Clear the state of the instance segmenter.

class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1153class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1154    """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings.
1155    """
1156
1157    # Apply the decoder in a batched fashion, and then perform the resizing independently per output.
1158    # This is necessary, because the individual tiles may have different tile shapes due to border tiles.
1159    def _predict_decoder(self, batched_embeddings, input_shapes, original_shapes):
1160        batched_embeddings = torch.cat(batched_embeddings)
1161        output = self._decoder._forward_impl(batched_embeddings)
1162
1163        batched_output = []
1164        for x, input_shape, original_shape in zip(output, input_shapes, original_shapes):
1165            x = self._decoder.postprocess_masks(x.unsqueeze(0), input_shape, original_shape).squeeze(0)
1166            batched_output.append(x.cpu().numpy())
1167        return batched_output
1168
1169    @torch.no_grad()
1170    def initialize(
1171        self,
1172        image: np.ndarray,
1173        image_embeddings: Optional[util.ImageEmbeddings] = None,
1174        i: Optional[int] = None,
1175        tile_shape: Optional[Tuple[int, int]] = None,
1176        halo: Optional[Tuple[int, int]] = None,
1177        verbose: bool = False,
1178        pbar_init: Optional[callable] = None,
1179        pbar_update: Optional[callable] = None,
1180        batch_size: int = 1,
1181    ) -> None:
1182        """Initialize image embeddings and decoder predictions for an image.
1183
1184        Args:
1185            image: The input image, volume or timeseries.
1186            image_embeddings: Optional precomputed image embeddings.
1187                See `util.precompute_image_embeddings` for details.
1188            i: Index for the image data. Required if `image` has three spatial dimensions
1189                or a time dimension and two spatial dimensions.
1190            tile_shape: Shape of the tiles for precomputing image embeddings.
1191            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1192            verbose: Dummy input to be compatible with other function signatures.
1193            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1194                Can be used together with pbar_update to handle napari progress bar in other thread.
1195                To enables using this function within a threadworker.
1196            pbar_update: Callback to update an external progress bar.
1197            batch_size: The batch size for image embedding computation and segmentation decoder prediction.
1198        """
1199        original_size = image.shape[:2]
1200        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1201            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size
1202        )
1203        tiling = blocking([0, 0], original_size, tile_shape)
1204
1205        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1206        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1207
1208        foreground = np.zeros(original_size, dtype="float32")
1209        center_distances = np.zeros(original_size, dtype="float32")
1210        boundary_distances = np.zeros(original_size, dtype="float32")
1211
1212        n_tiles = tiling.numberOfBlocks
1213        n_batches = int(np.ceil(n_tiles / batch_size))
1214
1215        for batch_id in range(n_batches):
1216            tile_start = batch_id * batch_size
1217            tile_stop = min(tile_start + batch_size, n_tiles)
1218
1219            batched_embeddings, input_shapes, original_shapes = [], [], []
1220            for tile_id in range(tile_start, tile_stop):
1221                # Get the image embeddings from the predictor for this tile.
1222                self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1223
1224                batched_embeddings.append(self._predictor.features)
1225                input_shapes.append(tuple(self._predictor.input_size))
1226                original_shapes.append(tuple(self._predictor.original_size))
1227
1228            batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes)
1229
1230            for output_id, tile_id in enumerate(range(tile_start, tile_stop)):
1231                output = batched_output[output_id]
1232                assert output.shape[0] == 3
1233
1234                # Set the predictions in the output for this tile.
1235                block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1236                local_bb = tuple(
1237                    slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1238                )
1239                inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1240
1241                foreground[inner_bb] = output[0][local_bb]
1242                center_distances[inner_bb] = output[1][local_bb]
1243                boundary_distances[inner_bb] = output[2][local_bb]
1244                pbar_update(1)
1245
1246        pbar_close()
1247
1248        # Set the state.
1249        self._foreground = foreground
1250        self._center_distances = center_distances
1251        self._boundary_distances = boundary_distances
1252        self._is_initialized = True

Same as InstanceSegmentationWithDecoder but for tiled image embeddings.

@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None, batch_size: int = 1) -> None:
1169    @torch.no_grad()
1170    def initialize(
1171        self,
1172        image: np.ndarray,
1173        image_embeddings: Optional[util.ImageEmbeddings] = None,
1174        i: Optional[int] = None,
1175        tile_shape: Optional[Tuple[int, int]] = None,
1176        halo: Optional[Tuple[int, int]] = None,
1177        verbose: bool = False,
1178        pbar_init: Optional[callable] = None,
1179        pbar_update: Optional[callable] = None,
1180        batch_size: int = 1,
1181    ) -> None:
1182        """Initialize image embeddings and decoder predictions for an image.
1183
1184        Args:
1185            image: The input image, volume or timeseries.
1186            image_embeddings: Optional precomputed image embeddings.
1187                See `util.precompute_image_embeddings` for details.
1188            i: Index for the image data. Required if `image` has three spatial dimensions
1189                or a time dimension and two spatial dimensions.
1190            tile_shape: Shape of the tiles for precomputing image embeddings.
1191            halo: Overlap of the tiles for tiled precomputation of image embeddings.
1192            verbose: Dummy input to be compatible with other function signatures.
1193            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1194                Can be used together with pbar_update to handle napari progress bar in other thread.
1195                To enables using this function within a threadworker.
1196            pbar_update: Callback to update an external progress bar.
1197            batch_size: The batch size for image embedding computation and segmentation decoder prediction.
1198        """
1199        original_size = image.shape[:2]
1200        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1201            self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose, batch_size=batch_size
1202        )
1203        tiling = blocking([0, 0], original_size, tile_shape)
1204
1205        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1206        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1207
1208        foreground = np.zeros(original_size, dtype="float32")
1209        center_distances = np.zeros(original_size, dtype="float32")
1210        boundary_distances = np.zeros(original_size, dtype="float32")
1211
1212        n_tiles = tiling.numberOfBlocks
1213        n_batches = int(np.ceil(n_tiles / batch_size))
1214
1215        for batch_id in range(n_batches):
1216            tile_start = batch_id * batch_size
1217            tile_stop = min(tile_start + batch_size, n_tiles)
1218
1219            batched_embeddings, input_shapes, original_shapes = [], [], []
1220            for tile_id in range(tile_start, tile_stop):
1221                # Get the image embeddings from the predictor for this tile.
1222                self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1223
1224                batched_embeddings.append(self._predictor.features)
1225                input_shapes.append(tuple(self._predictor.input_size))
1226                original_shapes.append(tuple(self._predictor.original_size))
1227
1228            batched_output = self._predict_decoder(batched_embeddings, input_shapes, original_shapes)
1229
1230            for output_id, tile_id in enumerate(range(tile_start, tile_stop)):
1231                output = batched_output[output_id]
1232                assert output.shape[0] == 3
1233
1234                # Set the predictions in the output for this tile.
1235                block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1236                local_bb = tuple(
1237                    slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1238                )
1239                inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1240
1241                foreground[inner_bb] = output[0][local_bb]
1242                center_distances[inner_bb] = output[1][local_bb]
1243                boundary_distances[inner_bb] = output[2][local_bb]
1244                pbar_update(1)
1245
1246        pbar_close()
1247
1248        # Set the state.
1249        self._foreground = foreground
1250        self._center_distances = center_distances
1251        self._boundary_distances = boundary_distances
1252        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_shape: Shape of the tiles for precomputing image embeddings.
  • halo: Overlap of the tiles for tiled precomputation of image embeddings.
  • verbose: Dummy input to be compatible with other function signatures.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
  • batch_size: The batch size for image embedding computation and segmentation decoder prediction.
def get_amg( predictor: segment_anything.predictor.SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.modules.module.Module] = None, **kwargs) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1255def get_amg(
1256    predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs,
1257) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1258    """Get the automatic mask generator class.
1259
1260    Args:
1261        predictor: The segment anything predictor.
1262        is_tiled: Whether tiled embeddings are used.
1263        decoder: Decoder to predict instacne segmmentation.
1264        kwargs: The keyword arguments for the amg class.
1265
1266    Returns:
1267        The automatic mask generator.
1268    """
1269    if decoder is None:
1270        segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator
1271        segmenter = segmenter_class(predictor, **kwargs)
1272    else:
1273        segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder
1274        segmenter = segmenter_class(predictor, decoder, **kwargs)
1275
1276    return segmenter

Get the automatic mask generator class.

Arguments:
  • predictor: The segment anything predictor.
  • is_tiled: Whether tiled embeddings are used.
  • decoder: Decoder to predict instacne segmmentation.
  • kwargs: The keyword arguments for the amg class.
Returns:

The automatic mask generator.