micro_sam.util

Helper functions for downloading Segment Anything models and predicting image embeddings.

   1"""
   2Helper functions for downloading Segment Anything models and predicting image embeddings.
   3"""
   4
   5import os
   6import multiprocessing as mp
   7import pickle
   8import hashlib
   9import warnings
  10from concurrent import futures
  11from pathlib import Path
  12from collections import OrderedDict
  13from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Callable
  14
  15import elf.parallel as parallel_impl
  16import imageio.v3 as imageio
  17import numpy as np
  18import pooch
  19import segment_anything.utils.amg as amg_utils
  20import torch
  21import xxhash
  22import zarr
  23
  24from elf.io import open_file
  25from bioimage_cpp.utils import Blocking
  26from bioimage_cpp.distance import distance_transform
  27from bioimage_cpp.segmentation import relabel_sequential
  28from skimage.measure import regionprops
  29from skimage.segmentation import find_boundaries
  30from torchvision.ops.boxes import batched_nms
  31
  32from .__version__ import __version__
  33from . import models as custom_models
  34
  35try:
  36    # Avoid import warnigns from mobile_sam
  37    with warnings.catch_warnings():
  38        warnings.simplefilter("ignore")
  39        from mobile_sam import sam_model_registry, SamPredictor
  40    VIT_T_SUPPORT = True
  41except ImportError:
  42    from segment_anything import sam_model_registry, SamPredictor
  43    VIT_T_SUPPORT = False
  44
  45try:
  46    from napari.utils import progress as tqdm
  47except ImportError:
  48    from tqdm import tqdm
  49
  50# This is the default model used in micro_sam
  51# Currently it is set to vit_b_lm
  52_DEFAULT_MODEL = "vit_b_lm"
  53
  54# The valid model types. Each type corresponds to the architecture of the
  55# vision transformer used within SAM.
  56_MODEL_TYPES = ("vit_l", "vit_b", "vit_h", "vit_t")
  57
  58
  59ImageEmbeddings = Dict[str, Any]
  60"""@private"""
  61
  62
  63def get_cache_directory() -> None:
  64    """Get micro-sam cache directory location.
  65
  66    Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
  67    """
  68    default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam"))
  69    cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory))
  70    return cache_directory
  71
  72
  73#
  74# Functionality for model download and export
  75#
  76
  77
  78def microsam_cachedir() -> None:
  79    """Return the micro-sam cache directory.
  80
  81    Returns the top level cache directory for micro-sam models and sample data.
  82
  83    Every time this function is called, we check for any user updates made to
  84    the MICROSAM_CACHEDIR os environment variable since the last time.
  85    """
  86    cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam")
  87    return cache_directory
  88
  89
  90def models():
  91    """Return the segmentation models registry.
  92
  93    We recreate the model registry every time this function is called,
  94    so any user changes to the default micro-sam cache directory location
  95    are respected.
  96    """
  97
  98    # We use xxhash to compute the hash of the models, see
  99    # https://github.com/computational-cell-analytics/micro-sam/issues/283
 100    # (It is now a dependency, so we don't provide the sha256 fallback anymore.)
 101    # To generate the xxh128 hash:
 102    #     xxh128sum filename
 103    encoder_registry = {
 104        # The default segment anything models:
 105        "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a",
 106        "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2",
 107        "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b",
 108        # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM.
 109        "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0",
 110        # The current version of our models in the modelzoo.
 111        # LM generalist models:
 112        "vit_l_lm": "xxh128:017f20677997d628426dec80a8018f9d",
 113        "vit_b_lm": "xxh128:fe9252a29f3f4ea53c15a06de471e186",
 114        "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e",
 115        # EM models:
 116        "vit_l_em_organelles": "xxh128:810b084b6e51acdbf760a993d8619f2d",
 117        "vit_b_em_organelles": "xxh128:f3bf2ed83d691456bae2c3f9a05fb438",
 118        "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9",
 119        # Histopathology models:
 120        "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974",
 121        "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2",
 122        "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e",
 123        # Medical Imaging models:
 124        "vit_b_medical_imaging": "xxh128:40169f1e3c03a4b67bff58249c176d92",
 125    }
 126    # Additional decoders for instance segmentation.
 127    decoder_registry = {
 128        # LM generalist models:
 129        "vit_l_lm_decoder": "xxh128:2faeafa03819dfe03e7c46a44aaac64a",
 130        "vit_b_lm_decoder": "xxh128:708b15ac620e235f90bb38612c4929ba",
 131        "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823",
 132        # EM models:
 133        "vit_l_em_organelles_decoder": "xxh128:334877640bfdaaabce533e3252a17294",
 134        "vit_b_em_organelles_decoder": "xxh128:bb6398956a6b0132c26b631c14f95ce2",
 135        "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44",
 136        # Histopathology models:
 137        "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213",
 138        "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04",
 139        "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47",
 140        # Medical Imaging models:
 141        "vit_b_medical_imaging_decoder": "xxh128:9e498b12f526f119b96c88be76e3b2ed",
 142    }
 143    registry = {**encoder_registry, **decoder_registry}
 144
 145    encoder_urls = {
 146        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
 147        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
 148        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
 149        "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
 150        "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l.pt",
 151        "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b.pt",
 152        "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt",
 153        "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l.pt",  # noqa
 154        "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b.pt",  # noqa
 155        "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt",  # noqa
 156        "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download",
 157        "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download",
 158        "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download",
 159        "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/f5Ol4FrjPQWfjUF/download",
 160    }
 161
 162    decoder_urls = {
 163        "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l_decoder.pt",  # noqa
 164        "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b_decoder.pt",  # noqa
 165        "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt",  # noqa
 166        "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l_decoder.pt",  # noqa
 167        "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b_decoder.pt",  # noqa
 168        "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt",  # noqa
 169        "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download",
 170        "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download",
 171        "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download",
 172        "vit_b_medical_imaging_decoder": "https://owncloud.gwdg.de/index.php/s/ahd3ZhZl2e0RIwz/download",
 173    }
 174    urls = {**encoder_urls, **decoder_urls}
 175
 176    models = pooch.create(
 177        path=os.path.join(microsam_cachedir(), "models"),
 178        base_url="",
 179        registry=registry,
 180        urls=urls,
 181    )
 182    return models
 183
 184
 185def _get_default_device():
 186    # check that we're in CI and use the CPU if we are
 187    # otherwise the tests may run out of memory on MAC if MPS is used.
 188    if os.getenv("GITHUB_ACTIONS") == "true":
 189        return "cpu"
 190    # Use cuda enabled gpu if it's available.
 191    if torch.cuda.is_available():
 192        device = "cuda"
 193    # As second priority use mps.
 194    # See https://pytorch.org/docs/stable/notes/mps.html for details
 195    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
 196        print("Using apple MPS device.")
 197        device = "mps"
 198    # Use the CPU as fallback.
 199    else:
 200        device = "cpu"
 201    return device
 202
 203
 204def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
 205    """Get the torch device.
 206
 207    If no device is passed the default device for your system is used.
 208    Else it will be checked if the device you have passed is supported.
 209
 210    Args:
 211        device: The input device. By default, selects the best available device supports.
 212
 213    Returns:
 214        The device.
 215    """
 216    if device is None or device == "auto":
 217        device = _get_default_device()
 218    else:
 219        device_type = device if isinstance(device, str) else device.type
 220        if device_type.lower() == "cuda":
 221            if not torch.cuda.is_available():
 222                raise RuntimeError("PyTorch CUDA backend is not available.")
 223        elif device_type.lower() == "mps":
 224            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
 225                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
 226        elif device_type.lower() == "cpu":
 227            pass  # cpu is always available
 228        else:
 229            raise RuntimeError(f"Unsupported device: '{device}'. Please choose from 'cpu', 'cuda', or 'mps'.")
 230
 231    return device
 232
 233
 234def _available_devices():
 235    available_devices = []
 236    for i in ["cuda", "mps", "cpu"]:
 237        try:
 238            device = get_device(i)
 239        except RuntimeError:
 240            pass
 241        else:
 242            available_devices.append(device)
 243    return available_devices
 244
 245
 246# We write a custom unpickler that skips objects that cannot be found instead of
 247# throwing an AttributeError or ModueNotFoundError.
 248# NOTE: since we just want to unpickle the model to load its weights these errors don't matter.
 249# See also https://stackoverflow.com/questions/27732354/unable-to-load-files-using-pickle-and-multiple-modules
 250class _CustomUnpickler(pickle.Unpickler):
 251    def find_class(self, module, name):
 252        try:
 253            return super().find_class(module, name)
 254        except (AttributeError, ModuleNotFoundError) as e:
 255            warnings.warn(f"Did not find {module}:{name} and will skip it, due to error {e}")
 256            return None
 257
 258
 259def _compute_hash(path, chunk_size=8192):
 260    hash_obj = xxhash.xxh128()
 261    with open(path, "rb") as f:
 262        chunk = f.read(chunk_size)
 263        while chunk:
 264            hash_obj.update(chunk)
 265            chunk = f.read(chunk_size)
 266    hash_val = hash_obj.hexdigest()
 267    return f"xxh128:{hash_val}"
 268
 269
 270# Load the state from a checkpoint.
 271# The checkpoint can either contain a sam encoder state
 272# or it can be a checkpoint for model finetuning.
 273def _load_checkpoint(checkpoint_path):
 274    # Over-ride the unpickler with our custom one.
 275    # This enables imports from torch_em checkpoints even if it cannot be fully unpickled.
 276    custom_pickle = pickle
 277    custom_pickle.Unpickler = _CustomUnpickler
 278
 279    state = torch.load(checkpoint_path, map_location="cpu", pickle_module=custom_pickle)
 280    if "model_state" in state:
 281        # Copy the model weights from torch_em's training format.
 282        model_state = state["model_state"]
 283        sam_prefix = "sam."
 284        model_state = OrderedDict(
 285            [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()]
 286        )
 287    else:
 288        model_state = state
 289
 290    return state, model_state
 291
 292
 293def _download_sam_model(model_type, progress_bar_factory=None):
 294    model_registry = models()
 295
 296    progress_bar = True
 297    # Check if we have to download the model.
 298    # If we do and have a progress bar factory, then we over-write the progress bar.
 299    if not os.path.exists(os.path.join(get_cache_directory(), model_type)) and progress_bar_factory is not None:
 300        progress_bar = progress_bar_factory(model_type)
 301
 302    checkpoint_path = model_registry.fetch(model_type, progressbar=progress_bar)
 303    if not isinstance(progress_bar, bool):  # Close the progress bar when the task finishes.
 304        progress_bar.close()
 305
 306    model_hash = model_registry.registry[model_type]
 307
 308    # If we have a custom model then we may also have a decoder checkpoint.
 309    # Download it here, so that we can add it to the state.
 310    decoder_name = f"{model_type}_decoder"
 311    decoder_path = model_registry.fetch(
 312        decoder_name, progressbar=True
 313    ) if decoder_name in model_registry.registry else None
 314
 315    return checkpoint_path, model_hash, decoder_path
 316
 317
 318def get_sam_model(
 319    model_type: str = _DEFAULT_MODEL,
 320    device: Optional[Union[str, torch.device]] = None,
 321    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 322    return_sam: bool = False,
 323    return_state: bool = False,
 324    peft_kwargs: Optional[Dict] = None,
 325    flexible_load_checkpoint: bool = False,
 326    progress_bar_factory: Optional[Callable] = None,
 327    decoder_path: Optional[Union[str, os.PathLike]] = None,
 328    **model_kwargs,
 329) -> SamPredictor:
 330    r"""Get the Segment Anything Predictor.
 331
 332    This function will download the required model or load it from the cached weight file.
 333    This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
 334    The name of the requested model can be set via `model_type`.
 335    See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
 336    for an overview of the available models
 337
 338    Alternatively this function can also load a model from weights stored in a local filepath.
 339    The corresponding file path is given via `checkpoint_path`. In this case `model_type`
 340    must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
 341    a SAM model with vit_b encoder.
 342
 343    By default the models are downloaded to a folder named 'micro_sam/models'
 344    inside your default cache directory, eg:
 345    * Mac: ~/Library/Caches/<AppName>
 346    * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined.
 347    * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache
 348    See the pooch.os_cache() documentation for more details:
 349    https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
 350
 351    Args:
 352        model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default.
 353            To get a list of all available model names you can call `micro_sam.util.get_model_names`.
 354        device: The device for the model. If 'None' is provided, will use GPU if available.
 355        checkpoint_path: The path to a file with weights that should be used instead of using the
 356            weights corresponding to `model_type`. If given, `model_type` must match the architecture
 357            corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder
 358            then `model_type` must be given as 'vit_b'.
 359        return_sam: Return the sam model object as well as the predictor. By default, set to 'False'.
 360        return_state: Return the unpickled checkpoint state. By default, set to 'False'.
 361        peft_kwargs: Keyword arguments for th PEFT wrapper class.
 362            If passed 'None', it does not initialize any parameter efficient finetuning.
 363        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
 364            By default, set to 'False'.
 365        progress_bar_factory: A function to create a progress bar for the model download.
 366        decoder_path: Optional path to weights for a segmentation decoder. If given and
 367            `return_state=True`, the decoder state is added to the returned state as
 368            'decoder_state'. This can be used to provide decoder-only weights that are
 369            separate from the encoder checkpoint.
 370        model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
 371
 372    Returns:
 373        The Segment Anything predictor.
 374    """
 375    device = get_device(device)
 376
 377    # We support passing a local filepath to a checkpoint.
 378    # In this case we do not download any weights but just use the local weight file,
 379    # as it is, without copying it over anywhere or checking it's hashes.
 380
 381    # checkpoint_path has not been passed, we download a known model and derive the correct
 382    # URL from the model_type. If the model_type is invalid pooch will raise an error.
 383    _provided_checkpoint_path = checkpoint_path is not None
 384    if checkpoint_path is None:
 385        checkpoint_path, model_hash, downloaded_decoder_path = _download_sam_model(model_type, progress_bar_factory)
 386        if decoder_path is None:
 387            decoder_path = downloaded_decoder_path
 388
 389    # checkpoint_path has been passed, we use it instead of downloading a model.
 390    else:
 391        # Check if the file exists and raise an error otherwise.
 392        # We can't check any hashes here, and we don't check if the file is actually a valid weight file.
 393        # (If it isn't the model creation will fail below.)
 394        if not os.path.exists(checkpoint_path):
 395            raise ValueError(f"Checkpoint at '{checkpoint_path}' could not be found.")
 396        model_hash = _compute_hash(checkpoint_path)
 397
 398    if decoder_path is not None and not os.path.exists(decoder_path):
 399        raise ValueError(f"Decoder checkpoint at '{decoder_path}' could not be found.")
 400
 401    # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped
 402    # before calling sam_model_registry.
 403    abbreviated_model_type = model_type[:5]
 404    if abbreviated_model_type not in _MODEL_TYPES:
 405        raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}")
 406    if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
 407        raise RuntimeError(
 408            "'mobile_sam' is required for the vit-tiny. "
 409            "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
 410        )
 411
 412    state, model_state = _load_checkpoint(checkpoint_path)
 413
 414    if _provided_checkpoint_path:
 415        # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type'
 416        # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination.
 417        from micro_sam.models.build_sam import _validate_model_type
 418        _provided_model_type = _validate_model_type(model_state)
 419
 420        # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type'
 421        # Otherwise replace 'abbreviated_model_type' with the later.
 422        if abbreviated_model_type != _provided_model_type:
 423            # Printing the message below to avoid any filtering of warnings on user's end.
 424            print(
 425                f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', "
 426                f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. "
 427                f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. "
 428                "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future."
 429            )
 430
 431        # Replace the extracted 'abbreviated_model_type' subjected to the model weights.
 432        abbreviated_model_type = _provided_model_type
 433
 434    # Whether to update parameters necessary to initialize the model
 435    if model_kwargs:  # Checks whether model_kwargs have been provided or not
 436        if abbreviated_model_type == "vit_t":
 437            raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.")
 438        sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs)
 439
 440    else:
 441        sam = sam_model_registry[abbreviated_model_type]()
 442
 443    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
 444    # Overwrites the SAM model by freezing the backbone and allow PEFT.
 445    if peft_kwargs and isinstance(peft_kwargs, dict):
 446        # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference.
 447        peft_kwargs.pop("quantize", None)
 448
 449        if abbreviated_model_type == "vit_t":
 450            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
 451
 452        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
 453    # In case the model checkpoints have some issues when it is initialized with different parameters than default.
 454    if flexible_load_checkpoint:
 455        sam = _handle_checkpoint_loading(sam, model_state)
 456    else:
 457        sam.load_state_dict(model_state)
 458    sam.to(device=device)
 459
 460    predictor = SamPredictor(sam)
 461    predictor.model_type = abbreviated_model_type
 462    predictor._hash = model_hash
 463    predictor.model_name = model_type
 464    predictor.checkpoint_path = checkpoint_path
 465
 466    # Add the decoder to the state if we have one and if the state is returned.
 467    if decoder_path is not None and return_state:
 468        state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False)
 469
 470    if return_sam and return_state:
 471        return predictor, sam, state
 472    if return_sam:
 473        return predictor, sam
 474    if return_state:
 475        return predictor, state
 476    return predictor
 477
 478
 479def _handle_checkpoint_loading(sam, model_state):
 480    # Whether to handle the mismatch issues in a bit more elegant way.
 481    # eg. while training for multi-class semantic segmentation in the mask encoder,
 482    # parameters are updated - leading to "size mismatch" errors
 483
 484    new_state_dict = {}  # for loading matching parameters
 485    mismatched_layers = []  # for tracking mismatching parameters
 486
 487    reference_state = sam.state_dict()
 488
 489    for k, v in model_state.items():
 490        if k in reference_state:  # This is done to get rid of unwanted layers from pretrained SAM.
 491            if reference_state[k].size() == v.size():
 492                new_state_dict[k] = v
 493            else:
 494                mismatched_layers.append(k)
 495
 496    reference_state.update(new_state_dict)
 497
 498    if len(mismatched_layers) > 0:
 499        warnings.warn(f"The layers with size mismatch: {mismatched_layers}")
 500
 501    for mlayer in mismatched_layers:
 502        if 'weight' in mlayer:
 503            torch.nn.init.kaiming_uniform_(reference_state[mlayer])
 504        elif 'bias' in mlayer:
 505            reference_state[mlayer].zero_()
 506
 507    sam.load_state_dict(reference_state)
 508
 509    return sam
 510
 511
 512def export_custom_sam_model(
 513    checkpoint_path: Union[str, os.PathLike],
 514    model_type: str,
 515    save_path: Union[str, os.PathLike],
 516    with_segmentation_decoder: bool = False,
 517    prefix: str = "sam.",
 518) -> None:
 519    """Export a finetuned Segment Anything Model to the standard model format.
 520
 521    The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
 522
 523    Args:
 524        checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
 525        model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
 526        save_path: Where to save the exported model.
 527        with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well.
 528            If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
 529        prefix: The prefix to remove from the model parameter keys.
 530    """
 531    state, model_state = _load_checkpoint(checkpoint_path=checkpoint_path)
 532    model_state = OrderedDict([(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()])
 533
 534    # Store the 'decoder_state' as well, if desired.
 535    if with_segmentation_decoder:
 536        if "decoder_state" not in state:
 537            raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.")
 538        decoder_state = state["decoder_state"]
 539        save_state = {"model_state": model_state, "decoder_state": decoder_state}
 540    else:
 541        save_state = model_state
 542
 543    torch.save(save_state, save_path)
 544
 545
 546def export_custom_qlora_model(
 547    checkpoint_path: Optional[Union[str, os.PathLike]],
 548    finetuned_path: Union[str, os.PathLike],
 549    model_type: str,
 550    save_path: Union[str, os.PathLike],
 551) -> None:
 552    """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
 553
 554    The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`.
 555
 556    Args:
 557        checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
 558        finetuned_path: The path to the new finetuned model, using QLoRA.
 559        model_type: The Segment Anything Model type corresponding to the checkpoint.
 560        save_path: Where to save the exported model.
 561    """
 562    # Step 1: Get the base SAM model: used to start finetuning from.
 563    _, sam = get_sam_model(
 564        model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True,
 565    )
 566
 567    # Step 2: Load the QLoRA-style finetuned model.
 568    ft_state, ft_model_state = _load_checkpoint(finetuned_path)
 569
 570    # Step 3: Identify LoRA layers from QLoRA model.
 571    # - differentiate between LoRA applied to the attention matrices and LoRA applied to the MLP layers.
 572    # - then copy the LoRA layers from the QLoRA model to the new state dict
 573    updated_model_state = {}
 574
 575    modified_attn_layers = set()
 576    modified_mlp_layers = set()
 577
 578    for k, v in ft_model_state.items():
 579        if "blocks." in k:
 580            layer_id = int(k.split("blocks.")[1].split(".")[0])
 581        if k.find("qkv.w_a_linear") != -1 or k.find("qkv.w_b_linear") != -1:
 582            modified_attn_layers.add(layer_id)
 583            updated_model_state[k] = v
 584        if k.find("mlp.w_a_linear") != -1 or k.find("mlp.w_b_linear") != -1:
 585            modified_mlp_layers.add(layer_id)
 586            updated_model_state[k] = v
 587
 588    # Step 4: Next, we get all the remaining parameters from the base SAM model.
 589    for k, v in sam.state_dict().items():
 590        if "blocks." in k:
 591            layer_id = int(k.split("blocks.")[1].split(".")[0])
 592        if k.find("attn.qkv.") != -1:
 593            if layer_id in modified_attn_layers:  # We have LoRA in QKV layers, so we need to modify the key
 594                k = k.replace("qkv", "qkv.qkv_proj")
 595        elif k.find("mlp") != -1 and k.find("image_encoder") != -1:
 596            if layer_id in modified_mlp_layers:  # We have LoRA in MLP layers, so we need to modify the key
 597                k = k.replace("mlp.", "mlp.mlp_layer.")
 598        updated_model_state[k] = v
 599
 600    # Step 5: Finally, we replace the old model state with the new one (to retain other relevant stuff)
 601    ft_state['model_state'] = updated_model_state
 602
 603    # Step 6: Store the new "state" to "save_path"
 604    torch.save(ft_state, save_path)
 605
 606
 607def get_model_names() -> Iterable:
 608    model_registry = models()
 609    model_names = model_registry.registry.keys()
 610    return model_names
 611
 612
 613#
 614# Functionality for precomputing image embeddings.
 615#
 616
 617
 618def _to_image(image):
 619    input_ = image
 620    ndim = input_.ndim
 621    n_channels = 1 if ndim == 2 else input_.shape[-1]
 622
 623    # Map the input to three channels.
 624    if ndim == 2:  # Grayscale image -> replicate channels.
 625        input_ = np.concatenate([input_[..., None]] * 3, axis=-1)
 626    elif ndim == 3 and n_channels == 1:  # Grayscale image -> replicate channels.
 627        input_ = np.concatenate([input_] * 3, axis=-1)
 628    elif ndim == 3 and n_channels == 2:  # Two channels -> add a zero channel.
 629        zero_channel = np.zeros(input_.shape[:2] + (1,), dtype=input_.dtype)
 630        input_ = np.concatenate([input_, zero_channel], axis=-1)
 631    elif input_.ndim == 3 and n_channels == 3:  # RGB input -> do nothing.
 632        pass
 633    elif input_.ndim == 3 and n_channels > 3:  # More than three channels -> select first three.
 634        warnings.warn(f"You provided an input with {n_channels} channels. Only the first three will be used.")
 635        input_ = input_[..., :3]
 636    else:
 637        raise ValueError(
 638            f"Invalid input dimensionality {ndim}. Expect either a 2D input (=grayscale image) "
 639            "or a 3D input (= image with channels)."
 640        )
 641    assert input_.ndim == 3 and input_.shape[-1] == 3
 642
 643    # Normalize the input per channel and bring it to uint8.
 644    input_ = input_.astype("float32")
 645    input_ -= input_.min(axis=(0, 1))[None, None]
 646    input_ /= (input_.max(axis=(0, 1))[None, None] + 1e-7)
 647    input_ = (input_ * 255).astype("uint8")
 648
 649    # Explicitly return a numpy array for compatibility with torchvision
 650    # because the input_ array could be something like dask array.
 651    return np.array(input_)
 652
 653
 654@torch.no_grad
 655def _compute_embeddings_batched(predictor, batched_images):
 656    predictor.reset_image()
 657    batched_tensors, original_sizes, input_sizes = [], [], []
 658
 659    # Apply proeprocessing to all images in the batch, and then stack them.
 660    # Note: after the transformation the images are all of the same size,
 661    # so they can be stacked and processed as a batch, even if the input images were of different size.
 662    for image in batched_images:
 663        tensor = predictor.transform.apply_image(image)
 664        tensor = torch.as_tensor(tensor, device=predictor.device)
 665        tensor = tensor.permute(2, 0, 1).contiguous()[None, :, :, :]
 666
 667        original_sizes.append(image.shape[:2])
 668        input_sizes.append(tensor.shape[-2:])
 669
 670        tensor = predictor.model.preprocess(tensor)
 671        batched_tensors.append(tensor)
 672
 673    batched_tensors = torch.cat(batched_tensors)
 674    features = predictor.model.image_encoder(batched_tensors)
 675
 676    predictor.original_size = original_sizes[-1]
 677    predictor.input_size = input_sizes[-1]
 678    predictor.features = features[-1]
 679    predictor.is_image_set = True
 680
 681    return features, original_sizes, input_sizes
 682
 683
 684# Wrapper of zarr.create dataset to support zarr v2 and zarr v3.
 685def _create_dataset_with_data(group, name, data, chunks=None):
 686    zarr_major_version = int(zarr.__version__.split(".")[0])
 687    if chunks is None:
 688        chunks = data.shape
 689    if zarr_major_version == 2:
 690        ds = group.create_dataset(name, data=data, shape=data.shape, chunks=chunks)
 691    elif zarr_major_version == 3:
 692        ds = group.create_array(name, shape=data.shape, chunks=chunks, dtype=data.dtype)
 693        ds[:] = data
 694    else:
 695        raise RuntimeError(f"Unsupported zarr version: {zarr_major_version}")
 696    return ds
 697
 698
 699def _create_dataset_without_data(group, name, shape, dtype, chunks):
 700    zarr_major_version = int(zarr.__version__.split(".")[0])
 701    if zarr_major_version == 2:
 702        ds = group.create_dataset(name, shape=shape, dtype=dtype, chunks=chunks)
 703    elif zarr_major_version == 3:
 704        ds = group.create_array(name, shape=shape, chunks=chunks, dtype=dtype)
 705    else:
 706        raise RuntimeError(f"Unsupported zarr version: {zarr_major_version}")
 707    return ds
 708
 709
 710def _write_batch(features, tile_ids, batched_embeddings, original_sizes, input_sizes, slices=None, n_slices=None):
 711
 712    # Pre-create / pre-fetch the datasets if we have slices.
 713    # (Dataset creation is not thread-safe)
 714    if slices is not None:
 715        datasets = {}
 716        for tile_id, tile_embeddings, original_size, input_size in zip(
 717            tile_ids, batched_embeddings, original_sizes, input_sizes
 718        ):
 719            ds_name = str(tile_id)
 720            if ds_name in datasets:
 721                continue
 722            if ds_name in features:
 723                datasets[ds_name] = features[ds_name]
 724                continue
 725            shape = (n_slices, 1) + tile_embeddings.shape
 726            chunks = (1, 1) + tile_embeddings.shape
 727            ds = _create_dataset_without_data(features, ds_name, shape=shape, dtype="float32", chunks=chunks)
 728            ds.attrs["original_size"] = original_size
 729            ds.attrs["input_size"] = input_size
 730            datasets[ds_name] = ds
 731
 732    def _write_embed(i):
 733        ds_name = str(tile_ids[i])
 734        tile_embeddings = batched_embeddings[i].unsqueeze(0)
 735        if slices is None:
 736            ds = _create_dataset_with_data(features, ds_name, data=tile_embeddings.cpu().numpy())
 737            ds.attrs["original_size"] = original_sizes[i]
 738            ds.attrs["input_size"] = input_sizes[i]
 739        elif ds_name in features:
 740            ds = datasets[ds_name]
 741            z = slices[i]
 742            ds[z] = tile_embeddings.cpu().numpy()
 743
 744    n_tiles = len(tile_ids)
 745    n_workers = min(mp.cpu_count(), n_tiles)
 746    with futures.ThreadPoolExecutor(n_workers) as tp:
 747        list(tp.map(_write_embed, range(n_tiles)))
 748
 749
 750def _get_tiles_in_mask(mask, tiling, halo, z=None):
 751    def _check_mask(tile_id):
 752        tile = tiling.get_block_with_halo(tile_id, list(halo))
 753        outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outer_block.begin, tile.outer_block.end))
 754        if z is not None:
 755            outer_tile = (z,) + outer_tile
 756        tile_mask = mask[outer_tile].astype("bool")
 757        return None if tile_mask.sum() == 0 else tile_id
 758
 759    n_threads = mp.cpu_count()
 760    with futures.ThreadPoolExecutor(n_threads) as tp:
 761        tiles_in_mask = tp.map(_check_mask, range(tiling.number_of_blocks))
 762    return sorted([tile_id for tile_id in tiles_in_mask if tile_id is not None])
 763
 764
 765def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask):
 766    tiling = Blocking([0, 0], input_.shape[:2], tile_shape)
 767    n_tiles = tiling.number_of_blocks
 768
 769    features = f.require_group("features")
 770    features.attrs["shape"] = input_.shape[:2]
 771    features.attrs["tile_shape"] = tile_shape
 772    features.attrs["halo"] = halo
 773
 774    n_batches = int(np.ceil(n_tiles / batch_size))
 775    if mask is None:
 776        tile_ids_for_batches = [
 777            range(batch_id * batch_size, min((batch_id + 1) * batch_size, n_tiles))
 778            for batch_id in range(n_batches)
 779        ]
 780        pbar_init(n_tiles, "Compute Image Embeddings 2D tiled")
 781    else:
 782        tiles_in_mask = _get_tiles_in_mask(mask, tiling, halo)
 783        pbar_init(len(tiles_in_mask), "Compute Image Embeddings 2D tiled with mask")
 784        tile_ids_for_batches = np.array_split(tiles_in_mask, n_batches)
 785        assert len(tile_ids_for_batches) == n_batches
 786
 787    for tile_ids in tile_ids_for_batches:
 788        batched_images = []
 789        for tile_id in tile_ids:
 790            tile = tiling.get_block_with_halo(tile_id, list(halo))
 791            outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outer_block.begin, tile.outer_block.end))
 792            tile_input = _to_image(input_[outer_tile])
 793            batched_images.append(tile_input)
 794
 795        batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
 796        _write_batch(features, tile_ids, batched_embeddings, original_sizes, input_sizes)
 797        pbar_update(len(tile_ids))
 798
 799    _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None)
 800    if mask is not None:
 801        features.attrs["tiles_in_mask"] = tiles_in_mask
 802
 803    return features
 804
 805
 806class _BatchProvider:
 807    def __init__(self, n_slices, n_tiles_per_plane, tiles_in_mask_per_slice, batch_size):
 808        if tiles_in_mask_per_slice is None:
 809            self.n_tiles_total = n_slices * n_tiles_per_plane
 810        else:
 811            self.n_tiles_total = sum(len(val) for val in tiles_in_mask_per_slice.values())
 812
 813        self.n_batches = int(np.ceil(self.n_tiles_total / batch_size))
 814        self.n_slices = n_slices
 815        self.n_tiles_per_plane = n_tiles_per_plane
 816        self.tiles_in_mask_per_slice = tiles_in_mask_per_slice
 817        self.batch_size = batch_size
 818
 819        # Iter variables.
 820        self.batch_id = 0
 821        self.z = 0
 822        self.tile_id = 0
 823
 824    def __iter__(self):
 825        return self
 826
 827    def __next__(self):
 828        if self.batch_id >= self.n_batches:
 829            raise StopIteration
 830
 831        z_list = list(range(self.n_tiles_per_plane))
 832        z_tiles = z_list if self.tiles_in_mask_per_slice is None else self.tiles_in_mask_per_slice[self.z]
 833
 834        slices, tile_ids = [], []
 835        this_batch_size = 0
 836        while this_batch_size < self.batch_size:
 837            if self.tile_id == len(z_tiles):
 838                self.z += 1
 839                self.tile_id = 0
 840                if self.z >= self.n_slices:
 841                    break
 842                z_tiles = z_list if self.tiles_in_mask_per_slice is None else self.tiles_in_mask_per_slice[self.z]
 843                continue
 844
 845            slices.append(self.z), tile_ids.append(z_tiles[self.tile_id])
 846            self.tile_id += 1
 847            this_batch_size += 1
 848
 849        self.batch_id += 1
 850        return slices, tile_ids
 851
 852
 853def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask):
 854    assert input_.ndim == 3
 855
 856    shape = input_.shape[1:]
 857    tiling = Blocking([0, 0], shape, tile_shape)
 858    features = f.require_group("features")
 859    features.attrs["shape"] = shape
 860    features.attrs["tile_shape"] = tile_shape
 861    features.attrs["halo"] = halo
 862
 863    n_tiles_per_plane = tiling.number_of_blocks
 864    n_slices = input_.shape[0]
 865
 866    msg = "Compute Image Embeddings 3D tiled"
 867    if mask is None:
 868        n_tiles_total = n_slices * n_tiles_per_plane
 869        tiles_in_mask_per_slice = None
 870    else:
 871        tiles_in_mask_per_slice = {}
 872        for z in range(n_slices):
 873            tiles_in_mask_per_slice[z] = _get_tiles_in_mask(mask, tiling, halo, z=z)
 874        n_tiles_total = sum(len(val) for val in tiles_in_mask_per_slice.values())
 875        msg += " masked"
 876    pbar_init(n_tiles_total, msg)
 877
 878    batch_provider = _BatchProvider(n_slices, n_tiles_per_plane, tiles_in_mask_per_slice, batch_size)
 879    for slices, tile_ids in batch_provider:
 880        batched_images = []
 881        for z, tile_id in zip(slices, tile_ids):
 882            tile = tiling.get_block_with_halo(tile_id, list(halo))
 883            outer_tile = (z,) + tuple(
 884                slice(beg, end) for beg, end in zip(tile.outer_block.begin, tile.outer_block.end)
 885            )
 886            tile_input = _to_image(input_[outer_tile])
 887            batched_images.append(tile_input)
 888
 889        batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
 890        _write_batch(
 891            features, tile_ids, batched_embeddings, original_sizes, input_sizes, slices=slices, n_slices=n_slices
 892        )
 893        pbar_update(len(tile_ids))
 894
 895    if mask is not None:
 896        features.attrs["tiles_in_mask"] = {str(z): per_slice for z, per_slice in tiles_in_mask_per_slice.items()}
 897
 898    _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None)
 899    return features
 900
 901
 902def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update):
 903    # Check if the embeddings are already cached.
 904    if save_path is not None and "input_size" in f.attrs:
 905        # In this case we load the embeddings.
 906        features = f["features"][:]
 907        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 908        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 909        # Also set the embeddings.
 910        set_precomputed(predictor, image_embeddings)
 911        return image_embeddings
 912
 913    pbar_init(1, "Compute Image Embeddings 2D")
 914    # Otherwise we have to compute the embeddings.
 915    predictor.reset_image()
 916    predictor.set_image(_to_image(input_))
 917    features = predictor.get_image_embedding().cpu().numpy()
 918    original_size = predictor.original_size
 919    input_size = predictor.input_size
 920    pbar_update(1)
 921
 922    # Save the embeddings if we have a save_path.
 923    if save_path is not None:
 924        _create_dataset_with_data(f, "features", data=features)
 925        _write_embedding_signature(
 926            f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size,
 927        )
 928
 929    image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 930    return image_embeddings
 931
 932
 933def _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask):
 934    # Check if the features are already computed.
 935    if "input_size" in f.attrs:
 936        features = f["features"]
 937        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 938        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 939        return image_embeddings
 940
 941    # Otherwise compute them. Note: saving happens automatically because we
 942    # always write the features to zarr. If no save path is given we use an in-memory zarr.
 943    features = _compute_tiled_features_2d(
 944        predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask
 945    )
 946    image_embeddings = {"features": features, "input_size": None, "original_size": None}
 947    return image_embeddings
 948
 949
 950def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size):
 951    # Check if the embeddings are already fully cached.
 952    if save_path is not None and "input_size" in f.attrs:
 953        # In this case we load the embeddings.
 954        features = f["features"] if lazy_loading else f["features"][:]
 955        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 956        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 957        return image_embeddings
 958
 959    # Otherwise we have to compute the embeddings.
 960
 961    # First check if we have a save path or not and set things up accordingly.
 962    if save_path is None:
 963        features = []
 964        save_features = False
 965        partial_features = False
 966    else:
 967        save_features = True
 968        embed_shape = (1, 256, 64, 64)
 969        shape = (input_.shape[0],) + embed_shape
 970        chunks = (1,) + embed_shape
 971        if "features" in f:
 972            partial_features = True
 973            features = f["features"]
 974            if features.shape != shape or features.chunks != chunks:
 975                raise RuntimeError("Invalid partial features")
 976        else:
 977            partial_features = False
 978            features = _create_dataset_without_data(f, "features", shape=shape, chunks=chunks, dtype="float32")
 979
 980    # Initialize the pbar and batches.
 981    n_slices = input_.shape[0]
 982    pbar_init(n_slices, "Compute Image Embeddings 3D")
 983    n_batches = int(np.ceil(n_slices / batch_size))
 984
 985    for batch_id in range(n_batches):
 986        z_start = batch_id * batch_size
 987        z_stop = min(z_start + batch_size, n_slices)
 988
 989        batched_images, batched_z = [], []
 990        for z in range(z_start, z_stop):
 991            # Skip feature computation in case of partial features in non-zero slice.
 992            if partial_features and np.count_nonzero(features[z]) != 0:
 993                continue
 994            tile_input = _to_image(input_[z])
 995            batched_images.append(tile_input)
 996            batched_z.append(z)
 997
 998        batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
 999
1000        for z, embedding in zip(batched_z, batched_embeddings):
1001            embedding = embedding.unsqueeze(0)
1002            if save_features:
1003                features[z] = embedding.cpu().numpy()
1004            else:
1005                features.append(embedding.unsqueeze(0))
1006            pbar_update(1)
1007
1008    if save_features:
1009        _write_embedding_signature(
1010            f, input_, predictor, tile_shape=None, halo=None,
1011            input_size=input_sizes[-1], original_size=original_sizes[-1],
1012        )
1013    else:
1014        # Concatenate across the z axis.
1015        features = torch.cat(features).cpu().numpy()
1016
1017    image_embeddings = {"features": features, "input_size": input_sizes[-1], "original_size": original_sizes[-1]}
1018    return image_embeddings
1019
1020
1021def _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask):
1022    # Check if the features are already computed.
1023    if "input_size" in f.attrs:
1024        features = f["features"]
1025        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
1026        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
1027        return image_embeddings
1028
1029    # Otherwise compute them. Note: saving happens automatically because we
1030    # always write the features to zarr. If no save path is given we use an in-memory zarr.
1031    features = _compute_tiled_features_3d(
1032        predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask
1033    )
1034    image_embeddings = {"features": features, "input_size": None, "original_size": None}
1035    return image_embeddings
1036
1037
1038def _compute_data_signature(input_):
1039    data_signature = hashlib.sha1(np.asarray(input_).tobytes()).hexdigest()
1040    return data_signature
1041
1042
1043# Create all metadata that is stored along with the embeddings.
1044def _get_embedding_signature(input_, predictor, tile_shape, halo, data_signature=None):
1045    if data_signature is None:
1046        data_signature = _compute_data_signature(input_)
1047
1048    signature = {
1049        "data_signature": data_signature,
1050        "tile_shape": tile_shape if tile_shape is None else list(tile_shape),
1051        "halo": halo if halo is None else list(halo),
1052        "model_type": predictor.model_type,
1053        "model_name": predictor.model_name,
1054        "micro_sam_version": __version__,
1055        "model_hash": getattr(predictor, "_hash", None),
1056    }
1057    return signature
1058
1059
1060# Note: the input size and orginal size are different if embeddings are tiled or not.
1061# That's why we do not include them in the main signature that is being checked
1062# (_get_embedding_signature), but just add it for serialization here.
1063def _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size, original_size):
1064    signature = _get_embedding_signature(input_, predictor, tile_shape, halo)
1065    signature.update({"input_size": input_size, "original_size": original_size})
1066    for key, val in signature.items():
1067        f.attrs[key] = val
1068
1069
1070def _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo):
1071    # We may have an empty zarr file that was already created to save the embeddings in.
1072    # In this case the embeddings will be computed and we don't need to perform any checks.
1073    if "input_size" not in f.attrs:
1074        return
1075
1076    signature = _get_embedding_signature(input_, predictor, tile_shape, halo)
1077    for key, val in signature.items():
1078        # Check whether the key is missing from the attrs or if the value is not matching.
1079        if key not in f.attrs or f.attrs[key] != val:
1080            # These keys were recently added, so we don't want to fail yet if they don't
1081            # match in order to not invalidate previous embedding files.
1082            # Instead we just raise a warning. (For the version we probably also don't want to fail
1083            # i the future since it should not invalidate the embeddings).
1084            if key in ("micro_sam_version", "model_hash", "model_name"):
1085                warnings.warn(
1086                    f"The signature for {key} in embeddings file {save_path} has a mismatch: "
1087                    f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. "
1088                    "But please recompute them if model predictions don't look as expected."
1089                )
1090            else:
1091                raise RuntimeError(
1092                    f"Embeddings file {save_path} is invalid due to mismatch in {key}: "
1093                    f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file."
1094                )
1095
1096
1097# Helper function for optional external progress bars.
1098def handle_pbar(verbose, pbar_init, pbar_update):
1099    """@private"""
1100
1101    # Noop to provide dummy functions.
1102    def noop(*args):
1103        pass
1104
1105    if verbose and pbar_init is None:  # we are verbose and don't have an external progress bar.
1106        assert pbar_update is None  # avoid inconsistent state of callbacks
1107
1108        # Create our own progress bar and callbacks
1109        pbar = tqdm()
1110
1111        def pbar_init(total, description):
1112            pbar.total = total
1113            pbar.set_description(description)
1114
1115        def pbar_update(update):
1116            pbar.update(update)
1117
1118        def pbar_close():
1119            pbar.close()
1120
1121    elif verbose and pbar_init is not None:  # external pbar -> we don't have to do anything
1122        assert pbar_update is not None
1123        pbar = None
1124        pbar_close = noop
1125
1126    else:  # we are not verbose, do nothing
1127        pbar = None
1128        pbar_init, pbar_update, pbar_close = noop, noop, noop
1129
1130    return pbar, pbar_init, pbar_update, pbar_close
1131
1132
1133def precompute_image_embeddings(
1134    predictor: SamPredictor,
1135    input_: np.ndarray,
1136    save_path: Optional[Union[str, os.PathLike]] = None,
1137    lazy_loading: bool = False,
1138    ndim: Optional[int] = None,
1139    tile_shape: Optional[Tuple[int, int]] = None,
1140    halo: Optional[Tuple[int, int]] = None,
1141    verbose: bool = True,
1142    batch_size: int = 1,
1143    mask: Optional[np.typing.ArrayLike] = None,
1144    pbar_init: Optional[callable] = None,
1145    pbar_update: Optional[callable] = None,
1146) -> ImageEmbeddings:
1147    """Compute the image embeddings (output of the encoder) for the input.
1148
1149    If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
1150
1151    Args:
1152        predictor: The Segment Anything predictor.
1153        input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
1154        save_path: Path to save the embeddings in a zarr container.
1155            By default, set to 'None', i.e. the computed embeddings will not be stored locally.
1156        lazy_loading: Whether to load all embeddings into memory or return an
1157            object to load them on demand when required. This only has an effect if 'save_path' is given
1158            and if the input is 3 dimensional. By default, set to 'False'.
1159        ndim: The dimensionality of the data. If not given will be deduced from the input data.
1160            By default, set to 'None', i.e. will be computed from the provided `input_`.
1161        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
1162        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
1163        verbose: Whether to be verbose in the computation. By default, set to 'True'.
1164        batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'.
1165        mask: An optional mask to define areas that are ignored in the computation.
1166            The mask will be used within tiled embedding computation and tiles that don't contain any foreground
1167            in the mask will be excluded from the computation. It does not have any effect for non-tiled embeddings.
1168        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1169            Can be used together with pbar_update to handle napari progress bar in other thread.
1170            To enable using this function within a threadworker.
1171        pbar_update: Callback to update an external progress bar.
1172
1173    Returns:
1174        The image embeddings.
1175    """
1176    ndim = input_.ndim if ndim is None else ndim
1177
1178    # Handle the embedding save_path.
1179    # We don't have a save path, open in memory zarr file to hold tiled embeddings.
1180    if save_path is None:
1181        f = zarr.group()
1182
1183    # We have a save path and it already exists. Embeddings will be loaded from it,
1184    # check that the saved embeddings in there match the parameters of the function call.
1185    elif os.path.exists(save_path):
1186        f = zarr.open(save_path, mode="a")
1187        _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo)
1188
1189    # We have a save path and it does not exist yet. Create the zarr file to which the
1190    # embeddings will then be saved.
1191    else:
1192        f = zarr.open(save_path, mode="a")
1193
1194    _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update)
1195
1196    if ndim == 2 and tile_shape is None:
1197        embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update)
1198    elif ndim == 2 and tile_shape is not None:
1199        embeddings = _compute_tiled_2d(
1200            input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask
1201        )
1202    elif ndim == 3 and tile_shape is None:
1203        embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size)
1204    elif ndim == 3 and tile_shape is not None:
1205        embeddings = _compute_tiled_3d(
1206            input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask
1207        )
1208    else:
1209        raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.")
1210
1211    pbar_close()
1212    return embeddings
1213
1214
1215def set_precomputed(
1216    predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None,
1217) -> SamPredictor:
1218    """Set the precomputed image embeddings for a predictor.
1219
1220    Args:
1221        predictor: The Segment Anything predictor.
1222        image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`.
1223        i: Index for the image data. Required if `image` has three spatial dimensions
1224            or a time dimension and two spatial dimensions.
1225        tile_id: Index for the tile. This is required if the embeddings are tiled.
1226
1227    Returns:
1228        The predictor with set features.
1229    """
1230    if tile_id is not None:
1231        tile_features = image_embeddings["features"][str(tile_id)]
1232        tile_image_embeddings = {
1233            "features": tile_features,
1234            "input_size": tile_features.attrs["input_size"],
1235            "original_size": tile_features.attrs["original_size"]
1236        }
1237        return set_precomputed(predictor, tile_image_embeddings, i=i)
1238
1239    device = predictor.device
1240    features = image_embeddings["features"]
1241    assert features.ndim in (4, 5), f"{features.ndim}"
1242    if features.ndim == 5 and i is None:
1243        raise ValueError("The data is 3D so an index i is needed.")
1244    elif features.ndim == 4 and i is not None:
1245        raise ValueError("The data is 2D so an index is not needed.")
1246
1247    if i is None:
1248        predictor.features = features.to(device) if torch.is_tensor(features) else \
1249            torch.from_numpy(features[:]).to(device)
1250    else:
1251        predictor.features = features[i].to(device) if torch.is_tensor(features) else \
1252            torch.from_numpy(features[i]).to(device)
1253
1254    predictor.original_size = image_embeddings["original_size"]
1255    predictor.input_size = image_embeddings["input_size"]
1256    predictor.is_image_set = True
1257
1258    return predictor
1259
1260
1261#
1262# Misc functionality
1263#
1264
1265
1266def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
1267    """Compute the intersection over union of two masks.
1268
1269    Args:
1270        mask1: The first mask.
1271        mask2: The second mask.
1272
1273    Returns:
1274        The intersection over union of the two masks.
1275    """
1276    overlap = np.logical_and(mask1 == 1, mask2 == 1).sum()
1277    union = np.logical_or(mask1 == 1, mask2 == 1).sum()
1278    eps = 1e-7
1279    iou = float(overlap) / (float(union) + eps)
1280    return iou
1281
1282
1283def get_centers_and_bounding_boxes(
1284    segmentation: np.ndarray, mode: str = "v"
1285) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]:
1286    """Returns the center coordinates of the foreground instances in the ground-truth.
1287
1288    Args:
1289        segmentation: The segmentation.
1290        mode: Determines the functionality used for computing the centers.
1291            If 'v', the point of maximal distance to the object boundary is used as center.
1292            This center is guaranteed to lie inside the object, also for concave shapes.
1293            If 'p' the object's centroids computed by skimage are used.
1294
1295    Returns:
1296        A dictionary that maps object ids to the corresponding centroid.
1297        A dictionary that maps object_ids to the corresponding bounding box.
1298    """
1299    assert mode in ["p", "v"], "Choose either 'p' for regionprops centroids or 'v' for distance-based centers"
1300
1301    properties = regionprops(segmentation)
1302
1303    if mode == "p":
1304        center_coordinates = {prop.label: prop.centroid for prop in properties}
1305    elif mode == "v":
1306        # Use the point of maximal distance to the object boundary as the center.
1307        # In contrast to the centroid, this point is guaranteed to lie inside the object,
1308        # also for concave shapes. This replaces vigra.filters.eccentricityCenters.
1309        # Compute the boundaries and a single distance transform for the whole
1310        # segmentation, instead of one distance transform per object.
1311        ndim = segmentation.ndim
1312        # Pad so objects touching the image border also get a boundary there,
1313        # matching a per-object padded distance transform.
1314        padded = np.pad(segmentation, 1)
1315        boundaries = find_boundaries(padded, mode="inner")
1316        distances = distance_transform(boundaries == 0)
1317
1318        center_coordinates = {}
1319        for prop in properties:
1320            bbox = prop.bbox
1321            # Slice the global distance field to this object's bbox (shifted by the
1322            # pad of 1) and restrict the argmax to the object's own pixels.
1323            region = distances[tuple(slice(b + 1, e + 1) for b, e in zip(bbox[:ndim], bbox[ndim:]))]
1324            masked = np.where(prop.image, region, -1.0)
1325            center_local = np.unravel_index(int(np.argmax(masked)), masked.shape)
1326            center_coordinates[prop.label] = tuple(int(c + o) for c, o in zip(center_local, bbox[:ndim]))
1327
1328    bbox_coordinates = {prop.label: prop.bbox for prop in properties}
1329
1330    assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}"
1331    return center_coordinates, bbox_coordinates
1332
1333
1334def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray:
1335    """Helper function to load image data from file.
1336
1337    Args:
1338        path: The filepath to the image data.
1339        key: The internal filepath for complex data formats like hdf5.
1340        lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
1341
1342    Returns:
1343        The image data.
1344    """
1345    if key is None:
1346        image_data = imageio.imread(path)
1347    else:
1348        with open_file(path, mode="r") as f:
1349            image_data = f[key]
1350            if not lazy_loading:
1351                image_data = image_data[:]
1352
1353    return image_data
1354
1355
1356def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor:
1357    """Convert the segmentation to one-hot encoded masks.
1358
1359    Args:
1360        segmentation: The segmentation.
1361        segmentation_ids: Optional subset of ids that will be used to subsample the masks.
1362            By default, computes the number of ids from the provided `segmentation` masks.
1363
1364    Returns:
1365        The one-hot encoded masks.
1366    """
1367    masks = segmentation.copy()
1368    if segmentation_ids is None:
1369        n_ids = int(segmentation.max())
1370
1371    else:
1372        msg = "No foreground objects were found."
1373        if len(segmentation_ids) == 0:  # The list should not be completely empty.
1374            raise RuntimeError(msg)
1375
1376        if 0 in segmentation_ids:  # The list should not have 'zero' as a value.
1377            raise RuntimeError(msg)
1378
1379        # the segmentation ids have to be sorted
1380        segmentation_ids = np.sort(segmentation_ids)
1381
1382        # set the non selected objects to zero and relabel sequentially
1383        masks[~np.isin(masks, segmentation_ids)] = 0
1384        masks = relabel_sequential(masks)[0]
1385        n_ids = len(segmentation_ids)
1386
1387    masks = torch.from_numpy(masks)
1388
1389    one_hot_shape = (n_ids + 1,) + masks.shape
1390    masks = masks.unsqueeze(0)  # add dimension to scatter
1391    masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:]
1392
1393    # add the extra singleton dimension to get shape NUM_OBJECTS x 1 x H x W
1394    masks = masks.unsqueeze(1)
1395    return masks
1396
1397
1398def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1399    """Get a suitable block shape for chunking a given shape.
1400
1401    The primary use for this is determining chunk sizes for
1402    zarr arrays or block shapes for parallelization.
1403
1404    Args:
1405        shape: The image or volume shape.
1406
1407    Returns:
1408        The block shape.
1409    """
1410    ndim = len(shape)
1411    if ndim == 2:
1412        block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape))
1413    elif ndim == 3:
1414        block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape))
1415    else:
1416        raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.")
1417
1418    return block_shape
1419
1420
1421def micro_sam_info() -> None:
1422    """Display μSAM information using a rich console."""
1423    import psutil
1424    import platform
1425    import argparse
1426    from rich import progress
1427    from rich.panel import Panel
1428    from rich.table import Table
1429    from rich.console import Console
1430
1431    import torch
1432    import micro_sam
1433
1434    parser = argparse.ArgumentParser(description="μSAM Information Booth")
1435    parser.add_argument(
1436        "--download", nargs="+", metavar=("WHAT", "KIND"),
1437        help="Downloads the pretrained SAM models."
1438        "'--download models' -> downloads all pretrained models; "
1439        "'--download models vit_b_lm vit_b_em_organelles' -> downloads the listed models; "
1440        "'--download model/models vit_b_lm' -> downloads a single specified model."
1441    )
1442    args = parser.parse_args()
1443
1444    # Open up a new console.
1445    console = Console()
1446
1447    # The header for information CLI.
1448    console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center")
1449    console.print("-" * console.width)
1450
1451    # μSAM version panel.
1452    console.print(
1453        Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True)
1454    )
1455
1456    # The documentation link panel.
1457    console.print(
1458        Panel(
1459            "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n"
1460            "https://computational-cell-analytics.github.io/micro-sam", title="Documentation"
1461        )
1462    )
1463
1464    # The publication panel.
1465    console.print(
1466        Panel(
1467            "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n"
1468            "https://www.nature.com/articles/s41592-024-02580-4", title="Publication"
1469        )
1470    )
1471
1472    # Creating a cache directory when users' run `micro_sam.info`.
1473    cache_dir = get_cache_directory()
1474    os.makedirs(cache_dir, exist_ok=True)
1475
1476    # The cache directory panel.
1477    console.print(
1478        Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{cache_dir}", title="Cache Directory")
1479    )
1480
1481    # We have a simple versioning logic here (which is what I'll follow here for mapping model versions).
1482    available_models = []
1483    for model_name, model_path in models().urls.items():  # We filter out the decoder models.
1484        if model_name.endswith("decoder"):
1485            continue
1486
1487        if "https://dl.fbaipublicfiles.com/segment_anything/" in model_path:  # Valid v1 SAM models.
1488            available_models.append(model_name)
1489
1490        if "https://owncloud.gwdg.de/" in model_path:  # Our own hosted models (in their v1 mode quite often)
1491            if model_name == "vit_t":  # MobileSAM model.
1492                available_models.append(model_name)
1493            else:
1494                available_models.append(f"{model_name} (v1)")
1495
1496        # Now for our models, the BioImageIO ModelZoo upload structure is such that:
1497        # '/1/files' corresponds to v2 models.
1498        # '/1.1/files' corresponds to v3 models.
1499        # '/1.2/files' corresponds to v4 models.
1500        if "/1/files" in model_path:
1501            available_models.append(f"{model_name} (v2)")
1502        if "/1.1/files" in model_path:
1503            available_models.append(f"{model_name} (v3)")
1504        if "/1.2/files" in model_path:
1505            available_models.append(f"{model_name} (v4)")
1506
1507    model_list = "\n".join(available_models)
1508
1509    # The available models panel.
1510    console.print(
1511        Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models")
1512    )
1513
1514    # The system information table.
1515    total_memory = psutil.virtual_memory().total / (1024 ** 3)
1516    table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True)
1517    table.add_column("Property")
1518    table.add_column("Value", style="bold #56B4E9")
1519    table.add_row("System", platform.system())
1520    table.add_row("Node Name", platform.node())
1521    table.add_row("Release", platform.release())
1522    table.add_row("Version", platform.version())
1523    table.add_row("Machine", platform.machine())
1524    table.add_row("Processor", platform.processor())
1525    table.add_row("Platform", platform.platform())
1526    table.add_row("Total RAM (GB)", f"{total_memory:.2f}")
1527    console.print(table)
1528
1529    # The device information and check for available GPU acceleration.
1530    default_device = _get_default_device()
1531
1532    if default_device == "cuda":
1533        device_index = torch.cuda.current_device()
1534        device_name = torch.cuda.get_device_name(device_index)
1535        console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information"))
1536    elif default_device == "mps":
1537        console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information"))
1538    else:
1539        console.print(
1540            Panel(
1541                "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]",
1542                title="Device Information"
1543            )
1544        )
1545
1546    # The section allowing to download models.
1547    # NOTE: In future, can be extended to download sample data.
1548    if args.download:
1549        download_provided_args = [t.lower() for t in args.download]
1550        mode, *model_types = download_provided_args
1551
1552        if mode not in {"models", "model"}:
1553            console.print(f"[red]Unknown option for --download: {mode}[/]")
1554            return
1555
1556        if mode in ["model", "models"] and not model_types:  # If user did not specify, we will download all models.
1557            download_list = available_models
1558        else:
1559            download_list = model_types
1560            incorrect_models = [m for m in download_list if m not in available_models]
1561            if incorrect_models:
1562                console.print(Panel("[red]Unknown model(s):[/] " + ", ".join(incorrect_models), title="Download Error"))
1563                return
1564
1565        with progress.Progress(
1566            progress.SpinnerColumn(),
1567            progress.TextColumn("[progress.description]{task.description}"),
1568            progress.BarColumn(bar_width=None),
1569            "[progress.percentage]{task.percentage:>3.0f}%",
1570            progress.TimeRemainingColumn(),
1571            console=console,
1572        ) as prog:
1573            task = prog.add_task("[green]Downloading μSAM models…", total=len(download_list))
1574            for model_type in download_list:
1575                prog.update(task, description=f"Downloading [cyan]{model_type}[/]…")
1576                _download_sam_model(model_type=model_type)
1577                prog.advance(task)
1578
1579        console.print(Panel("[bold green] Downloads complete![/]", title="Finished"))
1580
1581
1582#
1583# Functionality to convert mask predictions to an instance segmentation via non-maximum suppression.
1584# The functionality for computing NMS for masks is taken from CellSeg1:
1585# https://github.com/Nuisal/cellseg1/blob/1c027c2568b83494d2662d1fbecec9aafb478ee0/mask_nms.py
1586#
1587
1588
1589def _overlap_matrix(boxes):
1590    x1 = torch.max(boxes[:, None, 0], boxes[:, 0])
1591    y1 = torch.max(boxes[:, None, 1], boxes[:, 1])
1592    x2 = torch.min(boxes[:, None, 2], boxes[:, 2])
1593    y2 = torch.min(boxes[:, None, 3], boxes[:, 3])
1594
1595    w = torch.clamp(x2 - x1, min=0)
1596    h = torch.clamp(y2 - y1, min=0)
1597
1598    return (w * h) > 0
1599
1600
1601def _calculate_ious_between_pred_masks(masks, boxes, diagonal_value=1):
1602    n_points = masks.shape[0]
1603    m = torch.zeros((n_points, n_points))
1604
1605    overlap_m = _overlap_matrix(boxes)
1606
1607    for i in range(n_points):
1608        js = torch.where(overlap_m[i])[0]
1609        js_half = js[js > i].to(masks.device)
1610
1611        if len(js_half) > 0:
1612            intersection = torch.logical_and(masks[i], masks[js_half]).sum(dim=(1, 2))
1613            union = torch.logical_or(masks[i], masks[js_half]).sum(dim=(1, 2))
1614            iou = intersection / union
1615            m[i, js_half] = iou
1616
1617    m = m + m.T
1618    m.fill_diagonal_(diagonal_value)
1619    return m
1620
1621
1622def _calculate_iomin_between_pred_masks(masks, boxes, eps=1e-6):
1623    overlap_m = _overlap_matrix(boxes)
1624
1625    # Flatten spatial dimensions: (N, H*W) or (N, D*H*W)
1626    N = masks.shape[0]
1627    masks_flat = masks.reshape(N, -1).float()
1628
1629    # Per-mask area
1630    areas = masks_flat.sum(dim=1)  # (N,)
1631
1632    # Pairwise intersections via matrix multiplication
1633    # inter[i, j] = sum_k masks_flat[i, k] * masks_flat[j, k]
1634    inter = masks_flat @ masks_flat.t()  # (N, N)
1635
1636    # Denominator: min area of the two masks
1637    min_areas = torch.minimum(areas[:, None], areas[None, :])  # (N, N)
1638
1639    # IoMin = intersection / min(area_i, area_j)
1640    iomin = inter / (min_areas + eps)
1641
1642    # Set elements without any overlap explicitly to zero.
1643    iomin[~overlap_m] = 0
1644    return iomin
1645
1646
1647def _batched_mask_nms(masks, boxes, scores, nms_thresh, intersection_over_min):
1648    boxes = (
1649        boxes.detach() if isinstance(boxes, torch.Tensor) else torch.tensor(boxes)
1650    ).cpu()
1651    scores = (
1652        scores.detach() if isinstance(scores, torch.Tensor) else torch.tensor(scores)
1653    ).cpu()
1654    masks = (
1655        masks.detach() if isinstance(masks, torch.Tensor) else torch.tensor(masks)
1656    ).cpu()
1657
1658    if intersection_over_min:
1659        iou_matrix = _calculate_iomin_between_pred_masks(masks, boxes)
1660    else:
1661        iou_matrix = _calculate_ious_between_pred_masks(masks, boxes)
1662    sorted_indices = torch.argsort(scores, descending=True)
1663
1664    keep = []
1665    while len(sorted_indices) > 0:
1666        i = sorted_indices[0]
1667        keep.append(i)
1668
1669        if len(sorted_indices) == 1:
1670            break
1671
1672        iou_values = iou_matrix[i, sorted_indices[1:]]
1673        mask = iou_values <= nms_thresh
1674        sorted_indices = sorted_indices[1:][mask]
1675
1676    return torch.tensor(keep)
1677
1678
1679def _xywh_to_xyxy(boxes):
1680    boxes = boxes.clone() if isinstance(boxes, torch.Tensor) else torch.tensor(boxes)
1681    boxes = boxes.to(torch.float32)
1682    boxes[:, 2] += boxes[:, 0]
1683    boxes[:, 3] += boxes[:, 1]
1684    return boxes
1685
1686
1687def _infer_tiled_shape(predictions):
1688    shape = [0, 0]
1689    for pred in predictions:
1690        bbox, global_bbox = pred["bbox"], pred["global_bbox"]
1691        offset = (global_bbox[0] - bbox[0], global_bbox[1] - bbox[1])
1692        mask_shape = pred["segmentation"].shape
1693        shape[0] = max(shape[0], offset[1] + mask_shape[0])
1694        shape[1] = max(shape[1], offset[0] + mask_shape[1])
1695    return tuple(shape)
1696
1697
1698def _calculate_tiled_mask_overlap_matrix(masks, boxes, global_boxes, intersection_over_min):
1699    n_masks = len(masks)
1700    overlap_scores = torch.zeros((n_masks, n_masks))
1701    overlap_scores.fill_diagonal_(1)
1702
1703    boxes = (
1704        boxes.detach().cpu().to(dtype=torch.long)
1705        if isinstance(boxes, torch.Tensor) else torch.tensor(boxes, dtype=torch.long)
1706    )
1707    global_boxes = (
1708        global_boxes.detach().cpu().to(dtype=torch.long)
1709        if isinstance(global_boxes, torch.Tensor) else torch.tensor(global_boxes, dtype=torch.long)
1710    )
1711    global_boxes_xyxy = _xywh_to_xyxy(global_boxes).to(torch.long)
1712    overlap_m = _overlap_matrix(global_boxes_xyxy)
1713    masks = [mask.detach().cpu() if isinstance(mask, torch.Tensor) else torch.tensor(mask) for mask in masks]
1714    areas = torch.tensor([mask.sum() for mask in masks], dtype=torch.float32)
1715
1716    for i in range(n_masks):
1717        js = torch.where(overlap_m[i])[0]
1718        js_half = js[js > i]
1719        if len(js_half) == 0:
1720            continue
1721
1722        offset_i = global_boxes[i, :2] - boxes[i, :2]
1723        for j in js_half:
1724            offset_j = global_boxes[j, :2] - boxes[j, :2]
1725            overlap = [
1726                max(global_boxes_xyxy[i, 0], global_boxes_xyxy[j, 0]),
1727                max(global_boxes_xyxy[i, 1], global_boxes_xyxy[j, 1]),
1728                min(global_boxes_xyxy[i, 2], global_boxes_xyxy[j, 2]),
1729                min(global_boxes_xyxy[i, 3], global_boxes_xyxy[j, 3]),
1730            ]
1731
1732            mask_i = masks[i][
1733                overlap[1] - offset_i[1]:overlap[3] - offset_i[1],
1734                overlap[0] - offset_i[0]:overlap[2] - offset_i[0],
1735            ]
1736            mask_j = masks[j][
1737                overlap[1] - offset_j[1]:overlap[3] - offset_j[1],
1738                overlap[0] - offset_j[0]:overlap[2] - offset_j[0],
1739            ]
1740            intersection = torch.logical_and(mask_i, mask_j).sum()
1741            min_area = torch.minimum(areas[i], areas[j])
1742            denominator = min_area if intersection_over_min else areas[i] + areas[j] - intersection
1743            overlap_scores[i, j] = intersection / denominator
1744
1745    overlap_scores = overlap_scores + overlap_scores.T
1746    overlap_scores.fill_diagonal_(1)
1747    return overlap_scores
1748
1749
1750def _batched_tiled_mask_nms(masks, boxes, global_boxes, scores, nms_thresh, intersection_over_min):
1751    scores = (
1752        scores.detach() if isinstance(scores, torch.Tensor) else torch.tensor(scores)
1753    ).cpu()
1754
1755    iou_matrix = _calculate_tiled_mask_overlap_matrix(masks, boxes, global_boxes, intersection_over_min)
1756    sorted_indices = torch.argsort(scores, descending=True)
1757
1758    keep = []
1759    while len(sorted_indices) > 0:
1760        i = sorted_indices[0]
1761        keep.append(i)
1762
1763        if len(sorted_indices) == 1:
1764            break
1765
1766        iou_values = iou_matrix[i, sorted_indices[1:]]
1767        mask = iou_values <= nms_thresh
1768        sorted_indices = sorted_indices[1:][mask]
1769
1770    return torch.tensor(keep)
1771
1772
1773def mask_data_to_segmentation(
1774    masks: List[Dict[str, Any]],
1775    shape: Optional[Tuple[int, int]] = None,
1776    min_object_size: int = 0,
1777    max_object_size: Optional[int] = None,
1778    label_masks: bool = True,
1779    with_background: bool = False,
1780    merge_exclusively: bool = True,
1781) -> np.ndarray:
1782    """Convert the output of the automatic mask generation to an instance segmentation.
1783
1784    Args:
1785        masks: The outputs generated by `AutomaticMaskGenerator`, other classes from `micro_sam.instance_segmentation`,
1786            or from `micro_sam.inference` functions. Only supported for output_mode=binary_mask.
1787        shape: The shape of the output segmentation. If None, it will be derived from the mask input.
1788            If the mask where predicted with tiling then the shape must be given.
1789        min_object_size: The minimal size of an object in pixels. By default, set to '0'.
1790        max_object_size: The maximal size of an object in pixels.
1791        label_masks: Whether to apply connected components to the result before removing small objects.
1792            By default, set to 'True'.
1793        with_background: Whether to remove the largest object, which often covers the background for AMG.
1794        merge_exclusively: Whether to exclude previous merged masks from merging.
1795
1796    Returns:
1797        The instance segmentation.
1798    """
1799    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
1800    if shape is None:
1801        shape = next(iter(masks))["segmentation"].shape
1802    segmentation = np.zeros(shape, dtype="uint32")
1803
1804    def require_numpy(mask):
1805        return mask.cpu().numpy() if torch.is_tensor(mask) else mask
1806
1807    seg_id = 1
1808    for mask_data in masks:
1809        area = mask_data["area"]
1810        if (area < min_object_size) or (max_object_size is not None and area > max_object_size):
1811            continue
1812
1813        this_mask = require_numpy(mask_data["segmentation"])
1814        this_seg_id = mask_data.get("seg_id", seg_id)
1815        if "global_bbox" in mask_data:
1816            bb = mask_data["bbox"]
1817            bb = np.s_[bb[1]:bb[1] + bb[3], bb[0]:bb[0] + bb[2]]
1818            global_bb = mask_data["global_bbox"]
1819            global_bb = np.s_[global_bb[1]:global_bb[1] + global_bb[3], global_bb[0]:global_bb[0] + global_bb[2]]
1820            if merge_exclusively:
1821                this_mask = np.logical_and(this_mask[bb], segmentation[global_bb] == 0)
1822            else:
1823                this_mask = this_mask[bb]
1824            segmentation[global_bb][this_mask] = this_seg_id
1825        else:
1826            if merge_exclusively:
1827                this_mask = np.logical_and(this_mask, segmentation == 0)
1828            segmentation[this_mask] = this_seg_id
1829        seg_id = this_seg_id + 1
1830
1831    block_shape = (512, 512)
1832    if label_masks:
1833        segmentation_cc = np.zeros_like(segmentation, dtype=segmentation.dtype)
1834        segmentation_cc = parallel_impl.label(segmentation, out=segmentation_cc, block_shape=block_shape)
1835        segmentation = segmentation_cc
1836
1837    seg_ids, sizes = parallel_impl.unique(segmentation, return_counts=True, block_shape=block_shape)
1838    filter_ids = seg_ids[sizes < min_object_size]
1839    if with_background:
1840        bg_id = seg_ids[np.argmax(sizes)]
1841        filter_ids = np.concatenate([filter_ids, [bg_id]])
1842
1843    filter_mask = np.zeros(segmentation.shape, dtype="bool")
1844    filter_mask = parallel_impl.isin(segmentation, filter_ids, out=filter_mask, block_shape=block_shape)
1845    segmentation[filter_mask] = 0
1846    parallel_impl.relabel_consecutive(segmentation, block_shape=block_shape)[0]
1847
1848    return segmentation
1849
1850
1851def apply_nms(
1852    predictions: List[Dict[str, Any]],
1853    min_size: int,
1854    shape: Optional[Tuple[int, int]] = None,
1855    perform_box_nms: bool = False,
1856    nms_thresh: float = 0.9,
1857    max_size: Optional[int] = None,
1858    intersection_over_min: bool = False,
1859) -> np.ndarray:
1860    """Apply non-maximum suppression to mask predictions from a segment anything model.
1861
1862    Args:
1863        predictions: The mask predictions from SAM.
1864        min_size: The minimum mask size to keep in the output.
1865        shape: The shape of the output segmentation.
1866            For tiled predictions this is inferred from the tile-local mask shapes if it is not passed.
1867        perform_box_nms: Whether to perform NMS on the box coordinates or on the masks.
1868        nms_thresh: The threshold for filtering out objects in NMS.
1869        max_size: The maximum mask size to keep in the output.
1870        intersection_over_min: Whether to perform intersection over the minimum overlap shape
1871            or to perform intersection over union.
1872
1873    Returns:
1874        The segmentation obtained from merging the masks left after NMS.
1875    """
1876    # Check if the input comes with a 'global_bbox' attribute. If it does, then the predictions are from
1877    # a tiled prediction. In this case, we have to take the coordinates w.r.t. the tiling into account.
1878    is_tiled = "global_bbox" in predictions[0]
1879    if is_tiled and shape is None:
1880        shape = _infer_tiled_shape(predictions)
1881
1882    masks = [pred["segmentation"] for pred in predictions]
1883    nms_masks = None if is_tiled else torch.cat([mask[None] for mask in masks], dim=0)
1884    data = amg_utils.MaskData(masks=masks, iou_preds=torch.tensor([pred["predicted_iou"] for pred in predictions]))
1885    data["boxes"] = torch.tensor(np.array([pred["bbox"] for pred in predictions]))
1886    data["area"] = [int(mask.sum()) for mask in data["masks"]]
1887    data["stability_scores"] = torch.tensor([pred["stability_score"] for pred in predictions])
1888    if is_tiled:
1889        data["global_boxes"] = torch.tensor(np.array([pred["global_bbox"] for pred in predictions]))
1890
1891    if min_size > 0:
1892        keep_by_size = torch.tensor(
1893            [i for i, area in enumerate(data["area"]) if area > min_size], dtype=torch.long,
1894        )
1895        data.filter(keep_by_size)
1896        if nms_masks is not None:
1897            nms_masks = nms_masks[keep_by_size]
1898
1899    if max_size is not None:
1900        keep_by_size = torch.tensor([i for i, area in enumerate(data["area"]) if area < max_size])
1901        data.filter(keep_by_size)
1902        if nms_masks is not None:
1903            nms_masks = nms_masks[keep_by_size]
1904
1905    if len(data["masks"]) == 0:
1906        if shape is None:
1907            shape = predictions[0]["segmentation"].shape
1908        return np.zeros(shape, dtype="uint32")
1909
1910    scores = data["iou_preds"] * data["stability_scores"]
1911    boxes = _xywh_to_xyxy(data["global_boxes"] if is_tiled else data["boxes"])
1912    if perform_box_nms:
1913        assert not intersection_over_min  # not implemented
1914        keep_by_nms = batched_nms(
1915            boxes,
1916            scores,
1917            torch.zeros_like(data["boxes"][:, 0]),  # categories
1918            iou_threshold=nms_thresh,
1919        )
1920    elif is_tiled:
1921        keep_by_nms = _batched_tiled_mask_nms(
1922            masks=data["masks"],
1923            boxes=data["boxes"],
1924            global_boxes=data["global_boxes"],
1925            scores=scores,
1926            nms_thresh=nms_thresh,
1927            intersection_over_min=intersection_over_min,
1928        )
1929    else:
1930        keep_by_nms = _batched_mask_nms(
1931            masks=nms_masks,
1932            boxes=boxes,
1933            scores=scores,
1934            nms_thresh=nms_thresh,
1935            intersection_over_min=intersection_over_min,
1936        )
1937    data.filter(keep_by_nms)
1938
1939    if is_tiled:
1940        mask_data = [
1941            {"segmentation": mask, "area": area, "bbox": box, "global_bbox": global_box}
1942            for mask, area, box, global_box in zip(data["masks"], data["area"], data["boxes"], data["global_boxes"])
1943        ]
1944    else:
1945        mask_data = [
1946            {"segmentation": mask, "area": area, "bbox": box}
1947            for mask, area, box in zip(data["masks"], data["area"], data["boxes"])
1948        ]
1949
1950    if shape is None:
1951        shape = predictions[0]["segmentation"].shape
1952    if mask_data:
1953        segmentation = mask_data_to_segmentation(mask_data, shape=shape, min_object_size=min_size)
1954    else:  # In case all objects have been filtered out due to size filtering.
1955        segmentation = np.zeros(shape, dtype="uint32")
1956
1957    return segmentation
def get_cache_directory() -> None:
64def get_cache_directory() -> None:
65    """Get micro-sam cache directory location.
66
67    Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
68    """
69    default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam"))
70    cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory))
71    return cache_directory

Get micro-sam cache directory location.

Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.

def microsam_cachedir() -> None:
79def microsam_cachedir() -> None:
80    """Return the micro-sam cache directory.
81
82    Returns the top level cache directory for micro-sam models and sample data.
83
84    Every time this function is called, we check for any user updates made to
85    the MICROSAM_CACHEDIR os environment variable since the last time.
86    """
87    cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam")
88    return cache_directory

Return the micro-sam cache directory.

Returns the top level cache directory for micro-sam models and sample data.

Every time this function is called, we check for any user updates made to the MICROSAM_CACHEDIR os environment variable since the last time.

def models():
 91def models():
 92    """Return the segmentation models registry.
 93
 94    We recreate the model registry every time this function is called,
 95    so any user changes to the default micro-sam cache directory location
 96    are respected.
 97    """
 98
 99    # We use xxhash to compute the hash of the models, see
100    # https://github.com/computational-cell-analytics/micro-sam/issues/283
101    # (It is now a dependency, so we don't provide the sha256 fallback anymore.)
102    # To generate the xxh128 hash:
103    #     xxh128sum filename
104    encoder_registry = {
105        # The default segment anything models:
106        "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a",
107        "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2",
108        "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b",
109        # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM.
110        "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0",
111        # The current version of our models in the modelzoo.
112        # LM generalist models:
113        "vit_l_lm": "xxh128:017f20677997d628426dec80a8018f9d",
114        "vit_b_lm": "xxh128:fe9252a29f3f4ea53c15a06de471e186",
115        "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e",
116        # EM models:
117        "vit_l_em_organelles": "xxh128:810b084b6e51acdbf760a993d8619f2d",
118        "vit_b_em_organelles": "xxh128:f3bf2ed83d691456bae2c3f9a05fb438",
119        "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9",
120        # Histopathology models:
121        "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974",
122        "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2",
123        "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e",
124        # Medical Imaging models:
125        "vit_b_medical_imaging": "xxh128:40169f1e3c03a4b67bff58249c176d92",
126    }
127    # Additional decoders for instance segmentation.
128    decoder_registry = {
129        # LM generalist models:
130        "vit_l_lm_decoder": "xxh128:2faeafa03819dfe03e7c46a44aaac64a",
131        "vit_b_lm_decoder": "xxh128:708b15ac620e235f90bb38612c4929ba",
132        "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823",
133        # EM models:
134        "vit_l_em_organelles_decoder": "xxh128:334877640bfdaaabce533e3252a17294",
135        "vit_b_em_organelles_decoder": "xxh128:bb6398956a6b0132c26b631c14f95ce2",
136        "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44",
137        # Histopathology models:
138        "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213",
139        "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04",
140        "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47",
141        # Medical Imaging models:
142        "vit_b_medical_imaging_decoder": "xxh128:9e498b12f526f119b96c88be76e3b2ed",
143    }
144    registry = {**encoder_registry, **decoder_registry}
145
146    encoder_urls = {
147        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
148        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
149        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
150        "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
151        "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l.pt",
152        "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b.pt",
153        "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt",
154        "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l.pt",  # noqa
155        "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b.pt",  # noqa
156        "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt",  # noqa
157        "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download",
158        "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download",
159        "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download",
160        "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/f5Ol4FrjPQWfjUF/download",
161    }
162
163    decoder_urls = {
164        "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l_decoder.pt",  # noqa
165        "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b_decoder.pt",  # noqa
166        "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt",  # noqa
167        "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l_decoder.pt",  # noqa
168        "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b_decoder.pt",  # noqa
169        "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt",  # noqa
170        "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download",
171        "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download",
172        "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download",
173        "vit_b_medical_imaging_decoder": "https://owncloud.gwdg.de/index.php/s/ahd3ZhZl2e0RIwz/download",
174    }
175    urls = {**encoder_urls, **decoder_urls}
176
177    models = pooch.create(
178        path=os.path.join(microsam_cachedir(), "models"),
179        base_url="",
180        registry=registry,
181        urls=urls,
182    )
183    return models

Return the segmentation models registry.

We recreate the model registry every time this function is called, so any user changes to the default micro-sam cache directory location are respected.

def get_device( device: Union[str, torch.device, NoneType] = None) -> Union[str, torch.device]:
205def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
206    """Get the torch device.
207
208    If no device is passed the default device for your system is used.
209    Else it will be checked if the device you have passed is supported.
210
211    Args:
212        device: The input device. By default, selects the best available device supports.
213
214    Returns:
215        The device.
216    """
217    if device is None or device == "auto":
218        device = _get_default_device()
219    else:
220        device_type = device if isinstance(device, str) else device.type
221        if device_type.lower() == "cuda":
222            if not torch.cuda.is_available():
223                raise RuntimeError("PyTorch CUDA backend is not available.")
224        elif device_type.lower() == "mps":
225            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
226                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
227        elif device_type.lower() == "cpu":
228            pass  # cpu is always available
229        else:
230            raise RuntimeError(f"Unsupported device: '{device}'. Please choose from 'cpu', 'cuda', or 'mps'.")
231
232    return device

Get the torch device.

If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.

Arguments:
  • device: The input device. By default, selects the best available device supports.
Returns:

The device.

def get_sam_model( model_type: str = 'vit_b_lm', device: Union[str, torch.device, NoneType] = None, checkpoint_path: Union[str, os.PathLike, NoneType] = None, return_sam: bool = False, return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, progress_bar_factory: Optional[Callable] = None, decoder_path: Union[str, os.PathLike, NoneType] = None, **model_kwargs) -> mobile_sam.predictor.SamPredictor:
319def get_sam_model(
320    model_type: str = _DEFAULT_MODEL,
321    device: Optional[Union[str, torch.device]] = None,
322    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
323    return_sam: bool = False,
324    return_state: bool = False,
325    peft_kwargs: Optional[Dict] = None,
326    flexible_load_checkpoint: bool = False,
327    progress_bar_factory: Optional[Callable] = None,
328    decoder_path: Optional[Union[str, os.PathLike]] = None,
329    **model_kwargs,
330) -> SamPredictor:
331    r"""Get the Segment Anything Predictor.
332
333    This function will download the required model or load it from the cached weight file.
334    This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
335    The name of the requested model can be set via `model_type`.
336    See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
337    for an overview of the available models
338
339    Alternatively this function can also load a model from weights stored in a local filepath.
340    The corresponding file path is given via `checkpoint_path`. In this case `model_type`
341    must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
342    a SAM model with vit_b encoder.
343
344    By default the models are downloaded to a folder named 'micro_sam/models'
345    inside your default cache directory, eg:
346    * Mac: ~/Library/Caches/<AppName>
347    * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined.
348    * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache
349    See the pooch.os_cache() documentation for more details:
350    https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
351
352    Args:
353        model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default.
354            To get a list of all available model names you can call `micro_sam.util.get_model_names`.
355        device: The device for the model. If 'None' is provided, will use GPU if available.
356        checkpoint_path: The path to a file with weights that should be used instead of using the
357            weights corresponding to `model_type`. If given, `model_type` must match the architecture
358            corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder
359            then `model_type` must be given as 'vit_b'.
360        return_sam: Return the sam model object as well as the predictor. By default, set to 'False'.
361        return_state: Return the unpickled checkpoint state. By default, set to 'False'.
362        peft_kwargs: Keyword arguments for th PEFT wrapper class.
363            If passed 'None', it does not initialize any parameter efficient finetuning.
364        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
365            By default, set to 'False'.
366        progress_bar_factory: A function to create a progress bar for the model download.
367        decoder_path: Optional path to weights for a segmentation decoder. If given and
368            `return_state=True`, the decoder state is added to the returned state as
369            'decoder_state'. This can be used to provide decoder-only weights that are
370            separate from the encoder checkpoint.
371        model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
372
373    Returns:
374        The Segment Anything predictor.
375    """
376    device = get_device(device)
377
378    # We support passing a local filepath to a checkpoint.
379    # In this case we do not download any weights but just use the local weight file,
380    # as it is, without copying it over anywhere or checking it's hashes.
381
382    # checkpoint_path has not been passed, we download a known model and derive the correct
383    # URL from the model_type. If the model_type is invalid pooch will raise an error.
384    _provided_checkpoint_path = checkpoint_path is not None
385    if checkpoint_path is None:
386        checkpoint_path, model_hash, downloaded_decoder_path = _download_sam_model(model_type, progress_bar_factory)
387        if decoder_path is None:
388            decoder_path = downloaded_decoder_path
389
390    # checkpoint_path has been passed, we use it instead of downloading a model.
391    else:
392        # Check if the file exists and raise an error otherwise.
393        # We can't check any hashes here, and we don't check if the file is actually a valid weight file.
394        # (If it isn't the model creation will fail below.)
395        if not os.path.exists(checkpoint_path):
396            raise ValueError(f"Checkpoint at '{checkpoint_path}' could not be found.")
397        model_hash = _compute_hash(checkpoint_path)
398
399    if decoder_path is not None and not os.path.exists(decoder_path):
400        raise ValueError(f"Decoder checkpoint at '{decoder_path}' could not be found.")
401
402    # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped
403    # before calling sam_model_registry.
404    abbreviated_model_type = model_type[:5]
405    if abbreviated_model_type not in _MODEL_TYPES:
406        raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}")
407    if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
408        raise RuntimeError(
409            "'mobile_sam' is required for the vit-tiny. "
410            "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
411        )
412
413    state, model_state = _load_checkpoint(checkpoint_path)
414
415    if _provided_checkpoint_path:
416        # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type'
417        # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination.
418        from micro_sam.models.build_sam import _validate_model_type
419        _provided_model_type = _validate_model_type(model_state)
420
421        # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type'
422        # Otherwise replace 'abbreviated_model_type' with the later.
423        if abbreviated_model_type != _provided_model_type:
424            # Printing the message below to avoid any filtering of warnings on user's end.
425            print(
426                f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', "
427                f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. "
428                f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. "
429                "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future."
430            )
431
432        # Replace the extracted 'abbreviated_model_type' subjected to the model weights.
433        abbreviated_model_type = _provided_model_type
434
435    # Whether to update parameters necessary to initialize the model
436    if model_kwargs:  # Checks whether model_kwargs have been provided or not
437        if abbreviated_model_type == "vit_t":
438            raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.")
439        sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs)
440
441    else:
442        sam = sam_model_registry[abbreviated_model_type]()
443
444    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
445    # Overwrites the SAM model by freezing the backbone and allow PEFT.
446    if peft_kwargs and isinstance(peft_kwargs, dict):
447        # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference.
448        peft_kwargs.pop("quantize", None)
449
450        if abbreviated_model_type == "vit_t":
451            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
452
453        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
454    # In case the model checkpoints have some issues when it is initialized with different parameters than default.
455    if flexible_load_checkpoint:
456        sam = _handle_checkpoint_loading(sam, model_state)
457    else:
458        sam.load_state_dict(model_state)
459    sam.to(device=device)
460
461    predictor = SamPredictor(sam)
462    predictor.model_type = abbreviated_model_type
463    predictor._hash = model_hash
464    predictor.model_name = model_type
465    predictor.checkpoint_path = checkpoint_path
466
467    # Add the decoder to the state if we have one and if the state is returned.
468    if decoder_path is not None and return_state:
469        state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False)
470
471    if return_sam and return_state:
472        return predictor, sam, state
473    if return_sam:
474        return predictor, sam
475    if return_state:
476        return predictor, state
477    return predictor

Get the Segment Anything Predictor.

This function will download the required model or load it from the cached weight file. This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. The name of the requested model can be set via model_type. See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models for an overview of the available models

Alternatively this function can also load a model from weights stored in a local filepath. The corresponding file path is given via checkpoint_path. In this case model_type must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for a SAM model with vit_b encoder.

By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg:

Arguments:
  • model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default. To get a list of all available model names you can call micro_sam.util.get_model_names.
  • device: The device for the model. If 'None' is provided, will use GPU if available.
  • checkpoint_path: The path to a file with weights that should be used instead of using the weights corresponding to model_type. If given, model_type must match the architecture corresponding to the weight file. e.g. if you use weights for SAM with vit_b encoder then model_type must be given as 'vit_b'.
  • return_sam: Return the sam model object as well as the predictor. By default, set to 'False'.
  • return_state: Return the unpickled checkpoint state. By default, set to 'False'.
  • peft_kwargs: Keyword arguments for th PEFT wrapper class. If passed 'None', it does not initialize any parameter efficient finetuning.
  • flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. By default, set to 'False'.
  • progress_bar_factory: A function to create a progress bar for the model download.
  • decoder_path: Optional path to weights for a segmentation decoder. If given and return_state=True, the decoder state is added to the returned state as 'decoder_state'. This can be used to provide decoder-only weights that are separate from the encoder checkpoint.
  • model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
Returns:

The Segment Anything predictor.

def export_custom_sam_model( checkpoint_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike], with_segmentation_decoder: bool = False, prefix: str = 'sam.') -> None:
513def export_custom_sam_model(
514    checkpoint_path: Union[str, os.PathLike],
515    model_type: str,
516    save_path: Union[str, os.PathLike],
517    with_segmentation_decoder: bool = False,
518    prefix: str = "sam.",
519) -> None:
520    """Export a finetuned Segment Anything Model to the standard model format.
521
522    The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
523
524    Args:
525        checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
526        model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
527        save_path: Where to save the exported model.
528        with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well.
529            If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
530        prefix: The prefix to remove from the model parameter keys.
531    """
532    state, model_state = _load_checkpoint(checkpoint_path=checkpoint_path)
533    model_state = OrderedDict([(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()])
534
535    # Store the 'decoder_state' as well, if desired.
536    if with_segmentation_decoder:
537        if "decoder_state" not in state:
538            raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.")
539        decoder_state = state["decoder_state"]
540        save_state = {"model_state": model_state, "decoder_state": decoder_state}
541    else:
542        save_state = model_state
543
544    torch.save(save_state, save_path)

Export a finetuned Segment Anything Model to the standard model format.

The exported model can be used by the interactive annotation tools in micro_sam.annotator.

Arguments:
  • checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
  • model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
  • save_path: Where to save the exported model.
  • with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
  • prefix: The prefix to remove from the model parameter keys.
def export_custom_qlora_model( checkpoint_path: Union[str, os.PathLike, NoneType], finetuned_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike]) -> None:
547def export_custom_qlora_model(
548    checkpoint_path: Optional[Union[str, os.PathLike]],
549    finetuned_path: Union[str, os.PathLike],
550    model_type: str,
551    save_path: Union[str, os.PathLike],
552) -> None:
553    """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
554
555    The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`.
556
557    Args:
558        checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
559        finetuned_path: The path to the new finetuned model, using QLoRA.
560        model_type: The Segment Anything Model type corresponding to the checkpoint.
561        save_path: Where to save the exported model.
562    """
563    # Step 1: Get the base SAM model: used to start finetuning from.
564    _, sam = get_sam_model(
565        model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True,
566    )
567
568    # Step 2: Load the QLoRA-style finetuned model.
569    ft_state, ft_model_state = _load_checkpoint(finetuned_path)
570
571    # Step 3: Identify LoRA layers from QLoRA model.
572    # - differentiate between LoRA applied to the attention matrices and LoRA applied to the MLP layers.
573    # - then copy the LoRA layers from the QLoRA model to the new state dict
574    updated_model_state = {}
575
576    modified_attn_layers = set()
577    modified_mlp_layers = set()
578
579    for k, v in ft_model_state.items():
580        if "blocks." in k:
581            layer_id = int(k.split("blocks.")[1].split(".")[0])
582        if k.find("qkv.w_a_linear") != -1 or k.find("qkv.w_b_linear") != -1:
583            modified_attn_layers.add(layer_id)
584            updated_model_state[k] = v
585        if k.find("mlp.w_a_linear") != -1 or k.find("mlp.w_b_linear") != -1:
586            modified_mlp_layers.add(layer_id)
587            updated_model_state[k] = v
588
589    # Step 4: Next, we get all the remaining parameters from the base SAM model.
590    for k, v in sam.state_dict().items():
591        if "blocks." in k:
592            layer_id = int(k.split("blocks.")[1].split(".")[0])
593        if k.find("attn.qkv.") != -1:
594            if layer_id in modified_attn_layers:  # We have LoRA in QKV layers, so we need to modify the key
595                k = k.replace("qkv", "qkv.qkv_proj")
596        elif k.find("mlp") != -1 and k.find("image_encoder") != -1:
597            if layer_id in modified_mlp_layers:  # We have LoRA in MLP layers, so we need to modify the key
598                k = k.replace("mlp.", "mlp.mlp_layer.")
599        updated_model_state[k] = v
600
601    # Step 5: Finally, we replace the old model state with the new one (to retain other relevant stuff)
602    ft_state['model_state'] = updated_model_state
603
604    # Step 6: Store the new "state" to "save_path"
605    torch.save(ft_state, save_path)

Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.

The exported model can be used with the LoRA backbone by passing the relevant peft_kwargs to get_sam_model.

Arguments:
  • checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
  • finetuned_path: The path to the new finetuned model, using QLoRA.
  • model_type: The Segment Anything Model type corresponding to the checkpoint.
  • save_path: Where to save the exported model.
def get_model_names() -> Iterable:
608def get_model_names() -> Iterable:
609    model_registry = models()
610    model_names = model_registry.registry.keys()
611    return model_names
def precompute_image_embeddings( predictor: mobile_sam.predictor.SamPredictor, input_: numpy.ndarray, save_path: Union[str, os.PathLike, NoneType] = None, lazy_loading: bool = False, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, batch_size: int = 1, mask: Optional[ArrayLike] = None, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> Dict[str, Any]:
1134def precompute_image_embeddings(
1135    predictor: SamPredictor,
1136    input_: np.ndarray,
1137    save_path: Optional[Union[str, os.PathLike]] = None,
1138    lazy_loading: bool = False,
1139    ndim: Optional[int] = None,
1140    tile_shape: Optional[Tuple[int, int]] = None,
1141    halo: Optional[Tuple[int, int]] = None,
1142    verbose: bool = True,
1143    batch_size: int = 1,
1144    mask: Optional[np.typing.ArrayLike] = None,
1145    pbar_init: Optional[callable] = None,
1146    pbar_update: Optional[callable] = None,
1147) -> ImageEmbeddings:
1148    """Compute the image embeddings (output of the encoder) for the input.
1149
1150    If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
1151
1152    Args:
1153        predictor: The Segment Anything predictor.
1154        input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
1155        save_path: Path to save the embeddings in a zarr container.
1156            By default, set to 'None', i.e. the computed embeddings will not be stored locally.
1157        lazy_loading: Whether to load all embeddings into memory or return an
1158            object to load them on demand when required. This only has an effect if 'save_path' is given
1159            and if the input is 3 dimensional. By default, set to 'False'.
1160        ndim: The dimensionality of the data. If not given will be deduced from the input data.
1161            By default, set to 'None', i.e. will be computed from the provided `input_`.
1162        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
1163        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
1164        verbose: Whether to be verbose in the computation. By default, set to 'True'.
1165        batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'.
1166        mask: An optional mask to define areas that are ignored in the computation.
1167            The mask will be used within tiled embedding computation and tiles that don't contain any foreground
1168            in the mask will be excluded from the computation. It does not have any effect for non-tiled embeddings.
1169        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1170            Can be used together with pbar_update to handle napari progress bar in other thread.
1171            To enable using this function within a threadworker.
1172        pbar_update: Callback to update an external progress bar.
1173
1174    Returns:
1175        The image embeddings.
1176    """
1177    ndim = input_.ndim if ndim is None else ndim
1178
1179    # Handle the embedding save_path.
1180    # We don't have a save path, open in memory zarr file to hold tiled embeddings.
1181    if save_path is None:
1182        f = zarr.group()
1183
1184    # We have a save path and it already exists. Embeddings will be loaded from it,
1185    # check that the saved embeddings in there match the parameters of the function call.
1186    elif os.path.exists(save_path):
1187        f = zarr.open(save_path, mode="a")
1188        _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo)
1189
1190    # We have a save path and it does not exist yet. Create the zarr file to which the
1191    # embeddings will then be saved.
1192    else:
1193        f = zarr.open(save_path, mode="a")
1194
1195    _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update)
1196
1197    if ndim == 2 and tile_shape is None:
1198        embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update)
1199    elif ndim == 2 and tile_shape is not None:
1200        embeddings = _compute_tiled_2d(
1201            input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask
1202        )
1203    elif ndim == 3 and tile_shape is None:
1204        embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size)
1205    elif ndim == 3 and tile_shape is not None:
1206        embeddings = _compute_tiled_3d(
1207            input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask
1208        )
1209    else:
1210        raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.")
1211
1212    pbar_close()
1213    return embeddings

Compute the image embeddings (output of the encoder) for the input.

If 'save_path' is given the embeddings will be loaded/saved in a zarr container.

Arguments:
  • predictor: The Segment Anything predictor.
  • input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
  • save_path: Path to save the embeddings in a zarr container. By default, set to 'None', i.e. the computed embeddings will not be stored locally.
  • lazy_loading: Whether to load all embeddings into memory or return an object to load them on demand when required. This only has an effect if 'save_path' is given and if the input is 3 dimensional. By default, set to 'False'.
  • ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, set to 'None', i.e. will be computed from the provided input_.
  • tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
  • verbose: Whether to be verbose in the computation. By default, set to 'True'.
  • batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'.
  • mask: An optional mask to define areas that are ignored in the computation. The mask will be used within tiled embedding computation and tiles that don't contain any foreground in the mask will be excluded from the computation. It does not have any effect for non-tiled embeddings.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enable using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
Returns:

The image embeddings.

def set_precomputed( predictor: mobile_sam.predictor.SamPredictor, image_embeddings: Dict[str, Any], i: Optional[int] = None, tile_id: Optional[int] = None) -> mobile_sam.predictor.SamPredictor:
1216def set_precomputed(
1217    predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None,
1218) -> SamPredictor:
1219    """Set the precomputed image embeddings for a predictor.
1220
1221    Args:
1222        predictor: The Segment Anything predictor.
1223        image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`.
1224        i: Index for the image data. Required if `image` has three spatial dimensions
1225            or a time dimension and two spatial dimensions.
1226        tile_id: Index for the tile. This is required if the embeddings are tiled.
1227
1228    Returns:
1229        The predictor with set features.
1230    """
1231    if tile_id is not None:
1232        tile_features = image_embeddings["features"][str(tile_id)]
1233        tile_image_embeddings = {
1234            "features": tile_features,
1235            "input_size": tile_features.attrs["input_size"],
1236            "original_size": tile_features.attrs["original_size"]
1237        }
1238        return set_precomputed(predictor, tile_image_embeddings, i=i)
1239
1240    device = predictor.device
1241    features = image_embeddings["features"]
1242    assert features.ndim in (4, 5), f"{features.ndim}"
1243    if features.ndim == 5 and i is None:
1244        raise ValueError("The data is 3D so an index i is needed.")
1245    elif features.ndim == 4 and i is not None:
1246        raise ValueError("The data is 2D so an index is not needed.")
1247
1248    if i is None:
1249        predictor.features = features.to(device) if torch.is_tensor(features) else \
1250            torch.from_numpy(features[:]).to(device)
1251    else:
1252        predictor.features = features[i].to(device) if torch.is_tensor(features) else \
1253            torch.from_numpy(features[i]).to(device)
1254
1255    predictor.original_size = image_embeddings["original_size"]
1256    predictor.input_size = image_embeddings["input_size"]
1257    predictor.is_image_set = True
1258
1259    return predictor

Set the precomputed image embeddings for a predictor.

Arguments:
  • predictor: The Segment Anything predictor.
  • image_embeddings: The precomputed image embeddings computed by precompute_image_embeddings.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_id: Index for the tile. This is required if the embeddings are tiled.
Returns:

The predictor with set features.

def compute_iou(mask1: numpy.ndarray, mask2: numpy.ndarray) -> float:
1267def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
1268    """Compute the intersection over union of two masks.
1269
1270    Args:
1271        mask1: The first mask.
1272        mask2: The second mask.
1273
1274    Returns:
1275        The intersection over union of the two masks.
1276    """
1277    overlap = np.logical_and(mask1 == 1, mask2 == 1).sum()
1278    union = np.logical_or(mask1 == 1, mask2 == 1).sum()
1279    eps = 1e-7
1280    iou = float(overlap) / (float(union) + eps)
1281    return iou

Compute the intersection over union of two masks.

Arguments:
  • mask1: The first mask.
  • mask2: The second mask.
Returns:

The intersection over union of the two masks.

def get_centers_and_bounding_boxes( segmentation: numpy.ndarray, mode: str = 'v') -> Tuple[Dict[int, numpy.ndarray], Dict[int, tuple]]:
1284def get_centers_and_bounding_boxes(
1285    segmentation: np.ndarray, mode: str = "v"
1286) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]:
1287    """Returns the center coordinates of the foreground instances in the ground-truth.
1288
1289    Args:
1290        segmentation: The segmentation.
1291        mode: Determines the functionality used for computing the centers.
1292            If 'v', the point of maximal distance to the object boundary is used as center.
1293            This center is guaranteed to lie inside the object, also for concave shapes.
1294            If 'p' the object's centroids computed by skimage are used.
1295
1296    Returns:
1297        A dictionary that maps object ids to the corresponding centroid.
1298        A dictionary that maps object_ids to the corresponding bounding box.
1299    """
1300    assert mode in ["p", "v"], "Choose either 'p' for regionprops centroids or 'v' for distance-based centers"
1301
1302    properties = regionprops(segmentation)
1303
1304    if mode == "p":
1305        center_coordinates = {prop.label: prop.centroid for prop in properties}
1306    elif mode == "v":
1307        # Use the point of maximal distance to the object boundary as the center.
1308        # In contrast to the centroid, this point is guaranteed to lie inside the object,
1309        # also for concave shapes. This replaces vigra.filters.eccentricityCenters.
1310        # Compute the boundaries and a single distance transform for the whole
1311        # segmentation, instead of one distance transform per object.
1312        ndim = segmentation.ndim
1313        # Pad so objects touching the image border also get a boundary there,
1314        # matching a per-object padded distance transform.
1315        padded = np.pad(segmentation, 1)
1316        boundaries = find_boundaries(padded, mode="inner")
1317        distances = distance_transform(boundaries == 0)
1318
1319        center_coordinates = {}
1320        for prop in properties:
1321            bbox = prop.bbox
1322            # Slice the global distance field to this object's bbox (shifted by the
1323            # pad of 1) and restrict the argmax to the object's own pixels.
1324            region = distances[tuple(slice(b + 1, e + 1) for b, e in zip(bbox[:ndim], bbox[ndim:]))]
1325            masked = np.where(prop.image, region, -1.0)
1326            center_local = np.unravel_index(int(np.argmax(masked)), masked.shape)
1327            center_coordinates[prop.label] = tuple(int(c + o) for c, o in zip(center_local, bbox[:ndim]))
1328
1329    bbox_coordinates = {prop.label: prop.bbox for prop in properties}
1330
1331    assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}"
1332    return center_coordinates, bbox_coordinates

Returns the center coordinates of the foreground instances in the ground-truth.

Arguments:
  • segmentation: The segmentation.
  • mode: Determines the functionality used for computing the centers. If 'v', the point of maximal distance to the object boundary is used as center. This center is guaranteed to lie inside the object, also for concave shapes. If 'p' the object's centroids computed by skimage are used.
Returns:

A dictionary that maps object ids to the corresponding centroid. A dictionary that maps object_ids to the corresponding bounding box.

def load_image_data( path: str, key: Optional[str] = None, lazy_loading: bool = False) -> numpy.ndarray:
1335def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray:
1336    """Helper function to load image data from file.
1337
1338    Args:
1339        path: The filepath to the image data.
1340        key: The internal filepath for complex data formats like hdf5.
1341        lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
1342
1343    Returns:
1344        The image data.
1345    """
1346    if key is None:
1347        image_data = imageio.imread(path)
1348    else:
1349        with open_file(path, mode="r") as f:
1350            image_data = f[key]
1351            if not lazy_loading:
1352                image_data = image_data[:]
1353
1354    return image_data

Helper function to load image data from file.

Arguments:
  • path: The filepath to the image data.
  • key: The internal filepath for complex data formats like hdf5.
  • lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
Returns:

The image data.

def segmentation_to_one_hot( segmentation: numpy.ndarray, segmentation_ids: Optional[numpy.ndarray] = None) -> torch.Tensor:
1357def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor:
1358    """Convert the segmentation to one-hot encoded masks.
1359
1360    Args:
1361        segmentation: The segmentation.
1362        segmentation_ids: Optional subset of ids that will be used to subsample the masks.
1363            By default, computes the number of ids from the provided `segmentation` masks.
1364
1365    Returns:
1366        The one-hot encoded masks.
1367    """
1368    masks = segmentation.copy()
1369    if segmentation_ids is None:
1370        n_ids = int(segmentation.max())
1371
1372    else:
1373        msg = "No foreground objects were found."
1374        if len(segmentation_ids) == 0:  # The list should not be completely empty.
1375            raise RuntimeError(msg)
1376
1377        if 0 in segmentation_ids:  # The list should not have 'zero' as a value.
1378            raise RuntimeError(msg)
1379
1380        # the segmentation ids have to be sorted
1381        segmentation_ids = np.sort(segmentation_ids)
1382
1383        # set the non selected objects to zero and relabel sequentially
1384        masks[~np.isin(masks, segmentation_ids)] = 0
1385        masks = relabel_sequential(masks)[0]
1386        n_ids = len(segmentation_ids)
1387
1388    masks = torch.from_numpy(masks)
1389
1390    one_hot_shape = (n_ids + 1,) + masks.shape
1391    masks = masks.unsqueeze(0)  # add dimension to scatter
1392    masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:]
1393
1394    # add the extra singleton dimension to get shape NUM_OBJECTS x 1 x H x W
1395    masks = masks.unsqueeze(1)
1396    return masks

Convert the segmentation to one-hot encoded masks.

Arguments:
  • segmentation: The segmentation.
  • segmentation_ids: Optional subset of ids that will be used to subsample the masks. By default, computes the number of ids from the provided segmentation masks.
Returns:

The one-hot encoded masks.

def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1399def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1400    """Get a suitable block shape for chunking a given shape.
1401
1402    The primary use for this is determining chunk sizes for
1403    zarr arrays or block shapes for parallelization.
1404
1405    Args:
1406        shape: The image or volume shape.
1407
1408    Returns:
1409        The block shape.
1410    """
1411    ndim = len(shape)
1412    if ndim == 2:
1413        block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape))
1414    elif ndim == 3:
1415        block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape))
1416    else:
1417        raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.")
1418
1419    return block_shape

Get a suitable block shape for chunking a given shape.

The primary use for this is determining chunk sizes for zarr arrays or block shapes for parallelization.

Arguments:
  • shape: The image or volume shape.
Returns:

The block shape.

def micro_sam_info() -> None:
1422def micro_sam_info() -> None:
1423    """Display μSAM information using a rich console."""
1424    import psutil
1425    import platform
1426    import argparse
1427    from rich import progress
1428    from rich.panel import Panel
1429    from rich.table import Table
1430    from rich.console import Console
1431
1432    import torch
1433    import micro_sam
1434
1435    parser = argparse.ArgumentParser(description="μSAM Information Booth")
1436    parser.add_argument(
1437        "--download", nargs="+", metavar=("WHAT", "KIND"),
1438        help="Downloads the pretrained SAM models."
1439        "'--download models' -> downloads all pretrained models; "
1440        "'--download models vit_b_lm vit_b_em_organelles' -> downloads the listed models; "
1441        "'--download model/models vit_b_lm' -> downloads a single specified model."
1442    )
1443    args = parser.parse_args()
1444
1445    # Open up a new console.
1446    console = Console()
1447
1448    # The header for information CLI.
1449    console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center")
1450    console.print("-" * console.width)
1451
1452    # μSAM version panel.
1453    console.print(
1454        Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True)
1455    )
1456
1457    # The documentation link panel.
1458    console.print(
1459        Panel(
1460            "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n"
1461            "https://computational-cell-analytics.github.io/micro-sam", title="Documentation"
1462        )
1463    )
1464
1465    # The publication panel.
1466    console.print(
1467        Panel(
1468            "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n"
1469            "https://www.nature.com/articles/s41592-024-02580-4", title="Publication"
1470        )
1471    )
1472
1473    # Creating a cache directory when users' run `micro_sam.info`.
1474    cache_dir = get_cache_directory()
1475    os.makedirs(cache_dir, exist_ok=True)
1476
1477    # The cache directory panel.
1478    console.print(
1479        Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{cache_dir}", title="Cache Directory")
1480    )
1481
1482    # We have a simple versioning logic here (which is what I'll follow here for mapping model versions).
1483    available_models = []
1484    for model_name, model_path in models().urls.items():  # We filter out the decoder models.
1485        if model_name.endswith("decoder"):
1486            continue
1487
1488        if "https://dl.fbaipublicfiles.com/segment_anything/" in model_path:  # Valid v1 SAM models.
1489            available_models.append(model_name)
1490
1491        if "https://owncloud.gwdg.de/" in model_path:  # Our own hosted models (in their v1 mode quite often)
1492            if model_name == "vit_t":  # MobileSAM model.
1493                available_models.append(model_name)
1494            else:
1495                available_models.append(f"{model_name} (v1)")
1496
1497        # Now for our models, the BioImageIO ModelZoo upload structure is such that:
1498        # '/1/files' corresponds to v2 models.
1499        # '/1.1/files' corresponds to v3 models.
1500        # '/1.2/files' corresponds to v4 models.
1501        if "/1/files" in model_path:
1502            available_models.append(f"{model_name} (v2)")
1503        if "/1.1/files" in model_path:
1504            available_models.append(f"{model_name} (v3)")
1505        if "/1.2/files" in model_path:
1506            available_models.append(f"{model_name} (v4)")
1507
1508    model_list = "\n".join(available_models)
1509
1510    # The available models panel.
1511    console.print(
1512        Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models")
1513    )
1514
1515    # The system information table.
1516    total_memory = psutil.virtual_memory().total / (1024 ** 3)
1517    table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True)
1518    table.add_column("Property")
1519    table.add_column("Value", style="bold #56B4E9")
1520    table.add_row("System", platform.system())
1521    table.add_row("Node Name", platform.node())
1522    table.add_row("Release", platform.release())
1523    table.add_row("Version", platform.version())
1524    table.add_row("Machine", platform.machine())
1525    table.add_row("Processor", platform.processor())
1526    table.add_row("Platform", platform.platform())
1527    table.add_row("Total RAM (GB)", f"{total_memory:.2f}")
1528    console.print(table)
1529
1530    # The device information and check for available GPU acceleration.
1531    default_device = _get_default_device()
1532
1533    if default_device == "cuda":
1534        device_index = torch.cuda.current_device()
1535        device_name = torch.cuda.get_device_name(device_index)
1536        console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information"))
1537    elif default_device == "mps":
1538        console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information"))
1539    else:
1540        console.print(
1541            Panel(
1542                "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]",
1543                title="Device Information"
1544            )
1545        )
1546
1547    # The section allowing to download models.
1548    # NOTE: In future, can be extended to download sample data.
1549    if args.download:
1550        download_provided_args = [t.lower() for t in args.download]
1551        mode, *model_types = download_provided_args
1552
1553        if mode not in {"models", "model"}:
1554            console.print(f"[red]Unknown option for --download: {mode}[/]")
1555            return
1556
1557        if mode in ["model", "models"] and not model_types:  # If user did not specify, we will download all models.
1558            download_list = available_models
1559        else:
1560            download_list = model_types
1561            incorrect_models = [m for m in download_list if m not in available_models]
1562            if incorrect_models:
1563                console.print(Panel("[red]Unknown model(s):[/] " + ", ".join(incorrect_models), title="Download Error"))
1564                return
1565
1566        with progress.Progress(
1567            progress.SpinnerColumn(),
1568            progress.TextColumn("[progress.description]{task.description}"),
1569            progress.BarColumn(bar_width=None),
1570            "[progress.percentage]{task.percentage:>3.0f}%",
1571            progress.TimeRemainingColumn(),
1572            console=console,
1573        ) as prog:
1574            task = prog.add_task("[green]Downloading μSAM models…", total=len(download_list))
1575            for model_type in download_list:
1576                prog.update(task, description=f"Downloading [cyan]{model_type}[/]…")
1577                _download_sam_model(model_type=model_type)
1578                prog.advance(task)
1579
1580        console.print(Panel("[bold green] Downloads complete![/]", title="Finished"))

Display μSAM information using a rich console.

def mask_data_to_segmentation( masks: List[Dict[str, Any]], shape: Optional[Tuple[int, int]] = None, min_object_size: int = 0, max_object_size: Optional[int] = None, label_masks: bool = True, with_background: bool = False, merge_exclusively: bool = True) -> numpy.ndarray:
1774def mask_data_to_segmentation(
1775    masks: List[Dict[str, Any]],
1776    shape: Optional[Tuple[int, int]] = None,
1777    min_object_size: int = 0,
1778    max_object_size: Optional[int] = None,
1779    label_masks: bool = True,
1780    with_background: bool = False,
1781    merge_exclusively: bool = True,
1782) -> np.ndarray:
1783    """Convert the output of the automatic mask generation to an instance segmentation.
1784
1785    Args:
1786        masks: The outputs generated by `AutomaticMaskGenerator`, other classes from `micro_sam.instance_segmentation`,
1787            or from `micro_sam.inference` functions. Only supported for output_mode=binary_mask.
1788        shape: The shape of the output segmentation. If None, it will be derived from the mask input.
1789            If the mask where predicted with tiling then the shape must be given.
1790        min_object_size: The minimal size of an object in pixels. By default, set to '0'.
1791        max_object_size: The maximal size of an object in pixels.
1792        label_masks: Whether to apply connected components to the result before removing small objects.
1793            By default, set to 'True'.
1794        with_background: Whether to remove the largest object, which often covers the background for AMG.
1795        merge_exclusively: Whether to exclude previous merged masks from merging.
1796
1797    Returns:
1798        The instance segmentation.
1799    """
1800    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
1801    if shape is None:
1802        shape = next(iter(masks))["segmentation"].shape
1803    segmentation = np.zeros(shape, dtype="uint32")
1804
1805    def require_numpy(mask):
1806        return mask.cpu().numpy() if torch.is_tensor(mask) else mask
1807
1808    seg_id = 1
1809    for mask_data in masks:
1810        area = mask_data["area"]
1811        if (area < min_object_size) or (max_object_size is not None and area > max_object_size):
1812            continue
1813
1814        this_mask = require_numpy(mask_data["segmentation"])
1815        this_seg_id = mask_data.get("seg_id", seg_id)
1816        if "global_bbox" in mask_data:
1817            bb = mask_data["bbox"]
1818            bb = np.s_[bb[1]:bb[1] + bb[3], bb[0]:bb[0] + bb[2]]
1819            global_bb = mask_data["global_bbox"]
1820            global_bb = np.s_[global_bb[1]:global_bb[1] + global_bb[3], global_bb[0]:global_bb[0] + global_bb[2]]
1821            if merge_exclusively:
1822                this_mask = np.logical_and(this_mask[bb], segmentation[global_bb] == 0)
1823            else:
1824                this_mask = this_mask[bb]
1825            segmentation[global_bb][this_mask] = this_seg_id
1826        else:
1827            if merge_exclusively:
1828                this_mask = np.logical_and(this_mask, segmentation == 0)
1829            segmentation[this_mask] = this_seg_id
1830        seg_id = this_seg_id + 1
1831
1832    block_shape = (512, 512)
1833    if label_masks:
1834        segmentation_cc = np.zeros_like(segmentation, dtype=segmentation.dtype)
1835        segmentation_cc = parallel_impl.label(segmentation, out=segmentation_cc, block_shape=block_shape)
1836        segmentation = segmentation_cc
1837
1838    seg_ids, sizes = parallel_impl.unique(segmentation, return_counts=True, block_shape=block_shape)
1839    filter_ids = seg_ids[sizes < min_object_size]
1840    if with_background:
1841        bg_id = seg_ids[np.argmax(sizes)]
1842        filter_ids = np.concatenate([filter_ids, [bg_id]])
1843
1844    filter_mask = np.zeros(segmentation.shape, dtype="bool")
1845    filter_mask = parallel_impl.isin(segmentation, filter_ids, out=filter_mask, block_shape=block_shape)
1846    segmentation[filter_mask] = 0
1847    parallel_impl.relabel_consecutive(segmentation, block_shape=block_shape)[0]
1848
1849    return segmentation

Convert the output of the automatic mask generation to an instance segmentation.

Arguments:
  • masks: The outputs generated by AutomaticMaskGenerator, other classes from micro_sam.instance_segmentation, or from micro_sam.inference functions. Only supported for output_mode=binary_mask.
  • shape: The shape of the output segmentation. If None, it will be derived from the mask input. If the mask where predicted with tiling then the shape must be given.
  • min_object_size: The minimal size of an object in pixels. By default, set to '0'.
  • max_object_size: The maximal size of an object in pixels.
  • label_masks: Whether to apply connected components to the result before removing small objects. By default, set to 'True'.
  • with_background: Whether to remove the largest object, which often covers the background for AMG.
  • merge_exclusively: Whether to exclude previous merged masks from merging.
Returns:

The instance segmentation.

def apply_nms( predictions: List[Dict[str, Any]], min_size: int, shape: Optional[Tuple[int, int]] = None, perform_box_nms: bool = False, nms_thresh: float = 0.9, max_size: Optional[int] = None, intersection_over_min: bool = False) -> numpy.ndarray:
1852def apply_nms(
1853    predictions: List[Dict[str, Any]],
1854    min_size: int,
1855    shape: Optional[Tuple[int, int]] = None,
1856    perform_box_nms: bool = False,
1857    nms_thresh: float = 0.9,
1858    max_size: Optional[int] = None,
1859    intersection_over_min: bool = False,
1860) -> np.ndarray:
1861    """Apply non-maximum suppression to mask predictions from a segment anything model.
1862
1863    Args:
1864        predictions: The mask predictions from SAM.
1865        min_size: The minimum mask size to keep in the output.
1866        shape: The shape of the output segmentation.
1867            For tiled predictions this is inferred from the tile-local mask shapes if it is not passed.
1868        perform_box_nms: Whether to perform NMS on the box coordinates or on the masks.
1869        nms_thresh: The threshold for filtering out objects in NMS.
1870        max_size: The maximum mask size to keep in the output.
1871        intersection_over_min: Whether to perform intersection over the minimum overlap shape
1872            or to perform intersection over union.
1873
1874    Returns:
1875        The segmentation obtained from merging the masks left after NMS.
1876    """
1877    # Check if the input comes with a 'global_bbox' attribute. If it does, then the predictions are from
1878    # a tiled prediction. In this case, we have to take the coordinates w.r.t. the tiling into account.
1879    is_tiled = "global_bbox" in predictions[0]
1880    if is_tiled and shape is None:
1881        shape = _infer_tiled_shape(predictions)
1882
1883    masks = [pred["segmentation"] for pred in predictions]
1884    nms_masks = None if is_tiled else torch.cat([mask[None] for mask in masks], dim=0)
1885    data = amg_utils.MaskData(masks=masks, iou_preds=torch.tensor([pred["predicted_iou"] for pred in predictions]))
1886    data["boxes"] = torch.tensor(np.array([pred["bbox"] for pred in predictions]))
1887    data["area"] = [int(mask.sum()) for mask in data["masks"]]
1888    data["stability_scores"] = torch.tensor([pred["stability_score"] for pred in predictions])
1889    if is_tiled:
1890        data["global_boxes"] = torch.tensor(np.array([pred["global_bbox"] for pred in predictions]))
1891
1892    if min_size > 0:
1893        keep_by_size = torch.tensor(
1894            [i for i, area in enumerate(data["area"]) if area > min_size], dtype=torch.long,
1895        )
1896        data.filter(keep_by_size)
1897        if nms_masks is not None:
1898            nms_masks = nms_masks[keep_by_size]
1899
1900    if max_size is not None:
1901        keep_by_size = torch.tensor([i for i, area in enumerate(data["area"]) if area < max_size])
1902        data.filter(keep_by_size)
1903        if nms_masks is not None:
1904            nms_masks = nms_masks[keep_by_size]
1905
1906    if len(data["masks"]) == 0:
1907        if shape is None:
1908            shape = predictions[0]["segmentation"].shape
1909        return np.zeros(shape, dtype="uint32")
1910
1911    scores = data["iou_preds"] * data["stability_scores"]
1912    boxes = _xywh_to_xyxy(data["global_boxes"] if is_tiled else data["boxes"])
1913    if perform_box_nms:
1914        assert not intersection_over_min  # not implemented
1915        keep_by_nms = batched_nms(
1916            boxes,
1917            scores,
1918            torch.zeros_like(data["boxes"][:, 0]),  # categories
1919            iou_threshold=nms_thresh,
1920        )
1921    elif is_tiled:
1922        keep_by_nms = _batched_tiled_mask_nms(
1923            masks=data["masks"],
1924            boxes=data["boxes"],
1925            global_boxes=data["global_boxes"],
1926            scores=scores,
1927            nms_thresh=nms_thresh,
1928            intersection_over_min=intersection_over_min,
1929        )
1930    else:
1931        keep_by_nms = _batched_mask_nms(
1932            masks=nms_masks,
1933            boxes=boxes,
1934            scores=scores,
1935            nms_thresh=nms_thresh,
1936            intersection_over_min=intersection_over_min,
1937        )
1938    data.filter(keep_by_nms)
1939
1940    if is_tiled:
1941        mask_data = [
1942            {"segmentation": mask, "area": area, "bbox": box, "global_bbox": global_box}
1943            for mask, area, box, global_box in zip(data["masks"], data["area"], data["boxes"], data["global_boxes"])
1944        ]
1945    else:
1946        mask_data = [
1947            {"segmentation": mask, "area": area, "bbox": box}
1948            for mask, area, box in zip(data["masks"], data["area"], data["boxes"])
1949        ]
1950
1951    if shape is None:
1952        shape = predictions[0]["segmentation"].shape
1953    if mask_data:
1954        segmentation = mask_data_to_segmentation(mask_data, shape=shape, min_object_size=min_size)
1955    else:  # In case all objects have been filtered out due to size filtering.
1956        segmentation = np.zeros(shape, dtype="uint32")
1957
1958    return segmentation

Apply non-maximum suppression to mask predictions from a segment anything model.

Arguments:
  • predictions: The mask predictions from SAM.
  • min_size: The minimum mask size to keep in the output.
  • shape: The shape of the output segmentation. For tiled predictions this is inferred from the tile-local mask shapes if it is not passed.
  • perform_box_nms: Whether to perform NMS on the box coordinates or on the masks.
  • nms_thresh: The threshold for filtering out objects in NMS.
  • max_size: The maximum mask size to keep in the output.
  • intersection_over_min: Whether to perform intersection over the minimum overlap shape or to perform intersection over union.
Returns:

The segmentation obtained from merging the masks left after NMS.