micro_sam.util

Helper functions for downloading Segment Anything models and predicting image embeddings.

   1"""
   2Helper functions for downloading Segment Anything models and predicting image embeddings.
   3"""
   4
   5import os
   6import pickle
   7import hashlib
   8import warnings
   9from pathlib import Path
  10from collections import OrderedDict
  11from typing import Any, Dict, Iterable, Optional, Tuple, Union, Callable
  12
  13import zarr
  14import vigra
  15import torch
  16import pooch
  17import xxhash
  18import numpy as np
  19import imageio.v3 as imageio
  20from skimage.measure import regionprops
  21from skimage.segmentation import relabel_sequential
  22
  23from elf.io import open_file
  24
  25from nifty.tools import blocking
  26
  27from .__version__ import __version__
  28from . import models as custom_models
  29
  30try:
  31    # Avoid import warnigns from mobile_sam
  32    with warnings.catch_warnings():
  33        warnings.simplefilter("ignore")
  34        from mobile_sam import sam_model_registry, SamPredictor
  35    VIT_T_SUPPORT = True
  36except ImportError:
  37    from segment_anything import sam_model_registry, SamPredictor
  38    VIT_T_SUPPORT = False
  39
  40try:
  41    from napari.utils import progress as tqdm
  42except ImportError:
  43    from tqdm import tqdm
  44
  45# This is the default model used in micro_sam
  46# Currently it is set to vit_b_lm
  47_DEFAULT_MODEL = "vit_b_lm"
  48
  49# The valid model types. Each type corresponds to the architecture of the
  50# vision transformer used within SAM.
  51_MODEL_TYPES = ("vit_l", "vit_b", "vit_h", "vit_t")
  52
  53
  54# TODO define the proper type for image embeddings
  55ImageEmbeddings = Dict[str, Any]
  56"""@private"""
  57
  58
  59def get_cache_directory() -> None:
  60    """Get micro-sam cache directory location.
  61
  62    Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
  63    """
  64    default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam"))
  65    cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory))
  66    return cache_directory
  67
  68
  69#
  70# Functionality for model download and export
  71#
  72
  73
  74def microsam_cachedir() -> None:
  75    """Return the micro-sam cache directory.
  76
  77    Returns the top level cache directory for micro-sam models and sample data.
  78
  79    Every time this function is called, we check for any user updates made to
  80    the MICROSAM_CACHEDIR os environment variable since the last time.
  81    """
  82    cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam")
  83    return cache_directory
  84
  85
  86def models():
  87    """Return the segmentation models registry.
  88
  89    We recreate the model registry every time this function is called,
  90    so any user changes to the default micro-sam cache directory location
  91    are respected.
  92    """
  93
  94    # We use xxhash to compute the hash of the models, see
  95    # https://github.com/computational-cell-analytics/micro-sam/issues/283
  96    # (It is now a dependency, so we don't provide the sha256 fallback anymore.)
  97    # To generate the xxh128 hash:
  98    #     xxh128sum filename
  99    encoder_registry = {
 100        # The default segment anything models:
 101        "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a",
 102        "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2",
 103        "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b",
 104        # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM.
 105        "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0",
 106        # The current version of our models in the modelzoo.
 107        # LM generalist models:
 108        "vit_l_lm": "xxh128:fc32ea6f7fcc7eb02737d1304f81f5f2",
 109        "vit_b_lm": "xxh128:8fd5806be3c3ba213e19a709d6d1495f",
 110        "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e",
 111        # EM models:
 112        "vit_l_em_organelles": "xxh128:096c9695966803ca6fde24f4c1e3c3fb",
 113        "vit_b_em_organelles": "xxh128:f6f6593aeecd0e15a07bdac86360b6cc",
 114        "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9",
 115        # Histopathology models:
 116        "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974",
 117        "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2",
 118        "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e",
 119        # Medical Imaging models:
 120        "vit_b_medical_imaging": "xxh128:5be672f1458263a9edc9fd40d7f56ac1",
 121    }
 122    # Additional decoders for instance segmentation.
 123    decoder_registry = {
 124        # LM generalist models:
 125        "vit_l_lm_decoder": "xxh128:779b5a50ecc6d46d495753fba8717f2f",
 126        "vit_b_lm_decoder": "xxh128:9f580a96984b3085389ced5d9a4ae75d",
 127        "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823",
 128        # EM models:
 129        "vit_l_em_organelles_decoder": "xxh128:d60fd96bd6060856f6430f29e42568fb",
 130        "vit_b_em_organelles_decoder": "xxh128:b2d4dcffb99f76d83497d39ee500088f",
 131        "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44",
 132        # Histopathology models:
 133        "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213",
 134        "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04",
 135        "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47",
 136    }
 137    registry = {**encoder_registry, **decoder_registry}
 138
 139    encoder_urls = {
 140        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
 141        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
 142        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
 143        "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
 144        "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l.pt",
 145        "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b.pt",
 146        "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt",
 147        "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt",  # noqa
 148        "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt",
 149        "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt",  # noqa
 150        "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download",
 151        "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download",
 152        "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download",
 153        "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/AB69HGhj8wuozXQ/download",
 154    }
 155
 156    decoder_urls = {
 157        "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l_decoder.pt",  # noqa
 158        "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b_decoder.pt",  # noqa
 159        "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt",  # noqa
 160        "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt",  # noqa
 161        "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt",  # noqa
 162        "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt",  # noqa
 163        "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download",
 164        "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download",
 165        "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download",
 166    }
 167    urls = {**encoder_urls, **decoder_urls}
 168
 169    models = pooch.create(
 170        path=os.path.join(microsam_cachedir(), "models"),
 171        base_url="",
 172        registry=registry,
 173        urls=urls,
 174    )
 175    return models
 176
 177
 178def _get_default_device():
 179    # check that we're in CI and use the CPU if we are
 180    # otherwise the tests may run out of memory on MAC if MPS is used.
 181    if os.getenv("GITHUB_ACTIONS") == "true":
 182        return "cpu"
 183    # Use cuda enabled gpu if it's available.
 184    if torch.cuda.is_available():
 185        device = "cuda"
 186    # As second priority use mps.
 187    # See https://pytorch.org/docs/stable/notes/mps.html for details
 188    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
 189        print("Using apple MPS device.")
 190        device = "mps"
 191    # Use the CPU as fallback.
 192    else:
 193        device = "cpu"
 194    return device
 195
 196
 197def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
 198    """Get the torch device.
 199
 200    If no device is passed the default device for your system is used.
 201    Else it will be checked if the device you have passed is supported.
 202
 203    Args:
 204        device: The input device.
 205
 206    Returns:
 207        The device.
 208    """
 209    if device is None or device == "auto":
 210        device = _get_default_device()
 211    else:
 212        device_type = device if isinstance(device, str) else device.type
 213        if device_type.lower() == "cuda":
 214            if not torch.cuda.is_available():
 215                raise RuntimeError("PyTorch CUDA backend is not available.")
 216        elif device_type.lower() == "mps":
 217            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
 218                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
 219        elif device_type.lower() == "cpu":
 220            pass  # cpu is always available
 221        else:
 222            raise RuntimeError(
 223                f"Unsupported device: {device}\n"
 224                "Please choose from 'cpu', 'cuda', or 'mps'."
 225            )
 226
 227    return device
 228
 229
 230def _available_devices():
 231    available_devices = []
 232    for i in ["cuda", "mps", "cpu"]:
 233        try:
 234            device = get_device(i)
 235        except RuntimeError:
 236            pass
 237        else:
 238            available_devices.append(device)
 239    return available_devices
 240
 241
 242# We write a custom unpickler that skips objects that cannot be found instead of
 243# throwing an AttributeError or ModueNotFoundError.
 244# NOTE: since we just want to unpickle the model to load its weights these errors don't matter.
 245# See also https://stackoverflow.com/questions/27732354/unable-to-load-files-using-pickle-and-multiple-modules
 246class _CustomUnpickler(pickle.Unpickler):
 247    def find_class(self, module, name):
 248        try:
 249            return super().find_class(module, name)
 250        except (AttributeError, ModuleNotFoundError) as e:
 251            warnings.warn(f"Did not find {module}:{name} and will skip it, due to error {e}")
 252            return None
 253
 254
 255def _compute_hash(path, chunk_size=8192):
 256    hash_obj = xxhash.xxh128()
 257    with open(path, "rb") as f:
 258        chunk = f.read(chunk_size)
 259        while chunk:
 260            hash_obj.update(chunk)
 261            chunk = f.read(chunk_size)
 262    hash_val = hash_obj.hexdigest()
 263    return f"xxh128:{hash_val}"
 264
 265
 266# Load the state from a checkpoint.
 267# The checkpoint can either contain a sam encoder state
 268# or it can be a checkpoint for model finetuning.
 269def _load_checkpoint(checkpoint_path):
 270    # Over-ride the unpickler with our custom one.
 271    # This enables imports from torch_em checkpoints even if it cannot be fully unpickled.
 272    custom_pickle = pickle
 273    custom_pickle.Unpickler = _CustomUnpickler
 274
 275    state = torch.load(checkpoint_path, map_location="cpu", pickle_module=custom_pickle)
 276    if "model_state" in state:
 277        # Copy the model weights from torch_em's training format.
 278        model_state = state["model_state"]
 279        sam_prefix = "sam."
 280        model_state = OrderedDict(
 281            [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()]
 282        )
 283    else:
 284        model_state = state
 285
 286    return state, model_state
 287
 288
 289def get_sam_model(
 290    model_type: str = _DEFAULT_MODEL,
 291    device: Optional[Union[str, torch.device]] = None,
 292    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 293    return_sam: bool = False,
 294    return_state: bool = False,
 295    peft_kwargs: Optional[Dict] = None,
 296    flexible_load_checkpoint: bool = False,
 297    progress_bar_factory: Optional[Callable] = None,
 298    **model_kwargs,
 299) -> SamPredictor:
 300    r"""Get the Segment Anything Predictor.
 301
 302    This function will download the required model or load it from the cached weight file.
 303    This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
 304    The name of the requested model can be set via `model_type`.
 305    See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
 306    for an overview of the available models
 307
 308    Alternatively this function can also load a model from weights stored in a local filepath.
 309    The corresponding file path is given via `checkpoint_path`. In this case `model_type`
 310    must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
 311    a SAM model with vit_b encoder.
 312
 313    By default the models are downloaded to a folder named 'micro_sam/models'
 314    inside your default cache directory, eg:
 315    * Mac: ~/Library/Caches/<AppName>
 316    * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined.
 317    * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache
 318    See the pooch.os_cache() documentation for more details:
 319    https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
 320
 321    Args:
 322        model_type: The Segment Anything model to use. Will use the standard `vit_l` model by default.
 323            To get a list of all available model names you can call `get_model_names`.
 324        device: The device for the model. If none is given will use GPU if available.
 325        checkpoint_path: The path to a file with weights that should be used instead of using the
 326            weights corresponding to `model_type`. If given, `model_type` must match the architecture
 327            corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder
 328            then `model_type` must be given as "vit_b".
 329        return_sam: Return the sam model object as well as the predictor.
 330        return_state: Return the unpickled checkpoint state.
 331        peft_kwargs: Keyword arguments for th PEFT wrapper class.
 332        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
 333        model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
 334        progress_bar_factory: A function to create a progress bar for the model download.
 335
 336    Returns:
 337        The segment anything predictor.
 338    """
 339    device = get_device(device)
 340
 341    # We support passing a local filepath to a checkpoint.
 342    # In this case we do not download any weights but just use the local weight file,
 343    # as it is, without copying it over anywhere or checking it's hashes.
 344
 345    # checkpoint_path has not been passed, we download a known model and derive the correct
 346    # URL from the model_type. If the model_type is invalid pooch will raise an error.
 347    _provided_checkpoint_path = checkpoint_path is not None
 348    if checkpoint_path is None:
 349        model_registry = models()
 350
 351        progress_bar = True
 352        # Check if we have to download the model.
 353        # If we do and have a progress bar factory, then we over-write the progress bar.
 354        if not os.path.exists(os.path.join(get_cache_directory(), model_type)) and progress_bar_factory is not None:
 355            progress_bar = progress_bar_factory(model_type)
 356
 357        checkpoint_path = model_registry.fetch(model_type, progressbar=progress_bar)
 358        if not isinstance(progress_bar, bool):  # Close the progress bar when the task finishes.
 359            progress_bar.close()
 360
 361        model_hash = model_registry.registry[model_type]
 362
 363        # If we have a custom model then we may also have a decoder checkpoint.
 364        # Download it here, so that we can add it to the state.
 365        decoder_name = f"{model_type}_decoder"
 366        decoder_path = model_registry.fetch(
 367            decoder_name, progressbar=True
 368        ) if decoder_name in model_registry.registry else None
 369
 370    # checkpoint_path has been passed, we use it instead of downloading a model.
 371    else:
 372        # Check if the file exists and raise an error otherwise.
 373        # We can't check any hashes here, and we don't check if the file is actually a valid weight file.
 374        # (If it isn't the model creation will fail below.)
 375        if not os.path.exists(checkpoint_path):
 376            raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.")
 377        model_hash = _compute_hash(checkpoint_path)
 378        decoder_path = None
 379
 380    # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped
 381    # before calling sam_model_registry.
 382    abbreviated_model_type = model_type[:5]
 383    if abbreviated_model_type not in _MODEL_TYPES:
 384        raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}")
 385    if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
 386        raise RuntimeError(
 387            "'mobile_sam' is required for the vit-tiny. "
 388            "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
 389        )
 390
 391    state, model_state = _load_checkpoint(checkpoint_path)
 392
 393    if _provided_checkpoint_path:
 394        # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type'
 395        # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination.
 396        from micro_sam.models.build_sam import _validate_model_type
 397        _provided_model_type = _validate_model_type(model_state)
 398
 399        # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type'
 400        # Otherwise replace 'abbreviated_model_type' with the later.
 401        if abbreviated_model_type != _provided_model_type:
 402            # Printing the message below to avoid any filtering of warnings on user's end.
 403            print(
 404                f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', "
 405                f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. "
 406                f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. "
 407                "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future."
 408            )
 409
 410        # Replace the extracted 'abbreviated_model_type' subjected to the model weights.
 411        abbreviated_model_type = _provided_model_type
 412
 413    # Whether to update parameters necessary to initialize the model
 414    if model_kwargs:  # Checks whether model_kwargs have been provided or not
 415        if abbreviated_model_type == "vit_t":
 416            raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.")
 417        sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs)
 418
 419    else:
 420        sam = sam_model_registry[abbreviated_model_type]()
 421
 422    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
 423    # Overwrites the SAM model by freezing the backbone and allow PEFT.
 424    if peft_kwargs and isinstance(peft_kwargs, dict):
 425        # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference.
 426        peft_kwargs.pop("quantize", None)
 427
 428        if abbreviated_model_type == "vit_t":
 429            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
 430
 431        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
 432    # In case the model checkpoints have some issues when it is initialized with different parameters than default.
 433    if flexible_load_checkpoint:
 434        sam = _handle_checkpoint_loading(sam, model_state)
 435    else:
 436        sam.load_state_dict(model_state)
 437    sam.to(device=device)
 438
 439    predictor = SamPredictor(sam)
 440    predictor.model_type = abbreviated_model_type
 441    predictor._hash = model_hash
 442    predictor.model_name = model_type
 443    predictor.checkpoint_path = checkpoint_path
 444
 445    # Add the decoder to the state if we have one and if the state is returned.
 446    if decoder_path is not None and return_state:
 447        state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False)
 448
 449    if return_sam and return_state:
 450        return predictor, sam, state
 451    if return_sam:
 452        return predictor, sam
 453    if return_state:
 454        return predictor, state
 455    return predictor
 456
 457
 458def _handle_checkpoint_loading(sam, model_state):
 459    # Whether to handle the mismatch issues in a bit more elegant way.
 460    # eg. while training for multi-class semantic segmentation in the mask encoder,
 461    # parameters are updated - leading to "size mismatch" errors
 462
 463    new_state_dict = {}  # for loading matching parameters
 464    mismatched_layers = []  # for tracking mismatching parameters
 465
 466    reference_state = sam.state_dict()
 467
 468    for k, v in model_state.items():
 469        if k in reference_state:  # This is done to get rid of unwanted layers from pretrained SAM.
 470            if reference_state[k].size() == v.size():
 471                new_state_dict[k] = v
 472            else:
 473                mismatched_layers.append(k)
 474
 475    reference_state.update(new_state_dict)
 476
 477    if len(mismatched_layers) > 0:
 478        warnings.warn(f"The layers with size mismatch: {mismatched_layers}")
 479
 480    for mlayer in mismatched_layers:
 481        if 'weight' in mlayer:
 482            torch.nn.init.kaiming_uniform_(reference_state[mlayer])
 483        elif 'bias' in mlayer:
 484            reference_state[mlayer].zero_()
 485
 486    sam.load_state_dict(reference_state)
 487
 488    return sam
 489
 490
 491def export_custom_sam_model(
 492    checkpoint_path: Union[str, os.PathLike],
 493    model_type: str,
 494    save_path: Union[str, os.PathLike],
 495    with_segmentation_decoder: bool = False,
 496) -> None:
 497    """Export a finetuned Segment Anything Model to the standard model format.
 498
 499    The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
 500
 501    Args:
 502        checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
 503        model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
 504        save_path: Where to save the exported model.
 505        with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well.
 506            If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
 507    """
 508    _, state = get_sam_model(
 509        model_type=model_type, checkpoint_path=checkpoint_path, return_state=True, device="cpu",
 510    )
 511    model_state = state["model_state"]
 512    prefix = "sam."
 513    model_state = OrderedDict(
 514        [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]
 515    )
 516
 517    # Store the 'decoder_state' as well, if desired.
 518    if with_segmentation_decoder:
 519        if "decoder_state" not in state:
 520            raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.")
 521        decoder_state = state["decoder_state"]
 522        save_state = {"model_state": model_state, "decoder_state": decoder_state}
 523    else:
 524        save_state = model_state
 525
 526    torch.save(save_state, save_path)
 527
 528
 529def export_custom_qlora_model(
 530    checkpoint_path: Optional[Union[str, os.PathLike]],
 531    finetuned_path: Union[str, os.PathLike],
 532    model_type: str,
 533    save_path: Union[str, os.PathLike],
 534) -> None:
 535    """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
 536
 537    The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`.
 538
 539    Args:
 540        checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
 541        finetuned_path: The path to the new finetuned model, using QLoRA.
 542        model_type: The Segment Anything Model type corresponding to the checkpoint.
 543        save_path: Where to save the exported model.
 544    """
 545    # Step 1: Get the base SAM model: used to start finetuning from.
 546    _, sam = get_sam_model(
 547        model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True,
 548    )
 549
 550    # Step 2: Load the QLoRA-style finetuned model.
 551    ft_state, ft_model_state = _load_checkpoint(finetuned_path)
 552
 553    # Step 3: Get LoRA weights from QLoRA and retain all original parameters from the base SAM model.
 554    updated_model_state = {}
 555
 556    # - At first, we get all LoRA layers from the QLoRA-style finetuned model checkpoint.
 557    for k, v in ft_model_state.items():
 558        if k.find("w_b_linear") != -1 or k.find("w_a_linear") != -1:
 559            updated_model_state[k] = v
 560
 561    # - Next, we get all the remaining parameters from the base SAM model.
 562    for k, v in sam.state_dict().items():
 563        if k.find("attn.qkv.") != -1:
 564            k = k.replace("qkv", "qkv.qkv_proj")
 565            updated_model_state[k] = v
 566        else:
 567
 568            updated_model_state[k] = v
 569
 570    # - Finally, we replace the old model state with the new one (to retain other relevant stuff)
 571    ft_state['model_state'] = updated_model_state
 572
 573    # Step 4: Store the new "state" to "save_path"
 574    torch.save(ft_state, save_path)
 575
 576
 577def get_model_names() -> Iterable:
 578    model_registry = models()
 579    model_names = model_registry.registry.keys()
 580    return model_names
 581
 582
 583#
 584# Functionality for precomputing image embeddings.
 585#
 586
 587
 588def _to_image(input_):
 589    # we require the input to be uint8
 590    if input_.dtype != np.dtype("uint8"):
 591        # first normalize the input to [0, 1]
 592        input_ = input_.astype("float32") - input_.min()
 593        input_ = input_ / input_.max()
 594        # then bring to [0, 255] and cast to uint8
 595        input_ = (input_ * 255).astype("uint8")
 596
 597    if input_.ndim == 2:
 598        image = np.concatenate([input_[..., None]] * 3, axis=-1)
 599    elif input_.ndim == 3 and input_.shape[-1] == 3:
 600        image = input_
 601    else:
 602        raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image.")
 603
 604    # explicitly return a numpy array for compatibility with torchvision
 605    # because the input_ array could be something like dask array
 606    return np.array(image)
 607
 608
 609@torch.no_grad
 610def _compute_embeddings_batched(predictor, batched_images):
 611    predictor.reset_image()
 612    batched_tensors, original_sizes, input_sizes = [], [], []
 613
 614    # Apply proeprocessing to all images in the batch, and then stack them.
 615    # Note: after the transformation the images are all of the same size,
 616    # so they can be stacked and processed as a batch, even if the input images were of different size.
 617    for image in batched_images:
 618        tensor = predictor.transform.apply_image(image)
 619        tensor = torch.as_tensor(tensor, device=predictor.device)
 620        tensor = tensor.permute(2, 0, 1).contiguous()[None, :, :, :]
 621
 622        original_sizes.append(image.shape[:2])
 623        input_sizes.append(tensor.shape[-2:])
 624
 625        tensor = predictor.model.preprocess(tensor)
 626        batched_tensors.append(tensor)
 627
 628    batched_tensors = torch.cat(batched_tensors)
 629    features = predictor.model.image_encoder(batched_tensors)
 630
 631    predictor.original_size = original_sizes[-1]
 632    predictor.input_size = input_sizes[-1]
 633    predictor.features = features[-1]
 634    predictor.is_image_set = True
 635
 636    return features, original_sizes, input_sizes
 637
 638
 639def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size):
 640    tiling = blocking([0, 0], input_.shape[:2], tile_shape)
 641    n_tiles = tiling.numberOfBlocks
 642
 643    features = f.require_group("features")
 644    features.attrs["shape"] = input_.shape[:2]
 645    features.attrs["tile_shape"] = tile_shape
 646    features.attrs["halo"] = halo
 647
 648    pbar_init(n_tiles, "Compute Image Embeddings 2D tiled")
 649
 650    n_batches = int(np.ceil(n_tiles / batch_size))
 651    for batch_id in range(n_batches):
 652        tile_start = batch_id * batch_size
 653        tile_stop = min(tile_start + batch_size, n_tiles)
 654
 655        batched_images = []
 656        for tile_id in range(tile_start, tile_stop):
 657            tile = tiling.getBlockWithHalo(tile_id, list(halo))
 658            outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end))
 659            tile_input = _to_image(input_[outer_tile])
 660            batched_images.append(tile_input)
 661
 662        batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
 663        for i, tile_id in enumerate(range(tile_start, tile_stop)):
 664            tile_embeddings, original_size, input_size = batched_embeddings[i], original_sizes[i], input_sizes[i]
 665            # Unsqueeze the channel axis of the tile embeddings.
 666            tile_embeddings = tile_embeddings.unsqueeze(0)
 667            ds = features.create_dataset(
 668                str(tile_id), data=tile_embeddings.cpu().numpy(), compression="gzip", chunks=tile_embeddings.shape
 669            )
 670            ds.attrs["original_size"] = original_size
 671            ds.attrs["input_size"] = input_size
 672            pbar_update(1)
 673
 674    _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None)
 675    return features
 676
 677
 678def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size):
 679    assert input_.ndim == 3
 680
 681    shape = input_.shape[1:]
 682    tiling = blocking([0, 0], shape, tile_shape)
 683    n_tiles = tiling.numberOfBlocks
 684
 685    features = f.require_group("features")
 686    features.attrs["shape"] = shape
 687    features.attrs["tile_shape"] = tile_shape
 688    features.attrs["halo"] = halo
 689
 690    n_slices = input_.shape[0]
 691    pbar_init(n_tiles * n_slices, "Compute Image Embeddings 3D tiled")
 692
 693    # We batch across the z axis.
 694    n_batches = int(np.ceil(n_slices / batch_size))
 695
 696    for tile_id in range(n_tiles):
 697        tile = tiling.getBlockWithHalo(tile_id, list(halo))
 698        outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end))
 699
 700        ds = None
 701        for batch_id in range(n_batches):
 702            z_start = batch_id * batch_size
 703            z_stop = min(z_start + batch_size, n_slices)
 704
 705            batched_images = []
 706            for z in range(z_start, z_stop):
 707                tile_input = _to_image(input_[z][outer_tile])
 708                batched_images.append(tile_input)
 709
 710            batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
 711            for i, z in enumerate(range(z_start, z_stop)):
 712                tile_embeddings = batched_embeddings[i].unsqueeze(0)
 713                if ds is None:
 714                    shape = (n_slices,) + tile_embeddings.shape
 715                    chunks = (1,) + tile_embeddings.shape
 716                    ds = features.create_dataset(
 717                        str(tile_id), shape=shape, dtype="float32", compression="gzip", chunks=chunks
 718                    )
 719
 720                ds[z] = tile_embeddings.cpu().numpy()
 721                pbar_update(1)
 722
 723        ds.attrs["original_size"] = original_sizes[-1]
 724        ds.attrs["input_size"] = input_sizes[-1]
 725
 726    _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None)
 727
 728    return features
 729
 730
 731def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update):
 732    # Check if the embeddings are already cached.
 733    if save_path is not None and "input_size" in f.attrs:
 734        # In this case we load the embeddings.
 735        features = f["features"][:]
 736        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 737        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 738        # Also set the embeddings.
 739        set_precomputed(predictor, image_embeddings)
 740        return image_embeddings
 741
 742    pbar_init(1, "Compute Image Embeddings 2D")
 743    # Otherwise we have to compute the embeddings.
 744    predictor.reset_image()
 745    predictor.set_image(_to_image(input_))
 746    features = predictor.get_image_embedding().cpu().numpy()
 747    original_size = predictor.original_size
 748    input_size = predictor.input_size
 749    pbar_update(1)
 750
 751    # Save the embeddings if we have a save_path.
 752    if save_path is not None:
 753        f.create_dataset("features", data=features, compression="gzip", chunks=features.shape)
 754        _write_embedding_signature(
 755            f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size,
 756        )
 757
 758    image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 759    return image_embeddings
 760
 761
 762def _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size):
 763    # Check if the features are already computed.
 764    if "input_size" in f.attrs:
 765        features = f["features"]
 766        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 767        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 768        return image_embeddings
 769
 770    # Otherwise compute them. Note: saving happens automatically because we
 771    # always write the features to zarr. If no save path is given we use an in-memory zarr.
 772    features = _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size)
 773    image_embeddings = {"features": features, "input_size": None, "original_size": None}
 774    return image_embeddings
 775
 776
 777def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size):
 778    # Check if the embeddings are already fully cached.
 779    if save_path is not None and "input_size" in f.attrs:
 780        # In this case we load the embeddings.
 781        features = f["features"] if lazy_loading else f["features"][:]
 782        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 783        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 784        return image_embeddings
 785
 786    # Otherwise we have to compute the embeddings.
 787
 788    # First check if we have a save path or not and set things up accordingly.
 789    if save_path is None:
 790        features = []
 791        save_features = False
 792        partial_features = False
 793    else:
 794        save_features = True
 795        embed_shape = (1, 256, 64, 64)
 796        shape = (input_.shape[0],) + embed_shape
 797        chunks = (1,) + embed_shape
 798        if "features" in f:
 799            partial_features = True
 800            features = f["features"]
 801            if features.shape != shape or features.chunks != chunks:
 802                raise RuntimeError("Invalid partial features")
 803        else:
 804            partial_features = False
 805            features = f.create_dataset("features", shape=shape, chunks=chunks, dtype="float32")
 806
 807    # Initialize the pbar and batches.
 808    n_slices = input_.shape[0]
 809    pbar_init(n_slices, "Compute Image Embeddings 3D")
 810    n_batches = int(np.ceil(n_slices / batch_size))
 811
 812    for batch_id in range(n_batches):
 813        z_start = batch_id * batch_size
 814        z_stop = min(z_start + batch_size, n_slices)
 815
 816        batched_images, batched_z = [], []
 817        for z in range(z_start, z_stop):
 818            # Skip feature computation in case of partial features in non-zero slice.
 819            if partial_features and np.count_nonzero(features[z]) != 0:
 820                continue
 821            tile_input = _to_image(input_[z])
 822            batched_images.append(tile_input)
 823            batched_z.append(z)
 824
 825        batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
 826
 827        for z, embedding in zip(batched_z, batched_embeddings):
 828            embedding = embedding.unsqueeze(0)
 829            if save_features:
 830                features[z] = embedding.cpu().numpy()
 831            else:
 832                features.append(embedding.unsqueeze(0))
 833            pbar_update(1)
 834
 835    if save_features:
 836        _write_embedding_signature(
 837            f, input_, predictor, tile_shape=None, halo=None,
 838            input_size=input_sizes[-1], original_size=original_sizes[-1],
 839        )
 840    else:
 841        # Concatenate across the z axis.
 842        features = torch.cat(features).cpu().numpy()
 843
 844    image_embeddings = {"features": features, "input_size": input_sizes[-1], "original_size": original_sizes[-1]}
 845    return image_embeddings
 846
 847
 848def _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size):
 849    # Check if the features are already computed.
 850    if "input_size" in f.attrs:
 851        features = f["features"]
 852        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 853        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 854        return image_embeddings
 855
 856    # Otherwise compute them. Note: saving happens automatically because we
 857    # always write the features to zarr. If no save path is given we use an in-memory zarr.
 858    features = _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size)
 859    image_embeddings = {"features": features, "input_size": None, "original_size": None}
 860    return image_embeddings
 861
 862
 863def _compute_data_signature(input_):
 864    data_signature = hashlib.sha1(np.asarray(input_).tobytes()).hexdigest()
 865    return data_signature
 866
 867
 868# Create all metadata that is stored along with the embeddings.
 869def _get_embedding_signature(input_, predictor, tile_shape, halo, data_signature=None):
 870    if data_signature is None:
 871        data_signature = _compute_data_signature(input_)
 872
 873    signature = {
 874        "data_signature": data_signature,
 875        "tile_shape": tile_shape if tile_shape is None else list(tile_shape),
 876        "halo": halo if halo is None else list(halo),
 877        "model_type": predictor.model_type,
 878        "model_name": predictor.model_name,
 879        "micro_sam_version": __version__,
 880        "model_hash": getattr(predictor, "_hash", None),
 881    }
 882    return signature
 883
 884
 885# Note: the input size and orginal size are different if embeddings are tiled or not.
 886# That's why we do not include them in the main signature that is being checked
 887# (_get_embedding_signature), but just add it for serialization here.
 888def _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size, original_size):
 889    signature = _get_embedding_signature(input_, predictor, tile_shape, halo)
 890    signature.update({"input_size": input_size, "original_size": original_size})
 891    for key, val in signature.items():
 892        f.attrs[key] = val
 893
 894
 895def _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo):
 896    # We may have an empty zarr file that was already created to save the embeddings in.
 897    # In this case the embeddings will be computed and we don't need to perform any checks.
 898    if "input_size" not in f.attrs:
 899        return
 900
 901    signature = _get_embedding_signature(input_, predictor, tile_shape, halo)
 902    for key, val in signature.items():
 903        # Check whether the key is missing from the attrs or if the value is not matching.
 904        if key not in f.attrs or f.attrs[key] != val:
 905            # These keys were recently added, so we don't want to fail yet if they don't
 906            # match in order to not invalidate previous embedding files.
 907            # Instead we just raise a warning. (For the version we probably also don't want to fail
 908            # i the future since it should not invalidate the embeddings).
 909            if key in ("micro_sam_version", "model_hash", "model_name"):
 910                warnings.warn(
 911                    f"The signature for {key} in embeddings file {save_path} has a mismatch: "
 912                    f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. "
 913                    "But please recompute them if model predictions don't look as expected."
 914                )
 915            else:
 916                raise RuntimeError(
 917                    f"Embeddings file {save_path} is invalid due to mismatch in {key}: "
 918                    f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file."
 919                )
 920
 921
 922# Helper function for optional external progress bars.
 923def handle_pbar(verbose, pbar_init, pbar_update):
 924    """@private"""
 925
 926    # Noop to provide dummy functions.
 927    def noop(*args):
 928        pass
 929
 930    if verbose and pbar_init is None:  # we are verbose and don't have an external progress bar.
 931        assert pbar_update is None  # avoid inconsistent state of callbacks
 932
 933        # Create our own progress bar and callbacks
 934        pbar = tqdm()
 935
 936        def pbar_init(total, description):
 937            pbar.total = total
 938            pbar.set_description(description)
 939
 940        def pbar_update(update):
 941            pbar.update(update)
 942
 943        def pbar_close():
 944            pbar.close()
 945
 946    elif verbose and pbar_init is not None:  # external pbar -> we don't have to do anything
 947        assert pbar_update is not None
 948        pbar = None
 949        pbar_close = noop
 950
 951    else:  # we are not verbose, do nothing
 952        pbar = None
 953        pbar_init, pbar_update, pbar_close = noop, noop, noop
 954
 955    return pbar, pbar_init, pbar_update, pbar_close
 956
 957
 958def precompute_image_embeddings(
 959    predictor: SamPredictor,
 960    input_: np.ndarray,
 961    save_path: Optional[Union[str, os.PathLike]] = None,
 962    lazy_loading: bool = False,
 963    ndim: Optional[int] = None,
 964    tile_shape: Optional[Tuple[int, int]] = None,
 965    halo: Optional[Tuple[int, int]] = None,
 966    verbose: bool = True,
 967    batch_size: int = 1,
 968    pbar_init: Optional[callable] = None,
 969    pbar_update: Optional[callable] = None,
 970) -> ImageEmbeddings:
 971    """Compute the image embeddings (output of the encoder) for the input.
 972
 973    If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
 974
 975    Args:
 976        predictor: The Segment Anything predictor.
 977        input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
 978        save_path: Path to save the embeddings in a zarr container.
 979        lazy_loading: Whether to load all embeddings into memory or return an
 980            object to load them on demand when required. This only has an effect if 'save_path' is given
 981            and if the input is 3 dimensional.
 982        ndim: The dimensionality of the data. If not given will be deduced from the input data.
 983        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
 984        halo: Overlap of the tiles for tiled prediction.
 985        verbose: Whether to be verbose in the computation.
 986        batch_size: The batch size for precomputing image embeddings over tiles (or planes).
 987        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 988            Can be used together with pbar_update to handle napari progress bar in other thread.
 989            To enables using this function within a threadworker.
 990        pbar_update: Callback to update an external progress bar.
 991
 992    Returns:
 993        The image embeddings.
 994    """
 995    ndim = input_.ndim if ndim is None else ndim
 996
 997    # Handle the embedding save_path.
 998    # We don't have a save path, open in memory zarr file to hold tiled embeddings.
 999    if save_path is None:
1000        f = zarr.group()
1001
1002    # We have a save path and it already exists. Embeddings will be loaded from it,
1003    # check that the saved embeddings in there match the parameters of the function call.
1004    elif os.path.exists(save_path):
1005        f = zarr.open(save_path, "a")
1006        _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo)
1007
1008    # We have a save path and it does not exist yet. Create the zarr file to which the
1009    # embeddings will then be saved.
1010    else:
1011        f = zarr.open(save_path, "a")
1012
1013    _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update)
1014
1015    if ndim == 2 and tile_shape is None:
1016        embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update)
1017    elif ndim == 2 and tile_shape is not None:
1018        embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size)
1019    elif ndim == 3 and tile_shape is None:
1020        embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size)
1021    elif ndim == 3 and tile_shape is not None:
1022        embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size)
1023    else:
1024        raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.")
1025
1026    pbar_close()
1027    return embeddings
1028
1029
1030def set_precomputed(
1031    predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None,
1032) -> SamPredictor:
1033    """Set the precomputed image embeddings for a predictor.
1034
1035    Args:
1036        predictor: The Segment Anything predictor.
1037        image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`.
1038        i: Index for the image data. Required if `image` has three spatial dimensions
1039            or a time dimension and two spatial dimensions.
1040        tile_id: Index for the tile. This is required if the embeddings are tiled.
1041
1042    Returns:
1043        The predictor with set features.
1044    """
1045    if tile_id is not None:
1046        tile_features = image_embeddings["features"][tile_id]
1047        tile_image_embeddings = {
1048            "features": tile_features,
1049            "input_size": tile_features.attrs["input_size"],
1050            "original_size": tile_features.attrs["original_size"]
1051        }
1052        return set_precomputed(predictor, tile_image_embeddings, i=i)
1053
1054    device = predictor.device
1055    features = image_embeddings["features"]
1056    assert features.ndim in (4, 5), f"{features.ndim}"
1057    if features.ndim == 5 and i is None:
1058        raise ValueError("The data is 3D so an index i is needed.")
1059    elif features.ndim == 4 and i is not None:
1060        raise ValueError("The data is 2D so an index is not needed.")
1061
1062    if i is None:
1063        predictor.features = features.to(device) if torch.is_tensor(features) else \
1064            torch.from_numpy(features[:]).to(device)
1065    else:
1066        predictor.features = features[i].to(device) if torch.is_tensor(features) else \
1067            torch.from_numpy(features[i]).to(device)
1068
1069    predictor.original_size = image_embeddings["original_size"]
1070    predictor.input_size = image_embeddings["input_size"]
1071    predictor.is_image_set = True
1072
1073    return predictor
1074
1075
1076#
1077# Misc functionality
1078#
1079
1080
1081def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
1082    """Compute the intersection over union of two masks.
1083
1084    Args:
1085        mask1: The first mask.
1086        mask2: The second mask.
1087
1088    Returns:
1089        The intersection over union of the two masks.
1090    """
1091    overlap = np.logical_and(mask1 == 1, mask2 == 1).sum()
1092    union = np.logical_or(mask1 == 1, mask2 == 1).sum()
1093    eps = 1e-7
1094    iou = float(overlap) / (float(union) + eps)
1095    return iou
1096
1097
1098def get_centers_and_bounding_boxes(
1099    segmentation: np.ndarray, mode: str = "v"
1100) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]:
1101    """Returns the center coordinates of the foreground instances in the ground-truth.
1102
1103    Args:
1104        segmentation: The segmentation.
1105        mode: Determines the functionality used for computing the centers.
1106            If 'v', the object's eccentricity centers computed by vigra are used.
1107            If 'p' the object's centroids computed by skimage are used.
1108
1109    Returns:
1110        A dictionary that maps object ids to the corresponding centroid.
1111        A dictionary that maps object_ids to the corresponding bounding box.
1112    """
1113    assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra"
1114
1115    properties = regionprops(segmentation)
1116
1117    if mode == "p":
1118        center_coordinates = {prop.label: prop.centroid for prop in properties}
1119    elif mode == "v":
1120        center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32'))
1121        center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0}
1122
1123    bbox_coordinates = {prop.label: prop.bbox for prop in properties}
1124
1125    assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}"
1126    return center_coordinates, bbox_coordinates
1127
1128
1129def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray:
1130    """Helper function to load image data from file.
1131
1132    Args:
1133        path: The filepath to the image data.
1134        key: The internal filepath for complex data formats like hdf5.
1135        lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
1136
1137    Returns:
1138        The image data.
1139    """
1140    if key is None:
1141        image_data = imageio.imread(path)
1142    else:
1143        with open_file(path, mode="r") as f:
1144            image_data = f[key]
1145            if not lazy_loading:
1146                image_data = image_data[:]
1147
1148    return image_data
1149
1150
1151def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor:
1152    """Convert the segmentation to one-hot encoded masks.
1153
1154    Args:
1155        segmentation: The segmentation.
1156        segmentation_ids: Optional subset of ids that will be used to subsample the masks.
1157
1158    Returns:
1159        The one-hot encoded masks.
1160    """
1161    masks = segmentation.copy()
1162    if segmentation_ids is None:
1163        n_ids = int(segmentation.max())
1164
1165    else:
1166        msg = "No foreground objects were found."
1167        if len(segmentation_ids) == 0:  # The list should not be completely empty.
1168            raise RuntimeError(msg)
1169
1170        if 0 in segmentation_ids:  # The list should not have 'zero' as a value.
1171            raise RuntimeError(msg)
1172
1173        # the segmentation ids have to be sorted
1174        segmentation_ids = np.sort(segmentation_ids)
1175
1176        # set the non selected objects to zero and relabel sequentially
1177        masks[~np.isin(masks, segmentation_ids)] = 0
1178        masks = relabel_sequential(masks)[0]
1179        n_ids = len(segmentation_ids)
1180
1181    masks = torch.from_numpy(masks)
1182
1183    one_hot_shape = (n_ids + 1,) + masks.shape
1184    masks = masks.unsqueeze(0)  # add dimension to scatter
1185    masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:]
1186
1187    # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W
1188    masks = masks.unsqueeze(1)
1189    return masks
1190
1191
1192def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1193    """Get a suitable block shape for chunking a given shape.
1194
1195    The primary use for this is determining chunk sizes for
1196    zarr arrays or block shapes for parallelization.
1197
1198    Args:
1199        shape: The image or volume shape.
1200
1201    Returns:
1202        The block shape.
1203    """
1204    ndim = len(shape)
1205    if ndim == 2:
1206        block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape))
1207    elif ndim == 3:
1208        block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape))
1209    else:
1210        raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.")
1211
1212    return block_shape
1213
1214
1215def micro_sam_info():
1216    """Display μSAM information using a rich console."""
1217    import psutil
1218    import platform
1219    from rich.panel import Panel
1220    from rich.table import Table
1221    from rich.console import Console
1222
1223    import torch
1224    import micro_sam
1225
1226    # Open up a new console.
1227    console = Console()
1228
1229    # The header for information CLI.
1230    console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center")
1231    console.print("-" * console.width)
1232
1233    # μSAM version panel.
1234    console.print(
1235        Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True)
1236    )
1237
1238    # The documentation link panel.
1239    console.print(
1240        Panel(
1241            "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n"
1242            "https://computational-cell-analytics.github.io/micro-sam", title="Documentation"
1243        )
1244    )
1245
1246    # The publication panel.
1247    console.print(
1248        Panel(
1249            "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n"
1250            "https://www.nature.com/articles/s41592-024-02580-4", title="Publication"
1251        )
1252    )
1253
1254    # The cache directory panel.
1255    console.print(
1256        Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{get_cache_directory()}", title="Cache Directory")
1257    )
1258
1259    # The available models panel.
1260    available_models = list(get_model_names())
1261    # We filter out the decoder models.
1262    available_models = [m for m in available_models if not m.endswith("_decoder")]
1263    model_list = "\n".join(available_models)
1264    console.print(
1265        Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models")
1266    )
1267
1268    # The system information table.
1269    total_memory = psutil.virtual_memory().total / (1024 ** 3)
1270    table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True)
1271    table.add_column("Property")
1272    table.add_column("Value", style="bold #56B4E9")
1273    table.add_row("System", platform.system())
1274    table.add_row("Node Name", platform.node())
1275    table.add_row("Release", platform.release())
1276    table.add_row("Version", platform.version())
1277    table.add_row("Machine", platform.machine())
1278    table.add_row("Processor", platform.processor())
1279    table.add_row("Platform", platform.platform())
1280    table.add_row("Total RAM (GB)", f"{total_memory:.2f}")
1281    console.print(table)
1282
1283    # The device information and check for available GPU acceleration.
1284    default_device = _get_default_device()
1285
1286    if default_device == "cuda":
1287        device_index = torch.cuda.current_device()
1288        device_name = torch.cuda.get_device_name(device_index)
1289        console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information"))
1290    elif default_device == "mps":
1291        console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information"))
1292    else:
1293        console.print(
1294            Panel(
1295                "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]",
1296                title="Device Information"
1297            )
1298        )
def get_cache_directory() -> None:
60def get_cache_directory() -> None:
61    """Get micro-sam cache directory location.
62
63    Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
64    """
65    default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam"))
66    cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory))
67    return cache_directory

Get micro-sam cache directory location.

Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.

def microsam_cachedir() -> None:
75def microsam_cachedir() -> None:
76    """Return the micro-sam cache directory.
77
78    Returns the top level cache directory for micro-sam models and sample data.
79
80    Every time this function is called, we check for any user updates made to
81    the MICROSAM_CACHEDIR os environment variable since the last time.
82    """
83    cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam")
84    return cache_directory

Return the micro-sam cache directory.

Returns the top level cache directory for micro-sam models and sample data.

Every time this function is called, we check for any user updates made to the MICROSAM_CACHEDIR os environment variable since the last time.

def models():
 87def models():
 88    """Return the segmentation models registry.
 89
 90    We recreate the model registry every time this function is called,
 91    so any user changes to the default micro-sam cache directory location
 92    are respected.
 93    """
 94
 95    # We use xxhash to compute the hash of the models, see
 96    # https://github.com/computational-cell-analytics/micro-sam/issues/283
 97    # (It is now a dependency, so we don't provide the sha256 fallback anymore.)
 98    # To generate the xxh128 hash:
 99    #     xxh128sum filename
100    encoder_registry = {
101        # The default segment anything models:
102        "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a",
103        "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2",
104        "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b",
105        # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM.
106        "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0",
107        # The current version of our models in the modelzoo.
108        # LM generalist models:
109        "vit_l_lm": "xxh128:fc32ea6f7fcc7eb02737d1304f81f5f2",
110        "vit_b_lm": "xxh128:8fd5806be3c3ba213e19a709d6d1495f",
111        "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e",
112        # EM models:
113        "vit_l_em_organelles": "xxh128:096c9695966803ca6fde24f4c1e3c3fb",
114        "vit_b_em_organelles": "xxh128:f6f6593aeecd0e15a07bdac86360b6cc",
115        "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9",
116        # Histopathology models:
117        "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974",
118        "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2",
119        "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e",
120        # Medical Imaging models:
121        "vit_b_medical_imaging": "xxh128:5be672f1458263a9edc9fd40d7f56ac1",
122    }
123    # Additional decoders for instance segmentation.
124    decoder_registry = {
125        # LM generalist models:
126        "vit_l_lm_decoder": "xxh128:779b5a50ecc6d46d495753fba8717f2f",
127        "vit_b_lm_decoder": "xxh128:9f580a96984b3085389ced5d9a4ae75d",
128        "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823",
129        # EM models:
130        "vit_l_em_organelles_decoder": "xxh128:d60fd96bd6060856f6430f29e42568fb",
131        "vit_b_em_organelles_decoder": "xxh128:b2d4dcffb99f76d83497d39ee500088f",
132        "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44",
133        # Histopathology models:
134        "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213",
135        "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04",
136        "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47",
137    }
138    registry = {**encoder_registry, **decoder_registry}
139
140    encoder_urls = {
141        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
142        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
143        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
144        "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
145        "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l.pt",
146        "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b.pt",
147        "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt",
148        "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt",  # noqa
149        "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt",
150        "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt",  # noqa
151        "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download",
152        "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download",
153        "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download",
154        "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/AB69HGhj8wuozXQ/download",
155    }
156
157    decoder_urls = {
158        "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l_decoder.pt",  # noqa
159        "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b_decoder.pt",  # noqa
160        "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt",  # noqa
161        "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt",  # noqa
162        "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt",  # noqa
163        "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt",  # noqa
164        "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download",
165        "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download",
166        "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download",
167    }
168    urls = {**encoder_urls, **decoder_urls}
169
170    models = pooch.create(
171        path=os.path.join(microsam_cachedir(), "models"),
172        base_url="",
173        registry=registry,
174        urls=urls,
175    )
176    return models

Return the segmentation models registry.

We recreate the model registry every time this function is called, so any user changes to the default micro-sam cache directory location are respected.

def get_device( device: Union[str, torch.device, NoneType] = None) -> Union[str, torch.device]:
198def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
199    """Get the torch device.
200
201    If no device is passed the default device for your system is used.
202    Else it will be checked if the device you have passed is supported.
203
204    Args:
205        device: The input device.
206
207    Returns:
208        The device.
209    """
210    if device is None or device == "auto":
211        device = _get_default_device()
212    else:
213        device_type = device if isinstance(device, str) else device.type
214        if device_type.lower() == "cuda":
215            if not torch.cuda.is_available():
216                raise RuntimeError("PyTorch CUDA backend is not available.")
217        elif device_type.lower() == "mps":
218            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
219                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
220        elif device_type.lower() == "cpu":
221            pass  # cpu is always available
222        else:
223            raise RuntimeError(
224                f"Unsupported device: {device}\n"
225                "Please choose from 'cpu', 'cuda', or 'mps'."
226            )
227
228    return device

Get the torch device.

If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.

Arguments:
  • device: The input device.
Returns:

The device.

def get_sam_model( model_type: str = 'vit_b_lm', device: Union[str, torch.device, NoneType] = None, checkpoint_path: Union[str, os.PathLike, NoneType] = None, return_sam: bool = False, return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, progress_bar_factory: Optional[Callable] = None, **model_kwargs) -> mobile_sam.predictor.SamPredictor:
290def get_sam_model(
291    model_type: str = _DEFAULT_MODEL,
292    device: Optional[Union[str, torch.device]] = None,
293    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
294    return_sam: bool = False,
295    return_state: bool = False,
296    peft_kwargs: Optional[Dict] = None,
297    flexible_load_checkpoint: bool = False,
298    progress_bar_factory: Optional[Callable] = None,
299    **model_kwargs,
300) -> SamPredictor:
301    r"""Get the Segment Anything Predictor.
302
303    This function will download the required model or load it from the cached weight file.
304    This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
305    The name of the requested model can be set via `model_type`.
306    See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
307    for an overview of the available models
308
309    Alternatively this function can also load a model from weights stored in a local filepath.
310    The corresponding file path is given via `checkpoint_path`. In this case `model_type`
311    must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
312    a SAM model with vit_b encoder.
313
314    By default the models are downloaded to a folder named 'micro_sam/models'
315    inside your default cache directory, eg:
316    * Mac: ~/Library/Caches/<AppName>
317    * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined.
318    * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache
319    See the pooch.os_cache() documentation for more details:
320    https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
321
322    Args:
323        model_type: The Segment Anything model to use. Will use the standard `vit_l` model by default.
324            To get a list of all available model names you can call `get_model_names`.
325        device: The device for the model. If none is given will use GPU if available.
326        checkpoint_path: The path to a file with weights that should be used instead of using the
327            weights corresponding to `model_type`. If given, `model_type` must match the architecture
328            corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder
329            then `model_type` must be given as "vit_b".
330        return_sam: Return the sam model object as well as the predictor.
331        return_state: Return the unpickled checkpoint state.
332        peft_kwargs: Keyword arguments for th PEFT wrapper class.
333        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
334        model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
335        progress_bar_factory: A function to create a progress bar for the model download.
336
337    Returns:
338        The segment anything predictor.
339    """
340    device = get_device(device)
341
342    # We support passing a local filepath to a checkpoint.
343    # In this case we do not download any weights but just use the local weight file,
344    # as it is, without copying it over anywhere or checking it's hashes.
345
346    # checkpoint_path has not been passed, we download a known model and derive the correct
347    # URL from the model_type. If the model_type is invalid pooch will raise an error.
348    _provided_checkpoint_path = checkpoint_path is not None
349    if checkpoint_path is None:
350        model_registry = models()
351
352        progress_bar = True
353        # Check if we have to download the model.
354        # If we do and have a progress bar factory, then we over-write the progress bar.
355        if not os.path.exists(os.path.join(get_cache_directory(), model_type)) and progress_bar_factory is not None:
356            progress_bar = progress_bar_factory(model_type)
357
358        checkpoint_path = model_registry.fetch(model_type, progressbar=progress_bar)
359        if not isinstance(progress_bar, bool):  # Close the progress bar when the task finishes.
360            progress_bar.close()
361
362        model_hash = model_registry.registry[model_type]
363
364        # If we have a custom model then we may also have a decoder checkpoint.
365        # Download it here, so that we can add it to the state.
366        decoder_name = f"{model_type}_decoder"
367        decoder_path = model_registry.fetch(
368            decoder_name, progressbar=True
369        ) if decoder_name in model_registry.registry else None
370
371    # checkpoint_path has been passed, we use it instead of downloading a model.
372    else:
373        # Check if the file exists and raise an error otherwise.
374        # We can't check any hashes here, and we don't check if the file is actually a valid weight file.
375        # (If it isn't the model creation will fail below.)
376        if not os.path.exists(checkpoint_path):
377            raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.")
378        model_hash = _compute_hash(checkpoint_path)
379        decoder_path = None
380
381    # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped
382    # before calling sam_model_registry.
383    abbreviated_model_type = model_type[:5]
384    if abbreviated_model_type not in _MODEL_TYPES:
385        raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}")
386    if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
387        raise RuntimeError(
388            "'mobile_sam' is required for the vit-tiny. "
389            "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
390        )
391
392    state, model_state = _load_checkpoint(checkpoint_path)
393
394    if _provided_checkpoint_path:
395        # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type'
396        # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination.
397        from micro_sam.models.build_sam import _validate_model_type
398        _provided_model_type = _validate_model_type(model_state)
399
400        # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type'
401        # Otherwise replace 'abbreviated_model_type' with the later.
402        if abbreviated_model_type != _provided_model_type:
403            # Printing the message below to avoid any filtering of warnings on user's end.
404            print(
405                f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', "
406                f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. "
407                f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. "
408                "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future."
409            )
410
411        # Replace the extracted 'abbreviated_model_type' subjected to the model weights.
412        abbreviated_model_type = _provided_model_type
413
414    # Whether to update parameters necessary to initialize the model
415    if model_kwargs:  # Checks whether model_kwargs have been provided or not
416        if abbreviated_model_type == "vit_t":
417            raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.")
418        sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs)
419
420    else:
421        sam = sam_model_registry[abbreviated_model_type]()
422
423    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
424    # Overwrites the SAM model by freezing the backbone and allow PEFT.
425    if peft_kwargs and isinstance(peft_kwargs, dict):
426        # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference.
427        peft_kwargs.pop("quantize", None)
428
429        if abbreviated_model_type == "vit_t":
430            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
431
432        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
433    # In case the model checkpoints have some issues when it is initialized with different parameters than default.
434    if flexible_load_checkpoint:
435        sam = _handle_checkpoint_loading(sam, model_state)
436    else:
437        sam.load_state_dict(model_state)
438    sam.to(device=device)
439
440    predictor = SamPredictor(sam)
441    predictor.model_type = abbreviated_model_type
442    predictor._hash = model_hash
443    predictor.model_name = model_type
444    predictor.checkpoint_path = checkpoint_path
445
446    # Add the decoder to the state if we have one and if the state is returned.
447    if decoder_path is not None and return_state:
448        state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False)
449
450    if return_sam and return_state:
451        return predictor, sam, state
452    if return_sam:
453        return predictor, sam
454    if return_state:
455        return predictor, state
456    return predictor

Get the Segment Anything Predictor.

This function will download the required model or load it from the cached weight file. This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. The name of the requested model can be set via model_type. See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models for an overview of the available models

Alternatively this function can also load a model from weights stored in a local filepath. The corresponding file path is given via checkpoint_path. In this case model_type must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for a SAM model with vit_b encoder.

By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg:

Arguments:
  • model_type: The Segment Anything model to use. Will use the standard vit_l model by default. To get a list of all available model names you can call get_model_names.
  • device: The device for the model. If none is given will use GPU if available.
  • checkpoint_path: The path to a file with weights that should be used instead of using the weights corresponding to model_type. If given, model_type must match the architecture corresponding to the weight file. e.g. if you use weights for SAM with vit_b encoder then model_type must be given as "vit_b".
  • return_sam: Return the sam model object as well as the predictor.
  • return_state: Return the unpickled checkpoint state.
  • peft_kwargs: Keyword arguments for th PEFT wrapper class.
  • flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
  • model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
  • progress_bar_factory: A function to create a progress bar for the model download.
Returns:

The segment anything predictor.

def export_custom_sam_model( checkpoint_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike], with_segmentation_decoder: bool = False) -> None:
492def export_custom_sam_model(
493    checkpoint_path: Union[str, os.PathLike],
494    model_type: str,
495    save_path: Union[str, os.PathLike],
496    with_segmentation_decoder: bool = False,
497) -> None:
498    """Export a finetuned Segment Anything Model to the standard model format.
499
500    The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
501
502    Args:
503        checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
504        model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
505        save_path: Where to save the exported model.
506        with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well.
507            If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
508    """
509    _, state = get_sam_model(
510        model_type=model_type, checkpoint_path=checkpoint_path, return_state=True, device="cpu",
511    )
512    model_state = state["model_state"]
513    prefix = "sam."
514    model_state = OrderedDict(
515        [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]
516    )
517
518    # Store the 'decoder_state' as well, if desired.
519    if with_segmentation_decoder:
520        if "decoder_state" not in state:
521            raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.")
522        decoder_state = state["decoder_state"]
523        save_state = {"model_state": model_state, "decoder_state": decoder_state}
524    else:
525        save_state = model_state
526
527    torch.save(save_state, save_path)

Export a finetuned Segment Anything Model to the standard model format.

The exported model can be used by the interactive annotation tools in micro_sam.annotator.

Arguments:
  • checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
  • model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
  • save_path: Where to save the exported model.
  • with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
def export_custom_qlora_model( checkpoint_path: Union[str, os.PathLike, NoneType], finetuned_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike]) -> None:
530def export_custom_qlora_model(
531    checkpoint_path: Optional[Union[str, os.PathLike]],
532    finetuned_path: Union[str, os.PathLike],
533    model_type: str,
534    save_path: Union[str, os.PathLike],
535) -> None:
536    """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
537
538    The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`.
539
540    Args:
541        checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
542        finetuned_path: The path to the new finetuned model, using QLoRA.
543        model_type: The Segment Anything Model type corresponding to the checkpoint.
544        save_path: Where to save the exported model.
545    """
546    # Step 1: Get the base SAM model: used to start finetuning from.
547    _, sam = get_sam_model(
548        model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True,
549    )
550
551    # Step 2: Load the QLoRA-style finetuned model.
552    ft_state, ft_model_state = _load_checkpoint(finetuned_path)
553
554    # Step 3: Get LoRA weights from QLoRA and retain all original parameters from the base SAM model.
555    updated_model_state = {}
556
557    # - At first, we get all LoRA layers from the QLoRA-style finetuned model checkpoint.
558    for k, v in ft_model_state.items():
559        if k.find("w_b_linear") != -1 or k.find("w_a_linear") != -1:
560            updated_model_state[k] = v
561
562    # - Next, we get all the remaining parameters from the base SAM model.
563    for k, v in sam.state_dict().items():
564        if k.find("attn.qkv.") != -1:
565            k = k.replace("qkv", "qkv.qkv_proj")
566            updated_model_state[k] = v
567        else:
568
569            updated_model_state[k] = v
570
571    # - Finally, we replace the old model state with the new one (to retain other relevant stuff)
572    ft_state['model_state'] = updated_model_state
573
574    # Step 4: Store the new "state" to "save_path"
575    torch.save(ft_state, save_path)

Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.

The exported model can be used with the LoRA backbone by passing the relevant peft_kwargs to get_sam_model.

Arguments:
  • checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
  • finetuned_path: The path to the new finetuned model, using QLoRA.
  • model_type: The Segment Anything Model type corresponding to the checkpoint.
  • save_path: Where to save the exported model.
def get_model_names() -> Iterable:
578def get_model_names() -> Iterable:
579    model_registry = models()
580    model_names = model_registry.registry.keys()
581    return model_names
def precompute_image_embeddings( predictor: mobile_sam.predictor.SamPredictor, input_: numpy.ndarray, save_path: Union[str, os.PathLike, NoneType] = None, lazy_loading: bool = False, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, batch_size: int = 1, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> Dict[str, Any]:
 959def precompute_image_embeddings(
 960    predictor: SamPredictor,
 961    input_: np.ndarray,
 962    save_path: Optional[Union[str, os.PathLike]] = None,
 963    lazy_loading: bool = False,
 964    ndim: Optional[int] = None,
 965    tile_shape: Optional[Tuple[int, int]] = None,
 966    halo: Optional[Tuple[int, int]] = None,
 967    verbose: bool = True,
 968    batch_size: int = 1,
 969    pbar_init: Optional[callable] = None,
 970    pbar_update: Optional[callable] = None,
 971) -> ImageEmbeddings:
 972    """Compute the image embeddings (output of the encoder) for the input.
 973
 974    If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
 975
 976    Args:
 977        predictor: The Segment Anything predictor.
 978        input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
 979        save_path: Path to save the embeddings in a zarr container.
 980        lazy_loading: Whether to load all embeddings into memory or return an
 981            object to load them on demand when required. This only has an effect if 'save_path' is given
 982            and if the input is 3 dimensional.
 983        ndim: The dimensionality of the data. If not given will be deduced from the input data.
 984        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
 985        halo: Overlap of the tiles for tiled prediction.
 986        verbose: Whether to be verbose in the computation.
 987        batch_size: The batch size for precomputing image embeddings over tiles (or planes).
 988        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 989            Can be used together with pbar_update to handle napari progress bar in other thread.
 990            To enables using this function within a threadworker.
 991        pbar_update: Callback to update an external progress bar.
 992
 993    Returns:
 994        The image embeddings.
 995    """
 996    ndim = input_.ndim if ndim is None else ndim
 997
 998    # Handle the embedding save_path.
 999    # We don't have a save path, open in memory zarr file to hold tiled embeddings.
1000    if save_path is None:
1001        f = zarr.group()
1002
1003    # We have a save path and it already exists. Embeddings will be loaded from it,
1004    # check that the saved embeddings in there match the parameters of the function call.
1005    elif os.path.exists(save_path):
1006        f = zarr.open(save_path, "a")
1007        _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo)
1008
1009    # We have a save path and it does not exist yet. Create the zarr file to which the
1010    # embeddings will then be saved.
1011    else:
1012        f = zarr.open(save_path, "a")
1013
1014    _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update)
1015
1016    if ndim == 2 and tile_shape is None:
1017        embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update)
1018    elif ndim == 2 and tile_shape is not None:
1019        embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size)
1020    elif ndim == 3 and tile_shape is None:
1021        embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size)
1022    elif ndim == 3 and tile_shape is not None:
1023        embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size)
1024    else:
1025        raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.")
1026
1027    pbar_close()
1028    return embeddings

Compute the image embeddings (output of the encoder) for the input.

If 'save_path' is given the embeddings will be loaded/saved in a zarr container.

Arguments:
  • predictor: The Segment Anything predictor.
  • input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
  • save_path: Path to save the embeddings in a zarr container.
  • lazy_loading: Whether to load all embeddings into memory or return an object to load them on demand when required. This only has an effect if 'save_path' is given and if the input is 3 dimensional.
  • ndim: The dimensionality of the data. If not given will be deduced from the input data.
  • tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction.
  • verbose: Whether to be verbose in the computation.
  • batch_size: The batch size for precomputing image embeddings over tiles (or planes).
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
Returns:

The image embeddings.

def set_precomputed( predictor: mobile_sam.predictor.SamPredictor, image_embeddings: Dict[str, Any], i: Optional[int] = None, tile_id: Optional[int] = None) -> mobile_sam.predictor.SamPredictor:
1031def set_precomputed(
1032    predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None,
1033) -> SamPredictor:
1034    """Set the precomputed image embeddings for a predictor.
1035
1036    Args:
1037        predictor: The Segment Anything predictor.
1038        image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`.
1039        i: Index for the image data. Required if `image` has three spatial dimensions
1040            or a time dimension and two spatial dimensions.
1041        tile_id: Index for the tile. This is required if the embeddings are tiled.
1042
1043    Returns:
1044        The predictor with set features.
1045    """
1046    if tile_id is not None:
1047        tile_features = image_embeddings["features"][tile_id]
1048        tile_image_embeddings = {
1049            "features": tile_features,
1050            "input_size": tile_features.attrs["input_size"],
1051            "original_size": tile_features.attrs["original_size"]
1052        }
1053        return set_precomputed(predictor, tile_image_embeddings, i=i)
1054
1055    device = predictor.device
1056    features = image_embeddings["features"]
1057    assert features.ndim in (4, 5), f"{features.ndim}"
1058    if features.ndim == 5 and i is None:
1059        raise ValueError("The data is 3D so an index i is needed.")
1060    elif features.ndim == 4 and i is not None:
1061        raise ValueError("The data is 2D so an index is not needed.")
1062
1063    if i is None:
1064        predictor.features = features.to(device) if torch.is_tensor(features) else \
1065            torch.from_numpy(features[:]).to(device)
1066    else:
1067        predictor.features = features[i].to(device) if torch.is_tensor(features) else \
1068            torch.from_numpy(features[i]).to(device)
1069
1070    predictor.original_size = image_embeddings["original_size"]
1071    predictor.input_size = image_embeddings["input_size"]
1072    predictor.is_image_set = True
1073
1074    return predictor

Set the precomputed image embeddings for a predictor.

Arguments:
  • predictor: The Segment Anything predictor.
  • image_embeddings: The precomputed image embeddings computed by precompute_image_embeddings.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_id: Index for the tile. This is required if the embeddings are tiled.
Returns:

The predictor with set features.

def compute_iou(mask1: numpy.ndarray, mask2: numpy.ndarray) -> float:
1082def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
1083    """Compute the intersection over union of two masks.
1084
1085    Args:
1086        mask1: The first mask.
1087        mask2: The second mask.
1088
1089    Returns:
1090        The intersection over union of the two masks.
1091    """
1092    overlap = np.logical_and(mask1 == 1, mask2 == 1).sum()
1093    union = np.logical_or(mask1 == 1, mask2 == 1).sum()
1094    eps = 1e-7
1095    iou = float(overlap) / (float(union) + eps)
1096    return iou

Compute the intersection over union of two masks.

Arguments:
  • mask1: The first mask.
  • mask2: The second mask.
Returns:

The intersection over union of the two masks.

def get_centers_and_bounding_boxes( segmentation: numpy.ndarray, mode: str = 'v') -> Tuple[Dict[int, numpy.ndarray], Dict[int, tuple]]:
1099def get_centers_and_bounding_boxes(
1100    segmentation: np.ndarray, mode: str = "v"
1101) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]:
1102    """Returns the center coordinates of the foreground instances in the ground-truth.
1103
1104    Args:
1105        segmentation: The segmentation.
1106        mode: Determines the functionality used for computing the centers.
1107            If 'v', the object's eccentricity centers computed by vigra are used.
1108            If 'p' the object's centroids computed by skimage are used.
1109
1110    Returns:
1111        A dictionary that maps object ids to the corresponding centroid.
1112        A dictionary that maps object_ids to the corresponding bounding box.
1113    """
1114    assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra"
1115
1116    properties = regionprops(segmentation)
1117
1118    if mode == "p":
1119        center_coordinates = {prop.label: prop.centroid for prop in properties}
1120    elif mode == "v":
1121        center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32'))
1122        center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0}
1123
1124    bbox_coordinates = {prop.label: prop.bbox for prop in properties}
1125
1126    assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}"
1127    return center_coordinates, bbox_coordinates

Returns the center coordinates of the foreground instances in the ground-truth.

Arguments:
  • segmentation: The segmentation.
  • mode: Determines the functionality used for computing the centers. If 'v', the object's eccentricity centers computed by vigra are used. If 'p' the object's centroids computed by skimage are used.
Returns:

A dictionary that maps object ids to the corresponding centroid. A dictionary that maps object_ids to the corresponding bounding box.

def load_image_data( path: str, key: Optional[str] = None, lazy_loading: bool = False) -> numpy.ndarray:
1130def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray:
1131    """Helper function to load image data from file.
1132
1133    Args:
1134        path: The filepath to the image data.
1135        key: The internal filepath for complex data formats like hdf5.
1136        lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
1137
1138    Returns:
1139        The image data.
1140    """
1141    if key is None:
1142        image_data = imageio.imread(path)
1143    else:
1144        with open_file(path, mode="r") as f:
1145            image_data = f[key]
1146            if not lazy_loading:
1147                image_data = image_data[:]
1148
1149    return image_data

Helper function to load image data from file.

Arguments:
  • path: The filepath to the image data.
  • key: The internal filepath for complex data formats like hdf5.
  • lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
Returns:

The image data.

def segmentation_to_one_hot( segmentation: numpy.ndarray, segmentation_ids: Optional[numpy.ndarray] = None) -> torch.Tensor:
1152def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor:
1153    """Convert the segmentation to one-hot encoded masks.
1154
1155    Args:
1156        segmentation: The segmentation.
1157        segmentation_ids: Optional subset of ids that will be used to subsample the masks.
1158
1159    Returns:
1160        The one-hot encoded masks.
1161    """
1162    masks = segmentation.copy()
1163    if segmentation_ids is None:
1164        n_ids = int(segmentation.max())
1165
1166    else:
1167        msg = "No foreground objects were found."
1168        if len(segmentation_ids) == 0:  # The list should not be completely empty.
1169            raise RuntimeError(msg)
1170
1171        if 0 in segmentation_ids:  # The list should not have 'zero' as a value.
1172            raise RuntimeError(msg)
1173
1174        # the segmentation ids have to be sorted
1175        segmentation_ids = np.sort(segmentation_ids)
1176
1177        # set the non selected objects to zero and relabel sequentially
1178        masks[~np.isin(masks, segmentation_ids)] = 0
1179        masks = relabel_sequential(masks)[0]
1180        n_ids = len(segmentation_ids)
1181
1182    masks = torch.from_numpy(masks)
1183
1184    one_hot_shape = (n_ids + 1,) + masks.shape
1185    masks = masks.unsqueeze(0)  # add dimension to scatter
1186    masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:]
1187
1188    # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W
1189    masks = masks.unsqueeze(1)
1190    return masks

Convert the segmentation to one-hot encoded masks.

Arguments:
  • segmentation: The segmentation.
  • segmentation_ids: Optional subset of ids that will be used to subsample the masks.
Returns:

The one-hot encoded masks.

def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1193def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1194    """Get a suitable block shape for chunking a given shape.
1195
1196    The primary use for this is determining chunk sizes for
1197    zarr arrays or block shapes for parallelization.
1198
1199    Args:
1200        shape: The image or volume shape.
1201
1202    Returns:
1203        The block shape.
1204    """
1205    ndim = len(shape)
1206    if ndim == 2:
1207        block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape))
1208    elif ndim == 3:
1209        block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape))
1210    else:
1211        raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.")
1212
1213    return block_shape

Get a suitable block shape for chunking a given shape.

The primary use for this is determining chunk sizes for zarr arrays or block shapes for parallelization.

Arguments:
  • shape: The image or volume shape.
Returns:

The block shape.

def micro_sam_info():
1216def micro_sam_info():
1217    """Display μSAM information using a rich console."""
1218    import psutil
1219    import platform
1220    from rich.panel import Panel
1221    from rich.table import Table
1222    from rich.console import Console
1223
1224    import torch
1225    import micro_sam
1226
1227    # Open up a new console.
1228    console = Console()
1229
1230    # The header for information CLI.
1231    console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center")
1232    console.print("-" * console.width)
1233
1234    # μSAM version panel.
1235    console.print(
1236        Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True)
1237    )
1238
1239    # The documentation link panel.
1240    console.print(
1241        Panel(
1242            "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n"
1243            "https://computational-cell-analytics.github.io/micro-sam", title="Documentation"
1244        )
1245    )
1246
1247    # The publication panel.
1248    console.print(
1249        Panel(
1250            "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n"
1251            "https://www.nature.com/articles/s41592-024-02580-4", title="Publication"
1252        )
1253    )
1254
1255    # The cache directory panel.
1256    console.print(
1257        Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{get_cache_directory()}", title="Cache Directory")
1258    )
1259
1260    # The available models panel.
1261    available_models = list(get_model_names())
1262    # We filter out the decoder models.
1263    available_models = [m for m in available_models if not m.endswith("_decoder")]
1264    model_list = "\n".join(available_models)
1265    console.print(
1266        Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models")
1267    )
1268
1269    # The system information table.
1270    total_memory = psutil.virtual_memory().total / (1024 ** 3)
1271    table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True)
1272    table.add_column("Property")
1273    table.add_column("Value", style="bold #56B4E9")
1274    table.add_row("System", platform.system())
1275    table.add_row("Node Name", platform.node())
1276    table.add_row("Release", platform.release())
1277    table.add_row("Version", platform.version())
1278    table.add_row("Machine", platform.machine())
1279    table.add_row("Processor", platform.processor())
1280    table.add_row("Platform", platform.platform())
1281    table.add_row("Total RAM (GB)", f"{total_memory:.2f}")
1282    console.print(table)
1283
1284    # The device information and check for available GPU acceleration.
1285    default_device = _get_default_device()
1286
1287    if default_device == "cuda":
1288        device_index = torch.cuda.current_device()
1289        device_name = torch.cuda.get_device_name(device_index)
1290        console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information"))
1291    elif default_device == "mps":
1292        console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information"))
1293    else:
1294        console.print(
1295            Panel(
1296                "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]",
1297                title="Device Information"
1298            )
1299        )

Display μSAM information using a rich console.