micro_sam.instance_segmentation

Automated instance segmentation functionality. The classes implemented here extend the automatic instance segmentation from Segment Anything: https://computational-cell-analytics.github.io/micro-sam/micro_sam.html

   1"""
   2Automated instance segmentation functionality.
   3The classes implemented here extend the automatic instance segmentation from Segment Anything:
   4https://computational-cell-analytics.github.io/micro-sam/micro_sam.html
   5"""
   6
   7import os
   8from abc import ABC
   9from copy import deepcopy
  10from collections import OrderedDict
  11from typing import Any, Dict, List, Optional, Tuple, Union
  12
  13import vigra
  14import numpy as np
  15import torch
  16
  17from skimage.measure import label, regionprops
  18from skimage.segmentation import relabel_sequential
  19from torchvision.ops.boxes import batched_nms, box_area
  20
  21from torch_em.model import UNETR
  22from torch_em.util.segmentation import watershed_from_center_and_boundary_distances
  23
  24from nifty.tools import blocking
  25
  26import segment_anything.utils.amg as amg_utils
  27from segment_anything.predictor import SamPredictor
  28
  29from . import util
  30from ._vendored import batched_mask_to_box, mask_to_rle_pytorch
  31
  32#
  33# Utility Functionality
  34#
  35
  36
  37class _FakeInput:
  38    def __init__(self, shape):
  39        self.shape = shape
  40
  41    def __getitem__(self, index):
  42        block_shape = tuple(ind.stop - ind.start for ind in index)
  43        return np.zeros(block_shape, dtype="float32")
  44
  45
  46def mask_data_to_segmentation(
  47    masks: List[Dict[str, Any]],
  48    with_background: bool,
  49    min_object_size: int = 0,
  50    max_object_size: Optional[int] = None,
  51    label_masks: bool = True,
  52) -> np.ndarray:
  53    """Convert the output of the automatic mask generation to an instance segmentation.
  54
  55    Args:
  56        masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator.
  57            Only supports output_mode=binary_mask.
  58        with_background: Whether the segmentation has background. If yes this function assures that the largest
  59            object in the output will be mapped to zero (the background value).
  60        min_object_size: The minimal size of an object in pixels.
  61        max_object_size: The maximal size of an object in pixels.
  62<<<<<<< HEAD
  63
  64=======
  65        label_masks: Whether to apply connected components to the result before remving small objects.
  66>>>>>>> master
  67    Returns:
  68        The instance segmentation.
  69    """
  70
  71    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
  72    # we could also get the shape from the crop box
  73    shape = next(iter(masks))["segmentation"].shape
  74    segmentation = np.zeros(shape, dtype="uint32")
  75
  76    def require_numpy(mask):
  77        return mask.cpu().numpy() if torch.is_tensor(mask) else mask
  78
  79    seg_id = 1
  80    for mask in masks:
  81        if mask["area"] < min_object_size:
  82            continue
  83        if max_object_size is not None and mask["area"] > max_object_size:
  84            continue
  85
  86        this_seg_id = mask.get("seg_id", seg_id)
  87        segmentation[require_numpy(mask["segmentation"])] = this_seg_id
  88        seg_id = this_seg_id + 1
  89
  90    if label_masks:
  91        segmentation = label(segmentation)
  92    seg_ids, sizes = np.unique(segmentation, return_counts=True)
  93
  94    # In some cases objects may be smaller than peviously calculated,
  95    # since they are covered by other objects. We ensure these also get
  96    # filtered out here.
  97    filter_ids = seg_ids[sizes < min_object_size]
  98
  99    # If we run segmentation with background we also map the largest segment
 100    # (the most likely background object) to zero. This is often zero already,
 101    # but it does not hurt to reset that to zero either.
 102    if with_background:
 103        bg_id = seg_ids[np.argmax(sizes)]
 104        filter_ids = np.concatenate([filter_ids, [bg_id]])
 105
 106    segmentation[np.isin(segmentation, filter_ids)] = 0
 107    segmentation = relabel_sequential(segmentation)[0]
 108
 109    return segmentation
 110
 111
 112#
 113# Classes for automatic instance segmentation
 114#
 115
 116
 117class AMGBase(ABC):
 118    """Base class for the automatic mask generators.
 119    """
 120    def __init__(self):
 121        # the state that has to be computed by the 'initialize' method of the child classes
 122        self._is_initialized = False
 123        self._crop_list = None
 124        self._crop_boxes = None
 125        self._original_size = None
 126
 127    @property
 128    def is_initialized(self):
 129        """Whether the mask generator has already been initialized.
 130        """
 131        return self._is_initialized
 132
 133    @property
 134    def crop_list(self):
 135        """The list of mask data after initialization.
 136        """
 137        return self._crop_list
 138
 139    @property
 140    def crop_boxes(self):
 141        """The list of crop boxes.
 142        """
 143        return self._crop_boxes
 144
 145    @property
 146    def original_size(self):
 147        """The original image size.
 148        """
 149        return self._original_size
 150
 151    def _postprocess_batch(
 152        self,
 153        data,
 154        crop_box,
 155        original_size,
 156        pred_iou_thresh,
 157        stability_score_thresh,
 158        box_nms_thresh,
 159    ):
 160        orig_h, orig_w = original_size
 161
 162        # filter by predicted IoU
 163        if pred_iou_thresh > 0.0:
 164            keep_mask = data["iou_preds"] > pred_iou_thresh
 165            data.filter(keep_mask)
 166
 167        # filter by stability score
 168        if stability_score_thresh > 0.0:
 169            keep_mask = data["stability_score"] >= stability_score_thresh
 170            data.filter(keep_mask)
 171
 172        # filter boxes that touch crop boundaries
 173        keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
 174        if not torch.all(keep_mask):
 175            data.filter(keep_mask)
 176
 177        # remove duplicates within this crop.
 178        keep_by_nms = batched_nms(
 179            data["boxes"].float(),
 180            data["iou_preds"],
 181            torch.zeros_like(data["boxes"][:, 0]),  # categories
 182            iou_threshold=box_nms_thresh,
 183        )
 184        data.filter(keep_by_nms)
 185
 186        # return to the original image frame
 187        data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box)
 188        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
 189        # the data from embedding based segmentation doesn't have the points
 190        # so we skip if the corresponding key can't be found
 191        try:
 192            data["points"] = amg_utils.uncrop_points(data["points"], crop_box)
 193        except KeyError:
 194            pass
 195
 196        return data
 197
 198    def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
 199
 200        if len(mask_data["rles"]) == 0:
 201            return mask_data
 202
 203        # filter small disconnected regions and holes
 204        new_masks = []
 205        scores = []
 206        for rle in mask_data["rles"]:
 207            mask = amg_utils.rle_to_mask(rle)
 208
 209            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes")
 210            unchanged = not changed
 211            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands")
 212            unchanged = unchanged and not changed
 213
 214            new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0))
 215            # give score=0 to changed masks and score=1 to unchanged masks
 216            # so NMS will prefer ones that didn't need postprocessing
 217            scores.append(float(unchanged))
 218
 219        # recalculate boxes and remove any new duplicates
 220        masks = torch.cat(new_masks, dim=0)
 221        boxes = batched_mask_to_box(masks)
 222        keep_by_nms = batched_nms(
 223            boxes.float(),
 224            torch.as_tensor(scores, dtype=torch.float),
 225            torch.zeros_like(boxes[:, 0]),  # categories
 226            iou_threshold=nms_thresh,
 227        )
 228
 229        # only recalculate RLEs for masks that have changed
 230        for i_mask in keep_by_nms:
 231            if scores[i_mask] == 0.0:
 232                mask_torch = masks[i_mask].unsqueeze(0)
 233                # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
 234                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
 235                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
 236        mask_data.filter(keep_by_nms)
 237
 238        return mask_data
 239
 240    def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode):
 241        # filter small disconnected regions and holes in masks
 242        if min_mask_region_area > 0:
 243            mask_data = self._postprocess_small_regions(
 244                mask_data,
 245                min_mask_region_area,
 246                max(box_nms_thresh, crop_nms_thresh),
 247            )
 248
 249        # encode masks
 250        if output_mode == "coco_rle":
 251            mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]]
 252        elif output_mode == "binary_mask":
 253            mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]]
 254        else:
 255            mask_data["segmentations"] = mask_data["rles"]
 256
 257        # write mask records
 258        curr_anns = []
 259        for idx in range(len(mask_data["segmentations"])):
 260            ann = {
 261                "segmentation": mask_data["segmentations"][idx],
 262                "area": amg_utils.area_from_rle(mask_data["rles"][idx]),
 263                "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
 264                "predicted_iou": mask_data["iou_preds"][idx].item(),
 265                "stability_score": mask_data["stability_score"][idx].item(),
 266                "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
 267            }
 268            # the data from embedding based segmentation doesn't have the points
 269            # so we skip if the corresponding key can't be found
 270            try:
 271                ann["point_coords"] = [mask_data["points"][idx].tolist()]
 272            except KeyError:
 273                pass
 274            curr_anns.append(ann)
 275
 276        return curr_anns
 277
 278    def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
 279        orig_h, orig_w = original_size
 280
 281        # serialize predictions and store in MaskData
 282        data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
 283        if points is not None:
 284            data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
 285
 286        del masks
 287
 288        # calculate the stability scores
 289        data["stability_score"] = amg_utils.calculate_stability_score(
 290            data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
 291        )
 292
 293        # threshold masks and calculate boxes
 294        data["masks"] = data["masks"] > self._predictor.model.mask_threshold
 295        data["masks"] = data["masks"].type(torch.bool)
 296        data["boxes"] = batched_mask_to_box(data["masks"])
 297
 298        # compress to RLE
 299        data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
 300        # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
 301        data["rles"] = mask_to_rle_pytorch(data["masks"])
 302        del data["masks"]
 303
 304        return data
 305
 306    def get_state(self) -> Dict[str, Any]:
 307        """Get the initialized state of the mask generator.
 308
 309        Returns:
 310            State of the mask generator.
 311        """
 312        if not self.is_initialized:
 313            raise RuntimeError("The state has not been computed yet. Call initialize first.")
 314
 315        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
 316
 317    def set_state(self, state: Dict[str, Any]) -> None:
 318        """Set the state of the mask generator.
 319
 320        Args:
 321            state: The state of the mask generator, e.g. from serialized state.
 322        """
 323        self._crop_list = state["crop_list"]
 324        self._crop_boxes = state["crop_boxes"]
 325        self._original_size = state["original_size"]
 326        self._is_initialized = True
 327
 328    def clear_state(self):
 329        """Clear the state of the mask generator.
 330        """
 331        self._crop_list = None
 332        self._crop_boxes = None
 333        self._original_size = None
 334        self._is_initialized = False
 335
 336
 337class AutomaticMaskGenerator(AMGBase):
 338    """Generates an instance segmentation without prompts, using a point grid.
 339
 340    This class implements the same logic as
 341    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
 342    It decouples the computationally expensive steps of generating masks from the cheap post-processing operation
 343    to filter these masks to enable grid search and interactively changing the post-processing.
 344
 345    Use this class as follows:
 346    ```python
 347    amg = AutomaticMaskGenerator(predictor)
 348    amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
 349    masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
 350    ```
 351
 352    Args:
 353        predictor: The segment anything predictor.
 354        points_per_side: The number of points to be sampled along one side of the image.
 355            If None, `point_grids` must provide explicit point sampling.
 356        points_per_batch: The number of points run simultaneously by the model.
 357            Higher numbers may be faster but use more GPU memory.
 358        crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
 359        crop_overlap_ratio: Sets the degree to which crops overlap.
 360        crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
 361        point_grids: A lisst over explicit grids of points used for sampling masks.
 362            Normalized to [0, 1] with respect to the image coordinate system.
 363        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
 364    """
 365    def __init__(
 366        self,
 367        predictor: SamPredictor,
 368        points_per_side: Optional[int] = 32,
 369        points_per_batch: Optional[int] = None,
 370        crop_n_layers: int = 0,
 371        crop_overlap_ratio: float = 512 / 1500,
 372        crop_n_points_downscale_factor: int = 1,
 373        point_grids: Optional[List[np.ndarray]] = None,
 374        stability_score_offset: float = 1.0,
 375    ):
 376        super().__init__()
 377
 378        if points_per_side is not None:
 379            self.point_grids = amg_utils.build_all_layer_point_grids(
 380                points_per_side,
 381                crop_n_layers,
 382                crop_n_points_downscale_factor,
 383            )
 384        elif point_grids is not None:
 385            self.point_grids = point_grids
 386        else:
 387            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
 388
 389        self._predictor = predictor
 390        self._points_per_side = points_per_side
 391
 392        # we set the points per batch to 16 for mps for performance reasons
 393        # and otherwise keep them at the default of 64
 394        if points_per_batch is None:
 395            points_per_batch = 16 if str(predictor.device) == "mps" else 64
 396        self._points_per_batch = points_per_batch
 397
 398        self._crop_n_layers = crop_n_layers
 399        self._crop_overlap_ratio = crop_overlap_ratio
 400        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
 401        self._stability_score_offset = stability_score_offset
 402
 403    def _process_batch(self, points, im_size, crop_box, original_size):
 404        # run model on this batch
 405        transformed_points = self._predictor.transform.apply_coords(points, im_size)
 406        in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
 407        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
 408        masks, iou_preds, _ = self._predictor.predict_torch(
 409            in_points[:, None, :],
 410            in_labels[:, None],
 411            multimask_output=True,
 412            return_logits=True,
 413        )
 414        data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
 415        del masks
 416        return data
 417
 418    def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None):
 419        # Crop the image and calculate embeddings.
 420        x0, y0, x1, y1 = crop_box
 421        cropped_im = image[y0:y1, x0:x1, :]
 422        cropped_im_size = cropped_im.shape[:2]
 423
 424        if not precomputed_embeddings:
 425            self._predictor.set_image(cropped_im)
 426
 427        # Get the points for this crop.
 428        points_scale = np.array(cropped_im_size)[None, ::-1]
 429        points_for_image = self.point_grids[crop_layer_idx] * points_scale
 430
 431        # Generate masks for this crop in batches.
 432        data = amg_utils.MaskData()
 433        n_batches = len(points_for_image) // self._points_per_batch +\
 434            int(len(points_for_image) % self._points_per_batch != 0)
 435        if pbar_init is not None:
 436            pbar_init(n_batches, "Predict masks for point grid prompts")
 437
 438        for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image):
 439            batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
 440            data.cat(batch_data)
 441            del batch_data
 442            if pbar_update is not None:
 443                pbar_update(1)
 444
 445        if not precomputed_embeddings:
 446            self._predictor.reset_image()
 447
 448        return data
 449
 450    @torch.no_grad()
 451    def initialize(
 452        self,
 453        image: np.ndarray,
 454        image_embeddings: Optional[util.ImageEmbeddings] = None,
 455        i: Optional[int] = None,
 456        verbose: bool = False,
 457        pbar_init: Optional[callable] = None,
 458        pbar_update: Optional[callable] = None,
 459    ) -> None:
 460        """Initialize image embeddings and masks for an image.
 461
 462        Args:
 463            image: The input image, volume or timeseries.
 464            image_embeddings: Optional precomputed image embeddings.
 465                See `util.precompute_image_embeddings` for details.
 466            i: Index for the image data. Required if `image` has three spatial dimensions
 467                or a time dimension and two spatial dimensions.
 468            verbose: Whether to print computation progress.
 469            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 470                Can be used together with pbar_update to handle napari progress bar in other thread.
 471                To enables using this function within a threadworker.
 472            pbar_update: Callback to update an external progress bar.
 473        """
 474        original_size = image.shape[:2]
 475        self._original_size = original_size
 476
 477        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
 478            original_size, self._crop_n_layers, self._crop_overlap_ratio
 479        )
 480
 481        # We can set fixed image embeddings if we only have a single crop box (the default setting).
 482        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
 483        if len(crop_boxes) == 1:
 484            if image_embeddings is None:
 485                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 486            util.set_precomputed(self._predictor, image_embeddings, i=i)
 487            precomputed_embeddings = True
 488        else:
 489            precomputed_embeddings = False
 490
 491        # we need to cast to the image representation that is compatible with SAM
 492        image = util._to_image(image)
 493
 494        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 495
 496        crop_list = []
 497        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
 498            crop_data = self._process_crop(
 499                image, crop_box, layer_idx,
 500                precomputed_embeddings=precomputed_embeddings,
 501                pbar_init=pbar_init, pbar_update=pbar_update,
 502            )
 503            crop_list.append(crop_data)
 504        pbar_close()
 505
 506        self._is_initialized = True
 507        self._crop_list = crop_list
 508        self._crop_boxes = crop_boxes
 509
 510    @torch.no_grad()
 511    def generate(
 512        self,
 513        pred_iou_thresh: float = 0.88,
 514        stability_score_thresh: float = 0.95,
 515        box_nms_thresh: float = 0.7,
 516        crop_nms_thresh: float = 0.7,
 517        min_mask_region_area: int = 0,
 518        output_mode: str = "binary_mask",
 519    ) -> List[Dict[str, Any]]:
 520        """Generate instance segmentation for the currently initialized image.
 521
 522        Args:
 523            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
 524            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
 525                under changes to the cutoff used to binarize the model prediction.
 526            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
 527            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
 528            min_mask_region_area: Minimal size for the predicted masks.
 529            output_mode: The form masks are returned in.
 530
 531        Returns:
 532            The instance segmentation masks.
 533        """
 534        if not self.is_initialized:
 535            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
 536
 537        data = amg_utils.MaskData()
 538        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
 539            crop_data = self._postprocess_batch(
 540                data=deepcopy(data_),
 541                crop_box=crop_box, original_size=self.original_size,
 542                pred_iou_thresh=pred_iou_thresh,
 543                stability_score_thresh=stability_score_thresh,
 544                box_nms_thresh=box_nms_thresh
 545            )
 546            data.cat(crop_data)
 547
 548        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
 549            # Prefer masks from smaller crops
 550            scores = 1 / box_area(data["crop_boxes"])
 551            scores = scores.to(data["boxes"].device)
 552            keep_by_nms = batched_nms(
 553                data["boxes"].float(),
 554                scores,
 555                torch.zeros_like(data["boxes"][:, 0]),  # categories
 556                iou_threshold=crop_nms_thresh,
 557            )
 558            data.filter(keep_by_nms)
 559
 560        data.to_numpy()
 561        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
 562        return masks
 563
 564
 565# Helper function for tiled embedding computation and checking consistent state.
 566def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo):
 567    if image_embeddings is None:
 568        if tile_shape is None or halo is None:
 569            raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.")
 570        image_embeddings = util.precompute_image_embeddings(predictor, image, tile_shape=tile_shape, halo=halo)
 571
 572    # Use tile shape and halo from the precomputed embeddings if not given.
 573    # Otherwise check that they are consistent.
 574    feats = image_embeddings["features"]
 575    tile_shape_, halo_ = feats.attrs["tile_shape"], feats.attrs["halo"]
 576    if tile_shape is None:
 577        tile_shape = tile_shape_
 578    elif tile_shape != tile_shape_:
 579        raise ValueError(
 580            f"Inconsistent tile_shape parameter {tile_shape} with precomputed embeedings: {tile_shape_}."
 581        )
 582    if halo is None:
 583        halo = halo_
 584    elif halo != halo_:
 585        raise ValueError(f"Inconsistent halo parameter {halo} with precomputed embeedings: {halo_}.")
 586
 587    return image_embeddings, tile_shape, halo
 588
 589
 590class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
 591    """Generates an instance segmentation without prompts, using a point grid.
 592
 593    Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
 594
 595    Args:
 596        predictor: The segment anything predictor.
 597        points_per_side: The number of points to be sampled along one side of the image.
 598            If None, `point_grids` must provide explicit point sampling.
 599        points_per_batch: The number of points run simultaneously by the model.
 600            Higher numbers may be faster but use more GPU memory.
 601        point_grids: A lisst over explicit grids of points used for sampling masks.
 602            Normalized to [0, 1] with respect to the image coordinate system.
 603        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
 604    """
 605
 606    # We only expose the arguments that make sense for the tiled mask generator.
 607    # Anything related to crops doesn't make sense, because we re-use that functionality
 608    # for tiling, so these parameters wouldn't have any effect.
 609    def __init__(
 610        self,
 611        predictor: SamPredictor,
 612        points_per_side: Optional[int] = 32,
 613        points_per_batch: int = 64,
 614        point_grids: Optional[List[np.ndarray]] = None,
 615        stability_score_offset: float = 1.0,
 616    ) -> None:
 617        super().__init__(
 618            predictor=predictor,
 619            points_per_side=points_per_side,
 620            points_per_batch=points_per_batch,
 621            point_grids=point_grids,
 622            stability_score_offset=stability_score_offset,
 623        )
 624
 625    @torch.no_grad()
 626    def initialize(
 627        self,
 628        image: np.ndarray,
 629        image_embeddings: Optional[util.ImageEmbeddings] = None,
 630        i: Optional[int] = None,
 631        tile_shape: Optional[Tuple[int, int]] = None,
 632        halo: Optional[Tuple[int, int]] = None,
 633        verbose: bool = False,
 634        pbar_init: Optional[callable] = None,
 635        pbar_update: Optional[callable] = None,
 636    ) -> None:
 637        """Initialize image embeddings and masks for an image.
 638
 639        Args:
 640            image: The input image, volume or timeseries.
 641            image_embeddings: Optional precomputed image embeddings.
 642                See `util.precompute_image_embeddings` for details.
 643            i: Index for the image data. Required if `image` has three spatial dimensions
 644                or a time dimension and two spatial dimensions.
 645            tile_shape: The tile shape for embedding prediction.
 646            halo: The overlap of between tiles.
 647            verbose: Whether to print computation progress.
 648            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 649                Can be used together with pbar_update to handle napari progress bar in other thread.
 650                To enables using this function within a threadworker.
 651            pbar_update: Callback to update an external progress bar.
 652        """
 653        original_size = image.shape[:2]
 654        self._original_size = original_size
 655
 656        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
 657            self._predictor, image, image_embeddings, tile_shape, halo
 658        )
 659
 660        tiling = blocking([0, 0], original_size, tile_shape)
 661        n_tiles = tiling.numberOfBlocks
 662
 663        # The crop box is always the full local tile.
 664        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
 665        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
 666
 667        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 668        pbar_init(n_tiles, "Compute masks for tile")
 669
 670        # We need to cast to the image representation that is compatible with SAM.
 671        image = util._to_image(image)
 672
 673        mask_data = []
 674        for tile_id in range(n_tiles):
 675            # set the pre-computed embeddings for this tile
 676            features = image_embeddings["features"][tile_id]
 677            tile_embeddings = {
 678                "features": features,
 679                "input_size": features.attrs["input_size"],
 680                "original_size": features.attrs["original_size"],
 681            }
 682            util.set_precomputed(self._predictor, tile_embeddings, i)
 683
 684            # compute the mask data for this tile and append it
 685            this_mask_data = self._process_crop(
 686                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
 687            )
 688            mask_data.append(this_mask_data)
 689            pbar_update(1)
 690        pbar_close()
 691
 692        # set the initialized data
 693        self._is_initialized = True
 694        self._crop_list = mask_data
 695        self._crop_boxes = crop_boxes
 696
 697
 698#
 699# Instance segmentation functionality based on fine-tuned decoder
 700#
 701
 702
 703class DecoderAdapter(torch.nn.Module):
 704    """Adapter to contain the UNETR decoder in a single module.
 705
 706    To apply the decoder on top of pre-computed embeddings for
 707    the segmentation functionality.
 708    See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
 709    """
 710    def __init__(self, unetr):
 711        super().__init__()
 712
 713        self.base = unetr.base
 714        self.out_conv = unetr.out_conv
 715        self.deconv_out = unetr.deconv_out
 716        self.decoder_head = unetr.decoder_head
 717        self.final_activation = unetr.final_activation
 718        self.postprocess_masks = unetr.postprocess_masks
 719
 720        self.decoder = unetr.decoder
 721        self.deconv1 = unetr.deconv1
 722        self.deconv2 = unetr.deconv2
 723        self.deconv3 = unetr.deconv3
 724        self.deconv4 = unetr.deconv4
 725
 726    def forward(self, input_, input_shape, original_shape):
 727        z12 = input_
 728
 729        z9 = self.deconv1(z12)
 730        z6 = self.deconv2(z9)
 731        z3 = self.deconv3(z6)
 732        z0 = self.deconv4(z3)
 733
 734        updated_from_encoder = [z9, z6, z3]
 735
 736        x = self.base(z12)
 737        x = self.decoder(x, encoder_inputs=updated_from_encoder)
 738        x = self.deconv_out(x)
 739
 740        x = torch.cat([x, z0], dim=1)
 741        x = self.decoder_head(x)
 742
 743        x = self.out_conv(x)
 744        if self.final_activation is not None:
 745            x = self.final_activation(x)
 746
 747        x = self.postprocess_masks(x, input_shape, original_shape)
 748        return x
 749
 750
 751def get_unetr(
 752    image_encoder: torch.nn.Module,
 753    decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
 754    device: Optional[Union[str, torch.device]] = None,
 755) -> torch.nn.Module:
 756    """Get UNETR model for automatic instance segmentation.
 757
 758    Args:
 759        image_encoder: The image encoder of the SAM model.
 760            This is used as encoder by the UNETR too.
 761        decoder_state: Optional decoder state to initialize the weights
 762            of the UNETR decoder.
 763        device: The device.
 764    Returns:
 765        The UNETR model.
 766    """
 767    device = util.get_device(device)
 768
 769    unetr = UNETR(
 770        backbone="sam",
 771        encoder=image_encoder,
 772        out_channels=3,
 773        use_sam_stats=True,
 774        final_activation="Sigmoid",
 775        use_skip_connection=False,
 776        resize_input=True,
 777    )
 778    if decoder_state is not None:
 779        unetr_state_dict = unetr.state_dict()
 780        for k, v in unetr_state_dict.items():
 781            if not k.startswith("encoder"):
 782                unetr_state_dict[k] = decoder_state[k]
 783        unetr.load_state_dict(unetr_state_dict)
 784
 785    unetr.to(device)
 786    return unetr
 787
 788
 789def get_decoder(
 790    image_encoder: torch.nn.Module,
 791    decoder_state: OrderedDict[str, torch.Tensor],
 792    device: Optional[Union[str, torch.device]] = None,
 793) -> DecoderAdapter:
 794    """Get decoder to predict outputs for automatic instance segmentation
 795
 796    Args:
 797        image_encoder: The image encoder of the SAM model.
 798        decoder_state: State to initialize the weights of the UNETR decoder.
 799        device: The device.
 800    Returns:
 801        The decoder for instance segmentation.
 802    """
 803    unetr = get_unetr(image_encoder, decoder_state, device)
 804    return DecoderAdapter(unetr)
 805
 806
 807def get_predictor_and_decoder(
 808    model_type: str,
 809    checkpoint_path: Union[str, os.PathLike],
 810    device: Optional[Union[str, torch.device]] = None,
 811    peft_kwargs: Optional[Dict] = None,
 812) -> Tuple[SamPredictor, DecoderAdapter]:
 813    """Load the SAM model (predictor) and instance segmentation decoder.
 814
 815    This requires a checkpoint that contains the state for both predictor
 816    and decoder.
 817
 818    Args:
 819        model_type: The type of the image encoder used in the SAM model.
 820        checkpoint_path: Path to the checkpoint from which to load the data.
 821        device: The device.
 822        lora_rank: The rank for low rank adaptation of the attention layers.
 823
 824    Returns:
 825        The SAM predictor.
 826        The decoder for instance segmentation.
 827    """
 828    device = util.get_device(device)
 829    predictor, state = util.get_sam_model(
 830        model_type=model_type,
 831        checkpoint_path=checkpoint_path,
 832        device=device,
 833        return_state=True,
 834        peft_kwargs=peft_kwargs,
 835    )
 836    if "decoder_state" not in state:
 837        raise ValueError(
 838            f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
 839        )
 840    decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
 841    return predictor, decoder
 842
 843
 844class InstanceSegmentationWithDecoder:
 845    """Generates an instance segmentation without prompts, using a decoder.
 846
 847    Implements the same interface as `AutomaticMaskGenerator`.
 848
 849    Use this class as follows:
 850    ```python
 851    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
 852    segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
 853    masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
 854    ```
 855
 856    Args:
 857        predictor: The segment anything predictor.
 858        decoder: The decoder to predict intermediate representations
 859            for instance segmentation.
 860    """
 861    def __init__(
 862        self,
 863        predictor: SamPredictor,
 864        decoder: torch.nn.Module,
 865    ) -> None:
 866        self._predictor = predictor
 867        self._decoder = decoder
 868
 869        # The decoder outputs.
 870        self._foreground = None
 871        self._center_distances = None
 872        self._boundary_distances = None
 873
 874        self._is_initialized = False
 875
 876    @property
 877    def is_initialized(self):
 878        """Whether the mask generator has already been initialized.
 879        """
 880        return self._is_initialized
 881
 882    @torch.no_grad()
 883    def initialize(
 884        self,
 885        image: np.ndarray,
 886        image_embeddings: Optional[util.ImageEmbeddings] = None,
 887        i: Optional[int] = None,
 888        verbose: bool = False,
 889        pbar_init: Optional[callable] = None,
 890        pbar_update: Optional[callable] = None,
 891    ) -> None:
 892        """Initialize image embeddings and decoder predictions for an image.
 893
 894        Args:
 895            image: The input image, volume or timeseries.
 896            image_embeddings: Optional precomputed image embeddings.
 897                See `util.precompute_image_embeddings` for details.
 898            i: Index for the image data. Required if `image` has three spatial dimensions
 899                or a time dimension and two spatial dimensions.
 900            verbose: Whether to be verbose.
 901            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 902                Can be used together with pbar_update to handle napari progress bar in other thread.
 903                To enables using this function within a threadworker.
 904            pbar_update: Callback to update an external progress bar.
 905        """
 906        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 907        pbar_init(1, "Initialize instance segmentation with decoder")
 908
 909        if image_embeddings is None:
 910            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 911
 912        # Get the image embeddings from the predictor.
 913        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
 914        embeddings = self._predictor.features
 915        input_shape = tuple(self._predictor.input_size)
 916        original_shape = tuple(self._predictor.original_size)
 917
 918        # Run prediction with the UNETR decoder.
 919        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
 920        assert output.shape[0] == 3, f"{output.shape}"
 921        pbar_update(1)
 922        pbar_close()
 923
 924        # Set the state.
 925        self._foreground = output[0]
 926        self._center_distances = output[1]
 927        self._boundary_distances = output[2]
 928        self._is_initialized = True
 929
 930    def _to_masks(self, segmentation, output_mode):
 931        if output_mode != "binary_mask":
 932            raise NotImplementedError
 933
 934        props = regionprops(segmentation)
 935        ndim = segmentation.ndim
 936        assert ndim in (2, 3)
 937
 938        shape = segmentation.shape
 939        if ndim == 2:
 940            crop_box = [0, shape[1], 0, shape[0]]
 941        else:
 942            crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]
 943
 944        # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
 945        def to_bbox_2d(bbox):
 946            y0, x0 = bbox[0], bbox[1]
 947            w = bbox[3] - x0
 948            h = bbox[2] - y0
 949            return [x0, w, y0, h]
 950
 951        def to_bbox_3d(bbox):
 952            z0, y0, x0 = bbox[0], bbox[1], bbox[2]
 953            w = bbox[5] - x0
 954            h = bbox[4] - y0
 955            d = bbox[3] - y0
 956            return [x0, w, y0, h, z0, d]
 957
 958        to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
 959        masks = [
 960            {
 961                "segmentation": segmentation == prop.label,
 962                "area": prop.area,
 963                "bbox": to_bbox(prop.bbox),
 964                "crop_box": crop_box,
 965                "seg_id": prop.label,
 966            } for prop in props
 967        ]
 968        return masks
 969
 970    def generate(
 971        self,
 972        center_distance_threshold: float = 0.5,
 973        boundary_distance_threshold: float = 0.5,
 974        foreground_threshold: float = 0.5,
 975        foreground_smoothing: float = 1.0,
 976        distance_smoothing: float = 1.6,
 977        min_size: int = 0,
 978        output_mode: Optional[str] = "binary_mask",
 979    ) -> List[Dict[str, Any]]:
 980        """Generate instance segmentation for the currently initialized image.
 981
 982        Args:
 983            center_distance_threshold: Center distance predictions below this value will be
 984                used to find seeds (intersected with thresholded boundary distance predictions).
 985            boundary_distance_threshold: Boundary distance predictions below this value will be
 986                used to find seeds (intersected with thresholded center distance predictions).
 987            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
 988                checkerboard artifacts in the prediction.
 989            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
 990            distance_smoothing: Sigma value for smoothing the distance predictions.
 991            min_size: Minimal object size in the segmentation result.
 992            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
 993
 994        Returns:
 995            The instance segmentation masks.
 996        """
 997        if not self.is_initialized:
 998            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
 999
1000        if foreground_smoothing > 0:
1001            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1002        else:
1003            foreground = self._foreground
1004        # Further optimization: parallel implementation using elf.parallel functionality.
1005        # (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization)
1006        segmentation = watershed_from_center_and_boundary_distances(
1007            self._center_distances, self._boundary_distances, foreground,
1008            center_distance_threshold=center_distance_threshold,
1009            boundary_distance_threshold=boundary_distance_threshold,
1010            foreground_threshold=foreground_threshold,
1011            distance_smoothing=distance_smoothing,
1012            min_size=min_size,
1013        )
1014        if output_mode is not None:
1015            segmentation = self._to_masks(segmentation, output_mode)
1016        return segmentation
1017
1018    def get_state(self) -> Dict[str, Any]:
1019        """Get the initialized state of the instance segmenter.
1020
1021        Returns:
1022            Instance segmentation state.
1023        """
1024        if not self.is_initialized:
1025            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1026
1027        return {
1028            "foreground": self._foreground,
1029            "center_distances": self._center_distances,
1030            "boundary_distances": self._boundary_distances,
1031        }
1032
1033    def set_state(self, state: Dict[str, Any]) -> None:
1034        """Set the state of the instance segmenter.
1035
1036        Args:
1037            state: The instance segmentation state
1038        """
1039        self._foreground = state["foreground"]
1040        self._center_distances = state["center_distances"]
1041        self._boundary_distances = state["boundary_distances"]
1042        self._is_initialized = True
1043
1044    def clear_state(self):
1045        """Clear the state of the instance segmenter.
1046        """
1047        self._foreground = None
1048        self._center_distances = None
1049        self._boundary_distances = None
1050        self._is_initialized = False
1051
1052
1053class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1054    """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings.
1055    """
1056
1057    @torch.no_grad()
1058    def initialize(
1059        self,
1060        image: np.ndarray,
1061        image_embeddings: Optional[util.ImageEmbeddings] = None,
1062        i: Optional[int] = None,
1063        tile_shape: Optional[Tuple[int, int]] = None,
1064        halo: Optional[Tuple[int, int]] = None,
1065        verbose: bool = False,
1066        pbar_init: Optional[callable] = None,
1067        pbar_update: Optional[callable] = None,
1068    ) -> None:
1069        """Initialize image embeddings and decoder predictions for an image.
1070
1071        Args:
1072            image: The input image, volume or timeseries.
1073            image_embeddings: Optional precomputed image embeddings.
1074                See `util.precompute_image_embeddings` for details.
1075            i: Index for the image data. Required if `image` has three spatial dimensions
1076                or a time dimension and two spatial dimensions.
1077            verbose: Dummy input to be compatible with other function signatures.
1078            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1079                Can be used together with pbar_update to handle napari progress bar in other thread.
1080                To enables using this function within a threadworker.
1081            pbar_update: Callback to update an external progress bar.
1082        """
1083        original_size = image.shape[:2]
1084        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1085            self._predictor, image, image_embeddings, tile_shape, halo
1086        )
1087        tiling = blocking([0, 0], original_size, tile_shape)
1088
1089        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1090        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1091
1092        foreground = np.zeros(original_size, dtype="float32")
1093        center_distances = np.zeros(original_size, dtype="float32")
1094        boundary_distances = np.zeros(original_size, dtype="float32")
1095
1096        for tile_id in range(tiling.numberOfBlocks):
1097
1098            # Get the image embeddings from the predictor for this tile.
1099            self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1100            embeddings = self._predictor.features
1101            input_shape = tuple(self._predictor.input_size)
1102            original_shape = tuple(self._predictor.original_size)
1103
1104            # Predict with the UNETR decoder for this tile.
1105            output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1106            assert output.shape[0] == 3, f"{output.shape}"
1107
1108            # Set the predictions in the output for this tile.
1109            block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1110            local_bb = tuple(
1111                slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1112            )
1113            inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1114
1115            foreground[inner_bb] = output[0][local_bb]
1116            center_distances[inner_bb] = output[1][local_bb]
1117            boundary_distances[inner_bb] = output[2][local_bb]
1118        pbar_close()
1119
1120        # Set the state.
1121        self._foreground = foreground
1122        self._center_distances = center_distances
1123        self._boundary_distances = boundary_distances
1124        self._is_initialized = True
1125
1126
1127def get_amg(
1128    predictor: SamPredictor,
1129    is_tiled: bool,
1130    decoder: Optional[torch.nn.Module] = None,
1131    **kwargs,
1132) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1133    """Get the automatic mask generator class.
1134
1135    Args:
1136        predictor: The segment anything predictor.
1137        is_tiled: Whether tiled embeddings are used.
1138        decoder: Decoder to predict instacne segmmentation.
1139        kwargs: The keyword arguments for the amg class.
1140
1141    Returns:
1142        The automatic mask generator.
1143    """
1144    if decoder is None:
1145        segmenter = TiledAutomaticMaskGenerator(predictor, **kwargs) if is_tiled else\
1146            AutomaticMaskGenerator(predictor, **kwargs)
1147    else:
1148        segmenter = TiledInstanceSegmentationWithDecoder(predictor, decoder, **kwargs) if is_tiled else\
1149            InstanceSegmentationWithDecoder(predictor, decoder, **kwargs)
1150    return segmenter
def mask_data_to_segmentation( masks: List[Dict[str, Any]], with_background: bool, min_object_size: int = 0, max_object_size: Optional[int] = None, label_masks: bool = True) -> numpy.ndarray:
 47def mask_data_to_segmentation(
 48    masks: List[Dict[str, Any]],
 49    with_background: bool,
 50    min_object_size: int = 0,
 51    max_object_size: Optional[int] = None,
 52    label_masks: bool = True,
 53) -> np.ndarray:
 54    """Convert the output of the automatic mask generation to an instance segmentation.
 55
 56    Args:
 57        masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator.
 58            Only supports output_mode=binary_mask.
 59        with_background: Whether the segmentation has background. If yes this function assures that the largest
 60            object in the output will be mapped to zero (the background value).
 61        min_object_size: The minimal size of an object in pixels.
 62        max_object_size: The maximal size of an object in pixels.
 63<<<<<<< HEAD
 64
 65=======
 66        label_masks: Whether to apply connected components to the result before remving small objects.
 67>>>>>>> master
 68    Returns:
 69        The instance segmentation.
 70    """
 71
 72    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
 73    # we could also get the shape from the crop box
 74    shape = next(iter(masks))["segmentation"].shape
 75    segmentation = np.zeros(shape, dtype="uint32")
 76
 77    def require_numpy(mask):
 78        return mask.cpu().numpy() if torch.is_tensor(mask) else mask
 79
 80    seg_id = 1
 81    for mask in masks:
 82        if mask["area"] < min_object_size:
 83            continue
 84        if max_object_size is not None and mask["area"] > max_object_size:
 85            continue
 86
 87        this_seg_id = mask.get("seg_id", seg_id)
 88        segmentation[require_numpy(mask["segmentation"])] = this_seg_id
 89        seg_id = this_seg_id + 1
 90
 91    if label_masks:
 92        segmentation = label(segmentation)
 93    seg_ids, sizes = np.unique(segmentation, return_counts=True)
 94
 95    # In some cases objects may be smaller than peviously calculated,
 96    # since they are covered by other objects. We ensure these also get
 97    # filtered out here.
 98    filter_ids = seg_ids[sizes < min_object_size]
 99
100    # If we run segmentation with background we also map the largest segment
101    # (the most likely background object) to zero. This is often zero already,
102    # but it does not hurt to reset that to zero either.
103    if with_background:
104        bg_id = seg_ids[np.argmax(sizes)]
105        filter_ids = np.concatenate([filter_ids, [bg_id]])
106
107    segmentation[np.isin(segmentation, filter_ids)] = 0
108    segmentation = relabel_sequential(segmentation)[0]
109
110    return segmentation

Convert the output of the automatic mask generation to an instance segmentation.

Args:
    masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator.
        Only supports output_mode=binary_mask.
    with_background: Whether the segmentation has background. If yes this function assures that the largest
        object in the output will be mapped to zero (the background value).
    min_object_size: The minimal size of an object in pixels.
    max_object_size: The maximal size of an object in pixels.

<<<<<<< HEAD

======= label_masks: Whether to apply connected components to the result before remving small objects.

>>> master
    Returns:
        The instance segmentation.
class AMGBase(abc.ABC):
118class AMGBase(ABC):
119    """Base class for the automatic mask generators.
120    """
121    def __init__(self):
122        # the state that has to be computed by the 'initialize' method of the child classes
123        self._is_initialized = False
124        self._crop_list = None
125        self._crop_boxes = None
126        self._original_size = None
127
128    @property
129    def is_initialized(self):
130        """Whether the mask generator has already been initialized.
131        """
132        return self._is_initialized
133
134    @property
135    def crop_list(self):
136        """The list of mask data after initialization.
137        """
138        return self._crop_list
139
140    @property
141    def crop_boxes(self):
142        """The list of crop boxes.
143        """
144        return self._crop_boxes
145
146    @property
147    def original_size(self):
148        """The original image size.
149        """
150        return self._original_size
151
152    def _postprocess_batch(
153        self,
154        data,
155        crop_box,
156        original_size,
157        pred_iou_thresh,
158        stability_score_thresh,
159        box_nms_thresh,
160    ):
161        orig_h, orig_w = original_size
162
163        # filter by predicted IoU
164        if pred_iou_thresh > 0.0:
165            keep_mask = data["iou_preds"] > pred_iou_thresh
166            data.filter(keep_mask)
167
168        # filter by stability score
169        if stability_score_thresh > 0.0:
170            keep_mask = data["stability_score"] >= stability_score_thresh
171            data.filter(keep_mask)
172
173        # filter boxes that touch crop boundaries
174        keep_mask = ~amg_utils.is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
175        if not torch.all(keep_mask):
176            data.filter(keep_mask)
177
178        # remove duplicates within this crop.
179        keep_by_nms = batched_nms(
180            data["boxes"].float(),
181            data["iou_preds"],
182            torch.zeros_like(data["boxes"][:, 0]),  # categories
183            iou_threshold=box_nms_thresh,
184        )
185        data.filter(keep_by_nms)
186
187        # return to the original image frame
188        data["boxes"] = amg_utils.uncrop_boxes_xyxy(data["boxes"], crop_box)
189        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
190        # the data from embedding based segmentation doesn't have the points
191        # so we skip if the corresponding key can't be found
192        try:
193            data["points"] = amg_utils.uncrop_points(data["points"], crop_box)
194        except KeyError:
195            pass
196
197        return data
198
199    def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):
200
201        if len(mask_data["rles"]) == 0:
202            return mask_data
203
204        # filter small disconnected regions and holes
205        new_masks = []
206        scores = []
207        for rle in mask_data["rles"]:
208            mask = amg_utils.rle_to_mask(rle)
209
210            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="holes")
211            unchanged = not changed
212            mask, changed = amg_utils.remove_small_regions(mask, min_area, mode="islands")
213            unchanged = unchanged and not changed
214
215            new_masks.append(torch.as_tensor(mask, dtype=torch.int).unsqueeze(0))
216            # give score=0 to changed masks and score=1 to unchanged masks
217            # so NMS will prefer ones that didn't need postprocessing
218            scores.append(float(unchanged))
219
220        # recalculate boxes and remove any new duplicates
221        masks = torch.cat(new_masks, dim=0)
222        boxes = batched_mask_to_box(masks)
223        keep_by_nms = batched_nms(
224            boxes.float(),
225            torch.as_tensor(scores, dtype=torch.float),
226            torch.zeros_like(boxes[:, 0]),  # categories
227            iou_threshold=nms_thresh,
228        )
229
230        # only recalculate RLEs for masks that have changed
231        for i_mask in keep_by_nms:
232            if scores[i_mask] == 0.0:
233                mask_torch = masks[i_mask].unsqueeze(0)
234                # mask_data["rles"][i_mask] = amg_utils.mask_to_rle_pytorch(mask_torch)[0]
235                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
236                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
237        mask_data.filter(keep_by_nms)
238
239        return mask_data
240
241    def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode):
242        # filter small disconnected regions and holes in masks
243        if min_mask_region_area > 0:
244            mask_data = self._postprocess_small_regions(
245                mask_data,
246                min_mask_region_area,
247                max(box_nms_thresh, crop_nms_thresh),
248            )
249
250        # encode masks
251        if output_mode == "coco_rle":
252            mask_data["segmentations"] = [amg_utils.coco_encode_rle(rle) for rle in mask_data["rles"]]
253        elif output_mode == "binary_mask":
254            mask_data["segmentations"] = [amg_utils.rle_to_mask(rle) for rle in mask_data["rles"]]
255        else:
256            mask_data["segmentations"] = mask_data["rles"]
257
258        # write mask records
259        curr_anns = []
260        for idx in range(len(mask_data["segmentations"])):
261            ann = {
262                "segmentation": mask_data["segmentations"][idx],
263                "area": amg_utils.area_from_rle(mask_data["rles"][idx]),
264                "bbox": amg_utils.box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
265                "predicted_iou": mask_data["iou_preds"][idx].item(),
266                "stability_score": mask_data["stability_score"][idx].item(),
267                "crop_box": amg_utils.box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
268            }
269            # the data from embedding based segmentation doesn't have the points
270            # so we skip if the corresponding key can't be found
271            try:
272                ann["point_coords"] = [mask_data["points"][idx].tolist()]
273            except KeyError:
274                pass
275            curr_anns.append(ann)
276
277        return curr_anns
278
279    def _to_mask_data(self, masks, iou_preds, crop_box, original_size, points=None):
280        orig_h, orig_w = original_size
281
282        # serialize predictions and store in MaskData
283        data = amg_utils.MaskData(masks=masks.flatten(0, 1), iou_preds=iou_preds.flatten(0, 1))
284        if points is not None:
285            data["points"] = torch.as_tensor(points.repeat(masks.shape[1], axis=0), dtype=torch.float)
286
287        del masks
288
289        # calculate the stability scores
290        data["stability_score"] = amg_utils.calculate_stability_score(
291            data["masks"], self._predictor.model.mask_threshold, self._stability_score_offset
292        )
293
294        # threshold masks and calculate boxes
295        data["masks"] = data["masks"] > self._predictor.model.mask_threshold
296        data["masks"] = data["masks"].type(torch.bool)
297        data["boxes"] = batched_mask_to_box(data["masks"])
298
299        # compress to RLE
300        data["masks"] = amg_utils.uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
301        # data["rles"] = amg_utils.mask_to_rle_pytorch(data["masks"])
302        data["rles"] = mask_to_rle_pytorch(data["masks"])
303        del data["masks"]
304
305        return data
306
307    def get_state(self) -> Dict[str, Any]:
308        """Get the initialized state of the mask generator.
309
310        Returns:
311            State of the mask generator.
312        """
313        if not self.is_initialized:
314            raise RuntimeError("The state has not been computed yet. Call initialize first.")
315
316        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
317
318    def set_state(self, state: Dict[str, Any]) -> None:
319        """Set the state of the mask generator.
320
321        Args:
322            state: The state of the mask generator, e.g. from serialized state.
323        """
324        self._crop_list = state["crop_list"]
325        self._crop_boxes = state["crop_boxes"]
326        self._original_size = state["original_size"]
327        self._is_initialized = True
328
329    def clear_state(self):
330        """Clear the state of the mask generator.
331        """
332        self._crop_list = None
333        self._crop_boxes = None
334        self._original_size = None
335        self._is_initialized = False

Base class for the automatic mask generators.

is_initialized
128    @property
129    def is_initialized(self):
130        """Whether the mask generator has already been initialized.
131        """
132        return self._is_initialized

Whether the mask generator has already been initialized.

crop_list
134    @property
135    def crop_list(self):
136        """The list of mask data after initialization.
137        """
138        return self._crop_list

The list of mask data after initialization.

crop_boxes
140    @property
141    def crop_boxes(self):
142        """The list of crop boxes.
143        """
144        return self._crop_boxes

The list of crop boxes.

original_size
146    @property
147    def original_size(self):
148        """The original image size.
149        """
150        return self._original_size

The original image size.

def get_state(self) -> Dict[str, Any]:
307    def get_state(self) -> Dict[str, Any]:
308        """Get the initialized state of the mask generator.
309
310        Returns:
311            State of the mask generator.
312        """
313        if not self.is_initialized:
314            raise RuntimeError("The state has not been computed yet. Call initialize first.")
315
316        return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}

Get the initialized state of the mask generator.

Returns:

State of the mask generator.

def set_state(self, state: Dict[str, Any]) -> None:
318    def set_state(self, state: Dict[str, Any]) -> None:
319        """Set the state of the mask generator.
320
321        Args:
322            state: The state of the mask generator, e.g. from serialized state.
323        """
324        self._crop_list = state["crop_list"]
325        self._crop_boxes = state["crop_boxes"]
326        self._original_size = state["original_size"]
327        self._is_initialized = True

Set the state of the mask generator.

Arguments:
  • state: The state of the mask generator, e.g. from serialized state.
def clear_state(self):
329    def clear_state(self):
330        """Clear the state of the mask generator.
331        """
332        self._crop_list = None
333        self._crop_boxes = None
334        self._original_size = None
335        self._is_initialized = False

Clear the state of the mask generator.

class AutomaticMaskGenerator(AMGBase):
338class AutomaticMaskGenerator(AMGBase):
339    """Generates an instance segmentation without prompts, using a point grid.
340
341    This class implements the same logic as
342    https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
343    It decouples the computationally expensive steps of generating masks from the cheap post-processing operation
344    to filter these masks to enable grid search and interactively changing the post-processing.
345
346    Use this class as follows:
347    ```python
348    amg = AutomaticMaskGenerator(predictor)
349    amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
350    masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
351    ```
352
353    Args:
354        predictor: The segment anything predictor.
355        points_per_side: The number of points to be sampled along one side of the image.
356            If None, `point_grids` must provide explicit point sampling.
357        points_per_batch: The number of points run simultaneously by the model.
358            Higher numbers may be faster but use more GPU memory.
359        crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
360        crop_overlap_ratio: Sets the degree to which crops overlap.
361        crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
362        point_grids: A lisst over explicit grids of points used for sampling masks.
363            Normalized to [0, 1] with respect to the image coordinate system.
364        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
365    """
366    def __init__(
367        self,
368        predictor: SamPredictor,
369        points_per_side: Optional[int] = 32,
370        points_per_batch: Optional[int] = None,
371        crop_n_layers: int = 0,
372        crop_overlap_ratio: float = 512 / 1500,
373        crop_n_points_downscale_factor: int = 1,
374        point_grids: Optional[List[np.ndarray]] = None,
375        stability_score_offset: float = 1.0,
376    ):
377        super().__init__()
378
379        if points_per_side is not None:
380            self.point_grids = amg_utils.build_all_layer_point_grids(
381                points_per_side,
382                crop_n_layers,
383                crop_n_points_downscale_factor,
384            )
385        elif point_grids is not None:
386            self.point_grids = point_grids
387        else:
388            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
389
390        self._predictor = predictor
391        self._points_per_side = points_per_side
392
393        # we set the points per batch to 16 for mps for performance reasons
394        # and otherwise keep them at the default of 64
395        if points_per_batch is None:
396            points_per_batch = 16 if str(predictor.device) == "mps" else 64
397        self._points_per_batch = points_per_batch
398
399        self._crop_n_layers = crop_n_layers
400        self._crop_overlap_ratio = crop_overlap_ratio
401        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
402        self._stability_score_offset = stability_score_offset
403
404    def _process_batch(self, points, im_size, crop_box, original_size):
405        # run model on this batch
406        transformed_points = self._predictor.transform.apply_coords(points, im_size)
407        in_points = torch.as_tensor(transformed_points, device=self._predictor.device, dtype=torch.float)
408        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
409        masks, iou_preds, _ = self._predictor.predict_torch(
410            in_points[:, None, :],
411            in_labels[:, None],
412            multimask_output=True,
413            return_logits=True,
414        )
415        data = self._to_mask_data(masks, iou_preds, crop_box, original_size, points=points)
416        del masks
417        return data
418
419    def _process_crop(self, image, crop_box, crop_layer_idx, precomputed_embeddings, pbar_init=None, pbar_update=None):
420        # Crop the image and calculate embeddings.
421        x0, y0, x1, y1 = crop_box
422        cropped_im = image[y0:y1, x0:x1, :]
423        cropped_im_size = cropped_im.shape[:2]
424
425        if not precomputed_embeddings:
426            self._predictor.set_image(cropped_im)
427
428        # Get the points for this crop.
429        points_scale = np.array(cropped_im_size)[None, ::-1]
430        points_for_image = self.point_grids[crop_layer_idx] * points_scale
431
432        # Generate masks for this crop in batches.
433        data = amg_utils.MaskData()
434        n_batches = len(points_for_image) // self._points_per_batch +\
435            int(len(points_for_image) % self._points_per_batch != 0)
436        if pbar_init is not None:
437            pbar_init(n_batches, "Predict masks for point grid prompts")
438
439        for (points,) in amg_utils.batch_iterator(self._points_per_batch, points_for_image):
440            batch_data = self._process_batch(points, cropped_im_size, crop_box, self.original_size)
441            data.cat(batch_data)
442            del batch_data
443            if pbar_update is not None:
444                pbar_update(1)
445
446        if not precomputed_embeddings:
447            self._predictor.reset_image()
448
449        return data
450
451    @torch.no_grad()
452    def initialize(
453        self,
454        image: np.ndarray,
455        image_embeddings: Optional[util.ImageEmbeddings] = None,
456        i: Optional[int] = None,
457        verbose: bool = False,
458        pbar_init: Optional[callable] = None,
459        pbar_update: Optional[callable] = None,
460    ) -> None:
461        """Initialize image embeddings and masks for an image.
462
463        Args:
464            image: The input image, volume or timeseries.
465            image_embeddings: Optional precomputed image embeddings.
466                See `util.precompute_image_embeddings` for details.
467            i: Index for the image data. Required if `image` has three spatial dimensions
468                or a time dimension and two spatial dimensions.
469            verbose: Whether to print computation progress.
470            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
471                Can be used together with pbar_update to handle napari progress bar in other thread.
472                To enables using this function within a threadworker.
473            pbar_update: Callback to update an external progress bar.
474        """
475        original_size = image.shape[:2]
476        self._original_size = original_size
477
478        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
479            original_size, self._crop_n_layers, self._crop_overlap_ratio
480        )
481
482        # We can set fixed image embeddings if we only have a single crop box (the default setting).
483        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
484        if len(crop_boxes) == 1:
485            if image_embeddings is None:
486                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
487            util.set_precomputed(self._predictor, image_embeddings, i=i)
488            precomputed_embeddings = True
489        else:
490            precomputed_embeddings = False
491
492        # we need to cast to the image representation that is compatible with SAM
493        image = util._to_image(image)
494
495        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
496
497        crop_list = []
498        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
499            crop_data = self._process_crop(
500                image, crop_box, layer_idx,
501                precomputed_embeddings=precomputed_embeddings,
502                pbar_init=pbar_init, pbar_update=pbar_update,
503            )
504            crop_list.append(crop_data)
505        pbar_close()
506
507        self._is_initialized = True
508        self._crop_list = crop_list
509        self._crop_boxes = crop_boxes
510
511    @torch.no_grad()
512    def generate(
513        self,
514        pred_iou_thresh: float = 0.88,
515        stability_score_thresh: float = 0.95,
516        box_nms_thresh: float = 0.7,
517        crop_nms_thresh: float = 0.7,
518        min_mask_region_area: int = 0,
519        output_mode: str = "binary_mask",
520    ) -> List[Dict[str, Any]]:
521        """Generate instance segmentation for the currently initialized image.
522
523        Args:
524            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
525            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
526                under changes to the cutoff used to binarize the model prediction.
527            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
528            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
529            min_mask_region_area: Minimal size for the predicted masks.
530            output_mode: The form masks are returned in.
531
532        Returns:
533            The instance segmentation masks.
534        """
535        if not self.is_initialized:
536            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
537
538        data = amg_utils.MaskData()
539        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
540            crop_data = self._postprocess_batch(
541                data=deepcopy(data_),
542                crop_box=crop_box, original_size=self.original_size,
543                pred_iou_thresh=pred_iou_thresh,
544                stability_score_thresh=stability_score_thresh,
545                box_nms_thresh=box_nms_thresh
546            )
547            data.cat(crop_data)
548
549        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
550            # Prefer masks from smaller crops
551            scores = 1 / box_area(data["crop_boxes"])
552            scores = scores.to(data["boxes"].device)
553            keep_by_nms = batched_nms(
554                data["boxes"].float(),
555                scores,
556                torch.zeros_like(data["boxes"][:, 0]),  # categories
557                iou_threshold=crop_nms_thresh,
558            )
559            data.filter(keep_by_nms)
560
561        data.to_numpy()
562        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
563        return masks

Generates an instance segmentation without prompts, using a point grid.

This class implements the same logic as https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py It decouples the computationally expensive steps of generating masks from the cheap post-processing operation to filter these masks to enable grid search and interactively changing the post-processing.

Use this class as follows:

amg = AutomaticMaskGenerator(predictor)
amg.initialize(image)  # Initialize the masks, this takes care of all expensive computations.
masks = amg.generate(pred_iou_thresh=0.8)  # Generate the masks. This is fast and enables testing parameters
Arguments:
  • predictor: The segment anything predictor.
  • points_per_side: The number of points to be sampled along one side of the image. If None, point_grids must provide explicit point sampling.
  • points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
  • crop_n_layers: If >0, the mask prediction will be run again on crops of the image.
  • crop_overlap_ratio: Sets the degree to which crops overlap.
  • crop_n_points_downscale_factor: How the number of points is downsampled when predicting with crops.
  • point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
  • stability_score_offset: The amount to shift the cutoff when calculating the stability score.
AutomaticMaskGenerator( predictor: segment_anything.predictor.SamPredictor, points_per_side: Optional[int] = 32, points_per_batch: Optional[int] = None, crop_n_layers: int = 0, crop_overlap_ratio: float = 0.3413333333333333, crop_n_points_downscale_factor: int = 1, point_grids: Optional[List[numpy.ndarray]] = None, stability_score_offset: float = 1.0)
366    def __init__(
367        self,
368        predictor: SamPredictor,
369        points_per_side: Optional[int] = 32,
370        points_per_batch: Optional[int] = None,
371        crop_n_layers: int = 0,
372        crop_overlap_ratio: float = 512 / 1500,
373        crop_n_points_downscale_factor: int = 1,
374        point_grids: Optional[List[np.ndarray]] = None,
375        stability_score_offset: float = 1.0,
376    ):
377        super().__init__()
378
379        if points_per_side is not None:
380            self.point_grids = amg_utils.build_all_layer_point_grids(
381                points_per_side,
382                crop_n_layers,
383                crop_n_points_downscale_factor,
384            )
385        elif point_grids is not None:
386            self.point_grids = point_grids
387        else:
388            raise ValueError("Can't have both points_per_side and point_grid be None or not None.")
389
390        self._predictor = predictor
391        self._points_per_side = points_per_side
392
393        # we set the points per batch to 16 for mps for performance reasons
394        # and otherwise keep them at the default of 64
395        if points_per_batch is None:
396            points_per_batch = 16 if str(predictor.device) == "mps" else 64
397        self._points_per_batch = points_per_batch
398
399        self._crop_n_layers = crop_n_layers
400        self._crop_overlap_ratio = crop_overlap_ratio
401        self._crop_n_points_downscale_factor = crop_n_points_downscale_factor
402        self._stability_score_offset = stability_score_offset
@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
451    @torch.no_grad()
452    def initialize(
453        self,
454        image: np.ndarray,
455        image_embeddings: Optional[util.ImageEmbeddings] = None,
456        i: Optional[int] = None,
457        verbose: bool = False,
458        pbar_init: Optional[callable] = None,
459        pbar_update: Optional[callable] = None,
460    ) -> None:
461        """Initialize image embeddings and masks for an image.
462
463        Args:
464            image: The input image, volume or timeseries.
465            image_embeddings: Optional precomputed image embeddings.
466                See `util.precompute_image_embeddings` for details.
467            i: Index for the image data. Required if `image` has three spatial dimensions
468                or a time dimension and two spatial dimensions.
469            verbose: Whether to print computation progress.
470            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
471                Can be used together with pbar_update to handle napari progress bar in other thread.
472                To enables using this function within a threadworker.
473            pbar_update: Callback to update an external progress bar.
474        """
475        original_size = image.shape[:2]
476        self._original_size = original_size
477
478        crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
479            original_size, self._crop_n_layers, self._crop_overlap_ratio
480        )
481
482        # We can set fixed image embeddings if we only have a single crop box (the default setting).
483        # Otherwise we have to recompute the embeddings for each crop and can't precompute.
484        if len(crop_boxes) == 1:
485            if image_embeddings is None:
486                image_embeddings = util.precompute_image_embeddings(self._predictor, image)
487            util.set_precomputed(self._predictor, image_embeddings, i=i)
488            precomputed_embeddings = True
489        else:
490            precomputed_embeddings = False
491
492        # we need to cast to the image representation that is compatible with SAM
493        image = util._to_image(image)
494
495        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
496
497        crop_list = []
498        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
499            crop_data = self._process_crop(
500                image, crop_box, layer_idx,
501                precomputed_embeddings=precomputed_embeddings,
502                pbar_init=pbar_init, pbar_update=pbar_update,
503            )
504            crop_list.append(crop_data)
505        pbar_close()
506
507        self._is_initialized = True
508        self._crop_list = crop_list
509        self._crop_boxes = crop_boxes

Initialize image embeddings and masks for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to print computation progress.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
@torch.no_grad()
def generate( self, pred_iou_thresh: float = 0.88, stability_score_thresh: float = 0.95, box_nms_thresh: float = 0.7, crop_nms_thresh: float = 0.7, min_mask_region_area: int = 0, output_mode: str = 'binary_mask') -> List[Dict[str, Any]]:
511    @torch.no_grad()
512    def generate(
513        self,
514        pred_iou_thresh: float = 0.88,
515        stability_score_thresh: float = 0.95,
516        box_nms_thresh: float = 0.7,
517        crop_nms_thresh: float = 0.7,
518        min_mask_region_area: int = 0,
519        output_mode: str = "binary_mask",
520    ) -> List[Dict[str, Any]]:
521        """Generate instance segmentation for the currently initialized image.
522
523        Args:
524            pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
525            stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask
526                under changes to the cutoff used to binarize the model prediction.
527            box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
528            crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
529            min_mask_region_area: Minimal size for the predicted masks.
530            output_mode: The form masks are returned in.
531
532        Returns:
533            The instance segmentation masks.
534        """
535        if not self.is_initialized:
536            raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
537
538        data = amg_utils.MaskData()
539        for data_, crop_box in zip(self.crop_list, self.crop_boxes):
540            crop_data = self._postprocess_batch(
541                data=deepcopy(data_),
542                crop_box=crop_box, original_size=self.original_size,
543                pred_iou_thresh=pred_iou_thresh,
544                stability_score_thresh=stability_score_thresh,
545                box_nms_thresh=box_nms_thresh
546            )
547            data.cat(crop_data)
548
549        if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
550            # Prefer masks from smaller crops
551            scores = 1 / box_area(data["crop_boxes"])
552            scores = scores.to(data["boxes"].device)
553            keep_by_nms = batched_nms(
554                data["boxes"].float(),
555                scores,
556                torch.zeros_like(data["boxes"][:, 0]),  # categories
557                iou_threshold=crop_nms_thresh,
558            )
559            data.filter(keep_by_nms)
560
561        data.to_numpy()
562        masks = self._postprocess_masks(data, min_mask_region_area, box_nms_thresh, crop_nms_thresh, output_mode)
563        return masks

Generate instance segmentation for the currently initialized image.

Arguments:
  • pred_iou_thresh: Filter threshold in [0, 1], using the mask quality predicted by the model.
  • stability_score_thresh: Filter threshold in [0, 1], using the stability of the mask under changes to the cutoff used to binarize the model prediction.
  • box_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks.
  • crop_nms_thresh: The IoU threshold used by nonmax suppression to filter duplicate masks between crops.
  • min_mask_region_area: Minimal size for the predicted masks.
  • output_mode: The form masks are returned in.
Returns:

The instance segmentation masks.

class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
591class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
592    """Generates an instance segmentation without prompts, using a point grid.
593
594    Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
595
596    Args:
597        predictor: The segment anything predictor.
598        points_per_side: The number of points to be sampled along one side of the image.
599            If None, `point_grids` must provide explicit point sampling.
600        points_per_batch: The number of points run simultaneously by the model.
601            Higher numbers may be faster but use more GPU memory.
602        point_grids: A lisst over explicit grids of points used for sampling masks.
603            Normalized to [0, 1] with respect to the image coordinate system.
604        stability_score_offset: The amount to shift the cutoff when calculating the stability score.
605    """
606
607    # We only expose the arguments that make sense for the tiled mask generator.
608    # Anything related to crops doesn't make sense, because we re-use that functionality
609    # for tiling, so these parameters wouldn't have any effect.
610    def __init__(
611        self,
612        predictor: SamPredictor,
613        points_per_side: Optional[int] = 32,
614        points_per_batch: int = 64,
615        point_grids: Optional[List[np.ndarray]] = None,
616        stability_score_offset: float = 1.0,
617    ) -> None:
618        super().__init__(
619            predictor=predictor,
620            points_per_side=points_per_side,
621            points_per_batch=points_per_batch,
622            point_grids=point_grids,
623            stability_score_offset=stability_score_offset,
624        )
625
626    @torch.no_grad()
627    def initialize(
628        self,
629        image: np.ndarray,
630        image_embeddings: Optional[util.ImageEmbeddings] = None,
631        i: Optional[int] = None,
632        tile_shape: Optional[Tuple[int, int]] = None,
633        halo: Optional[Tuple[int, int]] = None,
634        verbose: bool = False,
635        pbar_init: Optional[callable] = None,
636        pbar_update: Optional[callable] = None,
637    ) -> None:
638        """Initialize image embeddings and masks for an image.
639
640        Args:
641            image: The input image, volume or timeseries.
642            image_embeddings: Optional precomputed image embeddings.
643                See `util.precompute_image_embeddings` for details.
644            i: Index for the image data. Required if `image` has three spatial dimensions
645                or a time dimension and two spatial dimensions.
646            tile_shape: The tile shape for embedding prediction.
647            halo: The overlap of between tiles.
648            verbose: Whether to print computation progress.
649            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
650                Can be used together with pbar_update to handle napari progress bar in other thread.
651                To enables using this function within a threadworker.
652            pbar_update: Callback to update an external progress bar.
653        """
654        original_size = image.shape[:2]
655        self._original_size = original_size
656
657        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
658            self._predictor, image, image_embeddings, tile_shape, halo
659        )
660
661        tiling = blocking([0, 0], original_size, tile_shape)
662        n_tiles = tiling.numberOfBlocks
663
664        # The crop box is always the full local tile.
665        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
666        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
667
668        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
669        pbar_init(n_tiles, "Compute masks for tile")
670
671        # We need to cast to the image representation that is compatible with SAM.
672        image = util._to_image(image)
673
674        mask_data = []
675        for tile_id in range(n_tiles):
676            # set the pre-computed embeddings for this tile
677            features = image_embeddings["features"][tile_id]
678            tile_embeddings = {
679                "features": features,
680                "input_size": features.attrs["input_size"],
681                "original_size": features.attrs["original_size"],
682            }
683            util.set_precomputed(self._predictor, tile_embeddings, i)
684
685            # compute the mask data for this tile and append it
686            this_mask_data = self._process_crop(
687                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
688            )
689            mask_data.append(this_mask_data)
690            pbar_update(1)
691        pbar_close()
692
693        # set the initialized data
694        self._is_initialized = True
695        self._crop_list = mask_data
696        self._crop_boxes = crop_boxes

Generates an instance segmentation without prompts, using a point grid.

Implements the same functionality as AutomaticMaskGenerator but for tiled embeddings.

Arguments:
  • predictor: The segment anything predictor.
  • points_per_side: The number of points to be sampled along one side of the image. If None, point_grids must provide explicit point sampling.
  • points_per_batch: The number of points run simultaneously by the model. Higher numbers may be faster but use more GPU memory.
  • point_grids: A lisst over explicit grids of points used for sampling masks. Normalized to [0, 1] with respect to the image coordinate system.
  • stability_score_offset: The amount to shift the cutoff when calculating the stability score.
TiledAutomaticMaskGenerator( predictor: segment_anything.predictor.SamPredictor, points_per_side: Optional[int] = 32, points_per_batch: int = 64, point_grids: Optional[List[numpy.ndarray]] = None, stability_score_offset: float = 1.0)
610    def __init__(
611        self,
612        predictor: SamPredictor,
613        points_per_side: Optional[int] = 32,
614        points_per_batch: int = 64,
615        point_grids: Optional[List[np.ndarray]] = None,
616        stability_score_offset: float = 1.0,
617    ) -> None:
618        super().__init__(
619            predictor=predictor,
620            points_per_side=points_per_side,
621            points_per_batch=points_per_batch,
622            point_grids=point_grids,
623            stability_score_offset=stability_score_offset,
624        )
@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
626    @torch.no_grad()
627    def initialize(
628        self,
629        image: np.ndarray,
630        image_embeddings: Optional[util.ImageEmbeddings] = None,
631        i: Optional[int] = None,
632        tile_shape: Optional[Tuple[int, int]] = None,
633        halo: Optional[Tuple[int, int]] = None,
634        verbose: bool = False,
635        pbar_init: Optional[callable] = None,
636        pbar_update: Optional[callable] = None,
637    ) -> None:
638        """Initialize image embeddings and masks for an image.
639
640        Args:
641            image: The input image, volume or timeseries.
642            image_embeddings: Optional precomputed image embeddings.
643                See `util.precompute_image_embeddings` for details.
644            i: Index for the image data. Required if `image` has three spatial dimensions
645                or a time dimension and two spatial dimensions.
646            tile_shape: The tile shape for embedding prediction.
647            halo: The overlap of between tiles.
648            verbose: Whether to print computation progress.
649            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
650                Can be used together with pbar_update to handle napari progress bar in other thread.
651                To enables using this function within a threadworker.
652            pbar_update: Callback to update an external progress bar.
653        """
654        original_size = image.shape[:2]
655        self._original_size = original_size
656
657        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
658            self._predictor, image, image_embeddings, tile_shape, halo
659        )
660
661        tiling = blocking([0, 0], original_size, tile_shape)
662        n_tiles = tiling.numberOfBlocks
663
664        # The crop box is always the full local tile.
665        tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
666        crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
667
668        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
669        pbar_init(n_tiles, "Compute masks for tile")
670
671        # We need to cast to the image representation that is compatible with SAM.
672        image = util._to_image(image)
673
674        mask_data = []
675        for tile_id in range(n_tiles):
676            # set the pre-computed embeddings for this tile
677            features = image_embeddings["features"][tile_id]
678            tile_embeddings = {
679                "features": features,
680                "input_size": features.attrs["input_size"],
681                "original_size": features.attrs["original_size"],
682            }
683            util.set_precomputed(self._predictor, tile_embeddings, i)
684
685            # compute the mask data for this tile and append it
686            this_mask_data = self._process_crop(
687                image, crop_box=crop_boxes[tile_id], crop_layer_idx=0, precomputed_embeddings=True
688            )
689            mask_data.append(this_mask_data)
690            pbar_update(1)
691        pbar_close()
692
693        # set the initialized data
694        self._is_initialized = True
695        self._crop_list = mask_data
696        self._crop_boxes = crop_boxes

Initialize image embeddings and masks for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_shape: The tile shape for embedding prediction.
  • halo: The overlap of between tiles.
  • verbose: Whether to print computation progress.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
class DecoderAdapter(torch.nn.modules.module.Module):
704class DecoderAdapter(torch.nn.Module):
705    """Adapter to contain the UNETR decoder in a single module.
706
707    To apply the decoder on top of pre-computed embeddings for
708    the segmentation functionality.
709    See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py
710    """
711    def __init__(self, unetr):
712        super().__init__()
713
714        self.base = unetr.base
715        self.out_conv = unetr.out_conv
716        self.deconv_out = unetr.deconv_out
717        self.decoder_head = unetr.decoder_head
718        self.final_activation = unetr.final_activation
719        self.postprocess_masks = unetr.postprocess_masks
720
721        self.decoder = unetr.decoder
722        self.deconv1 = unetr.deconv1
723        self.deconv2 = unetr.deconv2
724        self.deconv3 = unetr.deconv3
725        self.deconv4 = unetr.deconv4
726
727    def forward(self, input_, input_shape, original_shape):
728        z12 = input_
729
730        z9 = self.deconv1(z12)
731        z6 = self.deconv2(z9)
732        z3 = self.deconv3(z6)
733        z0 = self.deconv4(z3)
734
735        updated_from_encoder = [z9, z6, z3]
736
737        x = self.base(z12)
738        x = self.decoder(x, encoder_inputs=updated_from_encoder)
739        x = self.deconv_out(x)
740
741        x = torch.cat([x, z0], dim=1)
742        x = self.decoder_head(x)
743
744        x = self.out_conv(x)
745        if self.final_activation is not None:
746            x = self.final_activation(x)
747
748        x = self.postprocess_masks(x, input_shape, original_shape)
749        return x

Adapter to contain the UNETR decoder in a single module.

To apply the decoder on top of pre-computed embeddings for the segmentation functionality. See also: https://github.com/constantinpape/torch-em/blob/main/torch_em/model/unetr.py

DecoderAdapter(unetr)
711    def __init__(self, unetr):
712        super().__init__()
713
714        self.base = unetr.base
715        self.out_conv = unetr.out_conv
716        self.deconv_out = unetr.deconv_out
717        self.decoder_head = unetr.decoder_head
718        self.final_activation = unetr.final_activation
719        self.postprocess_masks = unetr.postprocess_masks
720
721        self.decoder = unetr.decoder
722        self.deconv1 = unetr.deconv1
723        self.deconv2 = unetr.deconv2
724        self.deconv3 = unetr.deconv3
725        self.deconv4 = unetr.deconv4

Initialize internal Module state, shared by both nn.Module and ScriptModule.

base
out_conv
deconv_out
decoder_head
final_activation
postprocess_masks
decoder
deconv1
deconv2
deconv3
deconv4
def forward(self, input_, input_shape, original_shape):
727    def forward(self, input_, input_shape, original_shape):
728        z12 = input_
729
730        z9 = self.deconv1(z12)
731        z6 = self.deconv2(z9)
732        z3 = self.deconv3(z6)
733        z0 = self.deconv4(z3)
734
735        updated_from_encoder = [z9, z6, z3]
736
737        x = self.base(z12)
738        x = self.decoder(x, encoder_inputs=updated_from_encoder)
739        x = self.deconv_out(x)
740
741        x = torch.cat([x, z0], dim=1)
742        x = self.decoder_head(x)
743
744        x = self.out_conv(x)
745        if self.final_activation is not None:
746            x = self.final_activation(x)
747
748        x = self.postprocess_masks(x, input_shape, original_shape)
749        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
def get_unetr( image_encoder: torch.nn.modules.module.Module, decoder_state: Optional[collections.OrderedDict[str, torch.Tensor]] = None, device: Union[str, torch.device, NoneType] = None) -> torch.nn.modules.module.Module:
752def get_unetr(
753    image_encoder: torch.nn.Module,
754    decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
755    device: Optional[Union[str, torch.device]] = None,
756) -> torch.nn.Module:
757    """Get UNETR model for automatic instance segmentation.
758
759    Args:
760        image_encoder: The image encoder of the SAM model.
761            This is used as encoder by the UNETR too.
762        decoder_state: Optional decoder state to initialize the weights
763            of the UNETR decoder.
764        device: The device.
765    Returns:
766        The UNETR model.
767    """
768    device = util.get_device(device)
769
770    unetr = UNETR(
771        backbone="sam",
772        encoder=image_encoder,
773        out_channels=3,
774        use_sam_stats=True,
775        final_activation="Sigmoid",
776        use_skip_connection=False,
777        resize_input=True,
778    )
779    if decoder_state is not None:
780        unetr_state_dict = unetr.state_dict()
781        for k, v in unetr_state_dict.items():
782            if not k.startswith("encoder"):
783                unetr_state_dict[k] = decoder_state[k]
784        unetr.load_state_dict(unetr_state_dict)
785
786    unetr.to(device)
787    return unetr

Get UNETR model for automatic instance segmentation.

Arguments:
  • image_encoder: The image encoder of the SAM model. This is used as encoder by the UNETR too.
  • decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
  • device: The device.
Returns:

The UNETR model.

def get_decoder( image_encoder: torch.nn.modules.module.Module, decoder_state: collections.OrderedDict[str, torch.Tensor], device: Union[str, torch.device, NoneType] = None) -> DecoderAdapter:
790def get_decoder(
791    image_encoder: torch.nn.Module,
792    decoder_state: OrderedDict[str, torch.Tensor],
793    device: Optional[Union[str, torch.device]] = None,
794) -> DecoderAdapter:
795    """Get decoder to predict outputs for automatic instance segmentation
796
797    Args:
798        image_encoder: The image encoder of the SAM model.
799        decoder_state: State to initialize the weights of the UNETR decoder.
800        device: The device.
801    Returns:
802        The decoder for instance segmentation.
803    """
804    unetr = get_unetr(image_encoder, decoder_state, device)
805    return DecoderAdapter(unetr)

Get decoder to predict outputs for automatic instance segmentation

Arguments:
  • image_encoder: The image encoder of the SAM model.
  • decoder_state: State to initialize the weights of the UNETR decoder.
  • device: The device.
Returns:

The decoder for instance segmentation.

def get_predictor_and_decoder( model_type: str, checkpoint_path: Union[str, os.PathLike], device: Union[str, torch.device, NoneType] = None, peft_kwargs: Optional[Dict] = None) -> Tuple[segment_anything.predictor.SamPredictor, DecoderAdapter]:
808def get_predictor_and_decoder(
809    model_type: str,
810    checkpoint_path: Union[str, os.PathLike],
811    device: Optional[Union[str, torch.device]] = None,
812    peft_kwargs: Optional[Dict] = None,
813) -> Tuple[SamPredictor, DecoderAdapter]:
814    """Load the SAM model (predictor) and instance segmentation decoder.
815
816    This requires a checkpoint that contains the state for both predictor
817    and decoder.
818
819    Args:
820        model_type: The type of the image encoder used in the SAM model.
821        checkpoint_path: Path to the checkpoint from which to load the data.
822        device: The device.
823        lora_rank: The rank for low rank adaptation of the attention layers.
824
825    Returns:
826        The SAM predictor.
827        The decoder for instance segmentation.
828    """
829    device = util.get_device(device)
830    predictor, state = util.get_sam_model(
831        model_type=model_type,
832        checkpoint_path=checkpoint_path,
833        device=device,
834        return_state=True,
835        peft_kwargs=peft_kwargs,
836    )
837    if "decoder_state" not in state:
838        raise ValueError(
839            f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
840        )
841    decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
842    return predictor, decoder

Load the SAM model (predictor) and instance segmentation decoder.

This requires a checkpoint that contains the state for both predictor and decoder.

Arguments:
  • model_type: The type of the image encoder used in the SAM model.
  • checkpoint_path: Path to the checkpoint from which to load the data.
  • device: The device.
  • lora_rank: The rank for low rank adaptation of the attention layers.
Returns:

The SAM predictor. The decoder for instance segmentation.

class InstanceSegmentationWithDecoder:
 845class InstanceSegmentationWithDecoder:
 846    """Generates an instance segmentation without prompts, using a decoder.
 847
 848    Implements the same interface as `AutomaticMaskGenerator`.
 849
 850    Use this class as follows:
 851    ```python
 852    segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
 853    segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
 854    masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
 855    ```
 856
 857    Args:
 858        predictor: The segment anything predictor.
 859        decoder: The decoder to predict intermediate representations
 860            for instance segmentation.
 861    """
 862    def __init__(
 863        self,
 864        predictor: SamPredictor,
 865        decoder: torch.nn.Module,
 866    ) -> None:
 867        self._predictor = predictor
 868        self._decoder = decoder
 869
 870        # The decoder outputs.
 871        self._foreground = None
 872        self._center_distances = None
 873        self._boundary_distances = None
 874
 875        self._is_initialized = False
 876
 877    @property
 878    def is_initialized(self):
 879        """Whether the mask generator has already been initialized.
 880        """
 881        return self._is_initialized
 882
 883    @torch.no_grad()
 884    def initialize(
 885        self,
 886        image: np.ndarray,
 887        image_embeddings: Optional[util.ImageEmbeddings] = None,
 888        i: Optional[int] = None,
 889        verbose: bool = False,
 890        pbar_init: Optional[callable] = None,
 891        pbar_update: Optional[callable] = None,
 892    ) -> None:
 893        """Initialize image embeddings and decoder predictions for an image.
 894
 895        Args:
 896            image: The input image, volume or timeseries.
 897            image_embeddings: Optional precomputed image embeddings.
 898                See `util.precompute_image_embeddings` for details.
 899            i: Index for the image data. Required if `image` has three spatial dimensions
 900                or a time dimension and two spatial dimensions.
 901            verbose: Whether to be verbose.
 902            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 903                Can be used together with pbar_update to handle napari progress bar in other thread.
 904                To enables using this function within a threadworker.
 905            pbar_update: Callback to update an external progress bar.
 906        """
 907        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
 908        pbar_init(1, "Initialize instance segmentation with decoder")
 909
 910        if image_embeddings is None:
 911            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
 912
 913        # Get the image embeddings from the predictor.
 914        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
 915        embeddings = self._predictor.features
 916        input_shape = tuple(self._predictor.input_size)
 917        original_shape = tuple(self._predictor.original_size)
 918
 919        # Run prediction with the UNETR decoder.
 920        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
 921        assert output.shape[0] == 3, f"{output.shape}"
 922        pbar_update(1)
 923        pbar_close()
 924
 925        # Set the state.
 926        self._foreground = output[0]
 927        self._center_distances = output[1]
 928        self._boundary_distances = output[2]
 929        self._is_initialized = True
 930
 931    def _to_masks(self, segmentation, output_mode):
 932        if output_mode != "binary_mask":
 933            raise NotImplementedError
 934
 935        props = regionprops(segmentation)
 936        ndim = segmentation.ndim
 937        assert ndim in (2, 3)
 938
 939        shape = segmentation.shape
 940        if ndim == 2:
 941            crop_box = [0, shape[1], 0, shape[0]]
 942        else:
 943            crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]
 944
 945        # go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
 946        def to_bbox_2d(bbox):
 947            y0, x0 = bbox[0], bbox[1]
 948            w = bbox[3] - x0
 949            h = bbox[2] - y0
 950            return [x0, w, y0, h]
 951
 952        def to_bbox_3d(bbox):
 953            z0, y0, x0 = bbox[0], bbox[1], bbox[2]
 954            w = bbox[5] - x0
 955            h = bbox[4] - y0
 956            d = bbox[3] - y0
 957            return [x0, w, y0, h, z0, d]
 958
 959        to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
 960        masks = [
 961            {
 962                "segmentation": segmentation == prop.label,
 963                "area": prop.area,
 964                "bbox": to_bbox(prop.bbox),
 965                "crop_box": crop_box,
 966                "seg_id": prop.label,
 967            } for prop in props
 968        ]
 969        return masks
 970
 971    def generate(
 972        self,
 973        center_distance_threshold: float = 0.5,
 974        boundary_distance_threshold: float = 0.5,
 975        foreground_threshold: float = 0.5,
 976        foreground_smoothing: float = 1.0,
 977        distance_smoothing: float = 1.6,
 978        min_size: int = 0,
 979        output_mode: Optional[str] = "binary_mask",
 980    ) -> List[Dict[str, Any]]:
 981        """Generate instance segmentation for the currently initialized image.
 982
 983        Args:
 984            center_distance_threshold: Center distance predictions below this value will be
 985                used to find seeds (intersected with thresholded boundary distance predictions).
 986            boundary_distance_threshold: Boundary distance predictions below this value will be
 987                used to find seeds (intersected with thresholded center distance predictions).
 988            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
 989                checkerboard artifacts in the prediction.
 990            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
 991            distance_smoothing: Sigma value for smoothing the distance predictions.
 992            min_size: Minimal object size in the segmentation result.
 993            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
 994
 995        Returns:
 996            The instance segmentation masks.
 997        """
 998        if not self.is_initialized:
 999            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1000
1001        if foreground_smoothing > 0:
1002            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1003        else:
1004            foreground = self._foreground
1005        # Further optimization: parallel implementation using elf.parallel functionality.
1006        # (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization)
1007        segmentation = watershed_from_center_and_boundary_distances(
1008            self._center_distances, self._boundary_distances, foreground,
1009            center_distance_threshold=center_distance_threshold,
1010            boundary_distance_threshold=boundary_distance_threshold,
1011            foreground_threshold=foreground_threshold,
1012            distance_smoothing=distance_smoothing,
1013            min_size=min_size,
1014        )
1015        if output_mode is not None:
1016            segmentation = self._to_masks(segmentation, output_mode)
1017        return segmentation
1018
1019    def get_state(self) -> Dict[str, Any]:
1020        """Get the initialized state of the instance segmenter.
1021
1022        Returns:
1023            Instance segmentation state.
1024        """
1025        if not self.is_initialized:
1026            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1027
1028        return {
1029            "foreground": self._foreground,
1030            "center_distances": self._center_distances,
1031            "boundary_distances": self._boundary_distances,
1032        }
1033
1034    def set_state(self, state: Dict[str, Any]) -> None:
1035        """Set the state of the instance segmenter.
1036
1037        Args:
1038            state: The instance segmentation state
1039        """
1040        self._foreground = state["foreground"]
1041        self._center_distances = state["center_distances"]
1042        self._boundary_distances = state["boundary_distances"]
1043        self._is_initialized = True
1044
1045    def clear_state(self):
1046        """Clear the state of the instance segmenter.
1047        """
1048        self._foreground = None
1049        self._center_distances = None
1050        self._boundary_distances = None
1051        self._is_initialized = False

Generates an instance segmentation without prompts, using a decoder.

Implements the same interface as AutomaticMaskGenerator.

Use this class as follows:

segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
segmenter.initialize(image)   # Predict the image embeddings and decoder outputs.
masks = segmenter.generate(center_distance_threshold=0.75)  # Generate the instance segmentation.
Arguments:
  • predictor: The segment anything predictor.
  • decoder: The decoder to predict intermediate representations for instance segmentation.
InstanceSegmentationWithDecoder( predictor: segment_anything.predictor.SamPredictor, decoder: torch.nn.modules.module.Module)
862    def __init__(
863        self,
864        predictor: SamPredictor,
865        decoder: torch.nn.Module,
866    ) -> None:
867        self._predictor = predictor
868        self._decoder = decoder
869
870        # The decoder outputs.
871        self._foreground = None
872        self._center_distances = None
873        self._boundary_distances = None
874
875        self._is_initialized = False
is_initialized
877    @property
878    def is_initialized(self):
879        """Whether the mask generator has already been initialized.
880        """
881        return self._is_initialized

Whether the mask generator has already been initialized.

@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
883    @torch.no_grad()
884    def initialize(
885        self,
886        image: np.ndarray,
887        image_embeddings: Optional[util.ImageEmbeddings] = None,
888        i: Optional[int] = None,
889        verbose: bool = False,
890        pbar_init: Optional[callable] = None,
891        pbar_update: Optional[callable] = None,
892    ) -> None:
893        """Initialize image embeddings and decoder predictions for an image.
894
895        Args:
896            image: The input image, volume or timeseries.
897            image_embeddings: Optional precomputed image embeddings.
898                See `util.precompute_image_embeddings` for details.
899            i: Index for the image data. Required if `image` has three spatial dimensions
900                or a time dimension and two spatial dimensions.
901            verbose: Whether to be verbose.
902            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
903                Can be used together with pbar_update to handle napari progress bar in other thread.
904                To enables using this function within a threadworker.
905            pbar_update: Callback to update an external progress bar.
906        """
907        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
908        pbar_init(1, "Initialize instance segmentation with decoder")
909
910        if image_embeddings is None:
911            image_embeddings = util.precompute_image_embeddings(self._predictor, image)
912
913        # Get the image embeddings from the predictor.
914        self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i)
915        embeddings = self._predictor.features
916        input_shape = tuple(self._predictor.input_size)
917        original_shape = tuple(self._predictor.original_size)
918
919        # Run prediction with the UNETR decoder.
920        output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
921        assert output.shape[0] == 3, f"{output.shape}"
922        pbar_update(1)
923        pbar_close()
924
925        # Set the state.
926        self._foreground = output[0]
927        self._center_distances = output[1]
928        self._boundary_distances = output[2]
929        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Whether to be verbose.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
def generate( self, center_distance_threshold: float = 0.5, boundary_distance_threshold: float = 0.5, foreground_threshold: float = 0.5, foreground_smoothing: float = 1.0, distance_smoothing: float = 1.6, min_size: int = 0, output_mode: Optional[str] = 'binary_mask') -> List[Dict[str, Any]]:
 971    def generate(
 972        self,
 973        center_distance_threshold: float = 0.5,
 974        boundary_distance_threshold: float = 0.5,
 975        foreground_threshold: float = 0.5,
 976        foreground_smoothing: float = 1.0,
 977        distance_smoothing: float = 1.6,
 978        min_size: int = 0,
 979        output_mode: Optional[str] = "binary_mask",
 980    ) -> List[Dict[str, Any]]:
 981        """Generate instance segmentation for the currently initialized image.
 982
 983        Args:
 984            center_distance_threshold: Center distance predictions below this value will be
 985                used to find seeds (intersected with thresholded boundary distance predictions).
 986            boundary_distance_threshold: Boundary distance predictions below this value will be
 987                used to find seeds (intersected with thresholded center distance predictions).
 988            foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid
 989                checkerboard artifacts in the prediction.
 990            foreground_threshold: Foreground predictions above this value will be used as foreground mask.
 991            distance_smoothing: Sigma value for smoothing the distance predictions.
 992            min_size: Minimal object size in the segmentation result.
 993            output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
 994
 995        Returns:
 996            The instance segmentation masks.
 997        """
 998        if not self.is_initialized:
 999            raise RuntimeError("InstanceSegmentationWithDecoder has not been initialized. Call initialize first.")
1000
1001        if foreground_smoothing > 0:
1002            foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
1003        else:
1004            foreground = self._foreground
1005        # Further optimization: parallel implementation using elf.parallel functionality.
1006        # (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization)
1007        segmentation = watershed_from_center_and_boundary_distances(
1008            self._center_distances, self._boundary_distances, foreground,
1009            center_distance_threshold=center_distance_threshold,
1010            boundary_distance_threshold=boundary_distance_threshold,
1011            foreground_threshold=foreground_threshold,
1012            distance_smoothing=distance_smoothing,
1013            min_size=min_size,
1014        )
1015        if output_mode is not None:
1016            segmentation = self._to_masks(segmentation, output_mode)
1017        return segmentation

Generate instance segmentation for the currently initialized image.

Arguments:
  • center_distance_threshold: Center distance predictions below this value will be used to find seeds (intersected with thresholded boundary distance predictions).
  • boundary_distance_threshold: Boundary distance predictions below this value will be used to find seeds (intersected with thresholded center distance predictions).
  • foreground_smoothing: Sigma value for smoothing the foreground predictions, to avoid checkerboard artifacts in the prediction.
  • foreground_threshold: Foreground predictions above this value will be used as foreground mask.
  • distance_smoothing: Sigma value for smoothing the distance predictions.
  • min_size: Minimal object size in the segmentation result.
  • output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
Returns:

The instance segmentation masks.

def get_state(self) -> Dict[str, Any]:
1019    def get_state(self) -> Dict[str, Any]:
1020        """Get the initialized state of the instance segmenter.
1021
1022        Returns:
1023            Instance segmentation state.
1024        """
1025        if not self.is_initialized:
1026            raise RuntimeError("The state has not been computed yet. Call initialize first.")
1027
1028        return {
1029            "foreground": self._foreground,
1030            "center_distances": self._center_distances,
1031            "boundary_distances": self._boundary_distances,
1032        }

Get the initialized state of the instance segmenter.

Returns:

Instance segmentation state.

def set_state(self, state: Dict[str, Any]) -> None:
1034    def set_state(self, state: Dict[str, Any]) -> None:
1035        """Set the state of the instance segmenter.
1036
1037        Args:
1038            state: The instance segmentation state
1039        """
1040        self._foreground = state["foreground"]
1041        self._center_distances = state["center_distances"]
1042        self._boundary_distances = state["boundary_distances"]
1043        self._is_initialized = True

Set the state of the instance segmenter.

Arguments:
  • state: The instance segmentation state
def clear_state(self):
1045    def clear_state(self):
1046        """Clear the state of the instance segmenter.
1047        """
1048        self._foreground = None
1049        self._center_distances = None
1050        self._boundary_distances = None
1051        self._is_initialized = False

Clear the state of the instance segmenter.

class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1054class TiledInstanceSegmentationWithDecoder(InstanceSegmentationWithDecoder):
1055    """Same as `InstanceSegmentationWithDecoder` but for tiled image embeddings.
1056    """
1057
1058    @torch.no_grad()
1059    def initialize(
1060        self,
1061        image: np.ndarray,
1062        image_embeddings: Optional[util.ImageEmbeddings] = None,
1063        i: Optional[int] = None,
1064        tile_shape: Optional[Tuple[int, int]] = None,
1065        halo: Optional[Tuple[int, int]] = None,
1066        verbose: bool = False,
1067        pbar_init: Optional[callable] = None,
1068        pbar_update: Optional[callable] = None,
1069    ) -> None:
1070        """Initialize image embeddings and decoder predictions for an image.
1071
1072        Args:
1073            image: The input image, volume or timeseries.
1074            image_embeddings: Optional precomputed image embeddings.
1075                See `util.precompute_image_embeddings` for details.
1076            i: Index for the image data. Required if `image` has three spatial dimensions
1077                or a time dimension and two spatial dimensions.
1078            verbose: Dummy input to be compatible with other function signatures.
1079            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1080                Can be used together with pbar_update to handle napari progress bar in other thread.
1081                To enables using this function within a threadworker.
1082            pbar_update: Callback to update an external progress bar.
1083        """
1084        original_size = image.shape[:2]
1085        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1086            self._predictor, image, image_embeddings, tile_shape, halo
1087        )
1088        tiling = blocking([0, 0], original_size, tile_shape)
1089
1090        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1091        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1092
1093        foreground = np.zeros(original_size, dtype="float32")
1094        center_distances = np.zeros(original_size, dtype="float32")
1095        boundary_distances = np.zeros(original_size, dtype="float32")
1096
1097        for tile_id in range(tiling.numberOfBlocks):
1098
1099            # Get the image embeddings from the predictor for this tile.
1100            self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1101            embeddings = self._predictor.features
1102            input_shape = tuple(self._predictor.input_size)
1103            original_shape = tuple(self._predictor.original_size)
1104
1105            # Predict with the UNETR decoder for this tile.
1106            output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1107            assert output.shape[0] == 3, f"{output.shape}"
1108
1109            # Set the predictions in the output for this tile.
1110            block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1111            local_bb = tuple(
1112                slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1113            )
1114            inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1115
1116            foreground[inner_bb] = output[0][local_bb]
1117            center_distances[inner_bb] = output[1][local_bb]
1118            boundary_distances[inner_bb] = output[2][local_bb]
1119        pbar_close()
1120
1121        # Set the state.
1122        self._foreground = foreground
1123        self._center_distances = center_distances
1124        self._boundary_distances = boundary_distances
1125        self._is_initialized = True

Same as InstanceSegmentationWithDecoder but for tiled image embeddings.

@torch.no_grad()
def initialize( self, image: numpy.ndarray, image_embeddings: Optional[Dict[str, Any]] = None, i: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = False, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> None:
1058    @torch.no_grad()
1059    def initialize(
1060        self,
1061        image: np.ndarray,
1062        image_embeddings: Optional[util.ImageEmbeddings] = None,
1063        i: Optional[int] = None,
1064        tile_shape: Optional[Tuple[int, int]] = None,
1065        halo: Optional[Tuple[int, int]] = None,
1066        verbose: bool = False,
1067        pbar_init: Optional[callable] = None,
1068        pbar_update: Optional[callable] = None,
1069    ) -> None:
1070        """Initialize image embeddings and decoder predictions for an image.
1071
1072        Args:
1073            image: The input image, volume or timeseries.
1074            image_embeddings: Optional precomputed image embeddings.
1075                See `util.precompute_image_embeddings` for details.
1076            i: Index for the image data. Required if `image` has three spatial dimensions
1077                or a time dimension and two spatial dimensions.
1078            verbose: Dummy input to be compatible with other function signatures.
1079            pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1080                Can be used together with pbar_update to handle napari progress bar in other thread.
1081                To enables using this function within a threadworker.
1082            pbar_update: Callback to update an external progress bar.
1083        """
1084        original_size = image.shape[:2]
1085        image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1086            self._predictor, image, image_embeddings, tile_shape, halo
1087        )
1088        tiling = blocking([0, 0], original_size, tile_shape)
1089
1090        _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1091        pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
1092
1093        foreground = np.zeros(original_size, dtype="float32")
1094        center_distances = np.zeros(original_size, dtype="float32")
1095        boundary_distances = np.zeros(original_size, dtype="float32")
1096
1097        for tile_id in range(tiling.numberOfBlocks):
1098
1099            # Get the image embeddings from the predictor for this tile.
1100            self._predictor = util.set_precomputed(self._predictor, image_embeddings, i=i, tile_id=tile_id)
1101            embeddings = self._predictor.features
1102            input_shape = tuple(self._predictor.input_size)
1103            original_shape = tuple(self._predictor.original_size)
1104
1105            # Predict with the UNETR decoder for this tile.
1106            output = self._decoder(embeddings, input_shape, original_shape).cpu().numpy().squeeze(0)
1107            assert output.shape[0] == 3, f"{output.shape}"
1108
1109            # Set the predictions in the output for this tile.
1110            block = tiling.getBlockWithHalo(tile_id, halo=list(halo))
1111            local_bb = tuple(
1112                slice(beg, end) for beg, end in zip(block.innerBlockLocal.begin, block.innerBlockLocal.end)
1113            )
1114            inner_bb = tuple(slice(beg, end) for beg, end in zip(block.innerBlock.begin, block.innerBlock.end))
1115
1116            foreground[inner_bb] = output[0][local_bb]
1117            center_distances[inner_bb] = output[1][local_bb]
1118            boundary_distances[inner_bb] = output[2][local_bb]
1119        pbar_close()
1120
1121        # Set the state.
1122        self._foreground = foreground
1123        self._center_distances = center_distances
1124        self._boundary_distances = boundary_distances
1125        self._is_initialized = True

Initialize image embeddings and decoder predictions for an image.

Arguments:
  • image: The input image, volume or timeseries.
  • image_embeddings: Optional precomputed image embeddings. See util.precompute_image_embeddings for details.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • verbose: Dummy input to be compatible with other function signatures.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
def get_amg( predictor: segment_anything.predictor.SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.modules.module.Module] = None, **kwargs) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1128def get_amg(
1129    predictor: SamPredictor,
1130    is_tiled: bool,
1131    decoder: Optional[torch.nn.Module] = None,
1132    **kwargs,
1133) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
1134    """Get the automatic mask generator class.
1135
1136    Args:
1137        predictor: The segment anything predictor.
1138        is_tiled: Whether tiled embeddings are used.
1139        decoder: Decoder to predict instacne segmmentation.
1140        kwargs: The keyword arguments for the amg class.
1141
1142    Returns:
1143        The automatic mask generator.
1144    """
1145    if decoder is None:
1146        segmenter = TiledAutomaticMaskGenerator(predictor, **kwargs) if is_tiled else\
1147            AutomaticMaskGenerator(predictor, **kwargs)
1148    else:
1149        segmenter = TiledInstanceSegmentationWithDecoder(predictor, decoder, **kwargs) if is_tiled else\
1150            InstanceSegmentationWithDecoder(predictor, decoder, **kwargs)
1151    return segmenter

Get the automatic mask generator class.

Arguments:
  • predictor: The segment anything predictor.
  • is_tiled: Whether tiled embeddings are used.
  • decoder: Decoder to predict instacne segmmentation.
  • kwargs: The keyword arguments for the amg class.
Returns:

The automatic mask generator.