micro_sam.util

Helper functions for downloading Segment Anything models and predicting image embeddings.

   1"""
   2Helper functions for downloading Segment Anything models and predicting image embeddings.
   3"""
   4
   5import os
   6import pickle
   7import hashlib
   8import warnings
   9from pathlib import Path
  10from collections import OrderedDict
  11from typing import Any, Dict, Iterable, Optional, Tuple, Union, Callable
  12
  13import zarr
  14import vigra
  15import torch
  16import pooch
  17import xxhash
  18import numpy as np
  19import imageio.v3 as imageio
  20from skimage.measure import regionprops
  21from skimage.segmentation import relabel_sequential
  22
  23from elf.io import open_file
  24
  25from nifty.tools import blocking
  26
  27from .__version__ import __version__
  28from . import models as custom_models
  29
  30try:
  31    # Avoid import warnigns from mobile_sam
  32    with warnings.catch_warnings():
  33        warnings.simplefilter("ignore")
  34        from mobile_sam import sam_model_registry, SamPredictor
  35    VIT_T_SUPPORT = True
  36except ImportError:
  37    from segment_anything import sam_model_registry, SamPredictor
  38    VIT_T_SUPPORT = False
  39
  40try:
  41    from napari.utils import progress as tqdm
  42except ImportError:
  43    from tqdm import tqdm
  44
  45# This is the default model used in micro_sam
  46# Currently it is set to vit_b_lm
  47_DEFAULT_MODEL = "vit_b_lm"
  48
  49# The valid model types. Each type corresponds to the architecture of the
  50# vision transformer used within SAM.
  51_MODEL_TYPES = ("vit_l", "vit_b", "vit_h", "vit_t")
  52
  53
  54# TODO define the proper type for image embeddings
  55ImageEmbeddings = Dict[str, Any]
  56"""@private"""
  57
  58
  59def get_cache_directory() -> None:
  60    """Get micro-sam cache directory location.
  61
  62    Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
  63    """
  64    default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam"))
  65    cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory))
  66    return cache_directory
  67
  68
  69#
  70# Functionality for model download and export
  71#
  72
  73
  74def microsam_cachedir() -> None:
  75    """Return the micro-sam cache directory.
  76
  77    Returns the top level cache directory for micro-sam models and sample data.
  78
  79    Every time this function is called, we check for any user updates made to
  80    the MICROSAM_CACHEDIR os environment variable since the last time.
  81    """
  82    cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam")
  83    return cache_directory
  84
  85
  86def models():
  87    """Return the segmentation models registry.
  88
  89    We recreate the model registry every time this function is called,
  90    so any user changes to the default micro-sam cache directory location
  91    are respected.
  92    """
  93
  94    # We use xxhash to compute the hash of the models, see
  95    # https://github.com/computational-cell-analytics/micro-sam/issues/283
  96    # (It is now a dependency, so we don't provide the sha256 fallback anymore.)
  97    # To generate the xxh128 hash:
  98    #     xxh128sum filename
  99    encoder_registry = {
 100        # The default segment anything models:
 101        "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a",
 102        "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2",
 103        "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b",
 104        # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM.
 105        "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0",
 106        # The current version of our models in the modelzoo.
 107        # LM generalist models:
 108        "vit_l_lm": "xxh128:017f20677997d628426dec80a8018f9d",
 109        "vit_b_lm": "xxh128:fe9252a29f3f4ea53c15a06de471e186",
 110        "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e",
 111        # EM models:
 112        "vit_l_em_organelles": "xxh128:810b084b6e51acdbf760a993d8619f2d",
 113        "vit_b_em_organelles": "xxh128:f3bf2ed83d691456bae2c3f9a05fb438",
 114        "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9",
 115        # Histopathology models:
 116        "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974",
 117        "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2",
 118        "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e",
 119        # Medical Imaging models:
 120        "vit_b_medical_imaging": "xxh128:5be672f1458263a9edc9fd40d7f56ac1",
 121    }
 122    # Additional decoders for instance segmentation.
 123    decoder_registry = {
 124        # LM generalist models:
 125        "vit_l_lm_decoder": "xxh128:2faeafa03819dfe03e7c46a44aaac64a",
 126        "vit_b_lm_decoder": "xxh128:708b15ac620e235f90bb38612c4929ba",
 127        "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823",
 128        # EM models:
 129        "vit_l_em_organelles_decoder": "xxh128:334877640bfdaaabce533e3252a17294",
 130        "vit_b_em_organelles_decoder": "xxh128:bb6398956a6b0132c26b631c14f95ce2",
 131        "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44",
 132        # Histopathology models:
 133        "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213",
 134        "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04",
 135        "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47",
 136    }
 137    registry = {**encoder_registry, **decoder_registry}
 138
 139    encoder_urls = {
 140        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
 141        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
 142        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
 143        "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
 144        "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l.pt",
 145        "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b.pt",
 146        "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt",
 147        "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l.pt",  # noqa
 148        "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b.pt",  # noqa
 149        "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt",  # noqa
 150        "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download",
 151        "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download",
 152        "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download",
 153        "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/AB69HGhj8wuozXQ/download",
 154    }
 155
 156    decoder_urls = {
 157        "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l_decoder.pt",  # noqa
 158        "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b_decoder.pt",  # noqa
 159        "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt",  # noqa
 160        "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l_decoder.pt",  # noqa
 161        "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b_decoder.pt",  # noqa
 162        "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt",  # noqa
 163        "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download",
 164        "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download",
 165        "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download",
 166    }
 167    urls = {**encoder_urls, **decoder_urls}
 168
 169    models = pooch.create(
 170        path=os.path.join(microsam_cachedir(), "models"),
 171        base_url="",
 172        registry=registry,
 173        urls=urls,
 174    )
 175    return models
 176
 177
 178def _get_default_device():
 179    # check that we're in CI and use the CPU if we are
 180    # otherwise the tests may run out of memory on MAC if MPS is used.
 181    if os.getenv("GITHUB_ACTIONS") == "true":
 182        return "cpu"
 183    # Use cuda enabled gpu if it's available.
 184    if torch.cuda.is_available():
 185        device = "cuda"
 186    # As second priority use mps.
 187    # See https://pytorch.org/docs/stable/notes/mps.html for details
 188    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
 189        print("Using apple MPS device.")
 190        device = "mps"
 191    # Use the CPU as fallback.
 192    else:
 193        device = "cpu"
 194    return device
 195
 196
 197def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
 198    """Get the torch device.
 199
 200    If no device is passed the default device for your system is used.
 201    Else it will be checked if the device you have passed is supported.
 202
 203    Args:
 204        device: The input device. By default, selects the best available device supports.
 205
 206    Returns:
 207        The device.
 208    """
 209    if device is None or device == "auto":
 210        device = _get_default_device()
 211    else:
 212        device_type = device if isinstance(device, str) else device.type
 213        if device_type.lower() == "cuda":
 214            if not torch.cuda.is_available():
 215                raise RuntimeError("PyTorch CUDA backend is not available.")
 216        elif device_type.lower() == "mps":
 217            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
 218                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
 219        elif device_type.lower() == "cpu":
 220            pass  # cpu is always available
 221        else:
 222            raise RuntimeError(f"Unsupported device: '{device}'. Please choose from 'cpu', 'cuda', or 'mps'.")
 223
 224    return device
 225
 226
 227def _available_devices():
 228    available_devices = []
 229    for i in ["cuda", "mps", "cpu"]:
 230        try:
 231            device = get_device(i)
 232        except RuntimeError:
 233            pass
 234        else:
 235            available_devices.append(device)
 236    return available_devices
 237
 238
 239# We write a custom unpickler that skips objects that cannot be found instead of
 240# throwing an AttributeError or ModueNotFoundError.
 241# NOTE: since we just want to unpickle the model to load its weights these errors don't matter.
 242# See also https://stackoverflow.com/questions/27732354/unable-to-load-files-using-pickle-and-multiple-modules
 243class _CustomUnpickler(pickle.Unpickler):
 244    def find_class(self, module, name):
 245        try:
 246            return super().find_class(module, name)
 247        except (AttributeError, ModuleNotFoundError) as e:
 248            warnings.warn(f"Did not find {module}:{name} and will skip it, due to error {e}")
 249            return None
 250
 251
 252def _compute_hash(path, chunk_size=8192):
 253    hash_obj = xxhash.xxh128()
 254    with open(path, "rb") as f:
 255        chunk = f.read(chunk_size)
 256        while chunk:
 257            hash_obj.update(chunk)
 258            chunk = f.read(chunk_size)
 259    hash_val = hash_obj.hexdigest()
 260    return f"xxh128:{hash_val}"
 261
 262
 263# Load the state from a checkpoint.
 264# The checkpoint can either contain a sam encoder state
 265# or it can be a checkpoint for model finetuning.
 266def _load_checkpoint(checkpoint_path):
 267    # Over-ride the unpickler with our custom one.
 268    # This enables imports from torch_em checkpoints even if it cannot be fully unpickled.
 269    custom_pickle = pickle
 270    custom_pickle.Unpickler = _CustomUnpickler
 271
 272    state = torch.load(checkpoint_path, map_location="cpu", pickle_module=custom_pickle)
 273    if "model_state" in state:
 274        # Copy the model weights from torch_em's training format.
 275        model_state = state["model_state"]
 276        sam_prefix = "sam."
 277        model_state = OrderedDict(
 278            [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()]
 279        )
 280    else:
 281        model_state = state
 282
 283    return state, model_state
 284
 285
 286def _download_sam_model(model_type, progress_bar_factory=None):
 287    model_registry = models()
 288
 289    progress_bar = True
 290    # Check if we have to download the model.
 291    # If we do and have a progress bar factory, then we over-write the progress bar.
 292    if not os.path.exists(os.path.join(get_cache_directory(), model_type)) and progress_bar_factory is not None:
 293        progress_bar = progress_bar_factory(model_type)
 294
 295    checkpoint_path = model_registry.fetch(model_type, progressbar=progress_bar)
 296    if not isinstance(progress_bar, bool):  # Close the progress bar when the task finishes.
 297        progress_bar.close()
 298
 299    model_hash = model_registry.registry[model_type]
 300
 301    # If we have a custom model then we may also have a decoder checkpoint.
 302    # Download it here, so that we can add it to the state.
 303    decoder_name = f"{model_type}_decoder"
 304    decoder_path = model_registry.fetch(
 305        decoder_name, progressbar=True
 306    ) if decoder_name in model_registry.registry else None
 307
 308    return checkpoint_path, model_hash, decoder_path
 309
 310
 311def get_sam_model(
 312    model_type: str = _DEFAULT_MODEL,
 313    device: Optional[Union[str, torch.device]] = None,
 314    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 315    return_sam: bool = False,
 316    return_state: bool = False,
 317    peft_kwargs: Optional[Dict] = None,
 318    flexible_load_checkpoint: bool = False,
 319    progress_bar_factory: Optional[Callable] = None,
 320    **model_kwargs,
 321) -> SamPredictor:
 322    r"""Get the Segment Anything Predictor.
 323
 324    This function will download the required model or load it from the cached weight file.
 325    This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
 326    The name of the requested model can be set via `model_type`.
 327    See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
 328    for an overview of the available models
 329
 330    Alternatively this function can also load a model from weights stored in a local filepath.
 331    The corresponding file path is given via `checkpoint_path`. In this case `model_type`
 332    must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
 333    a SAM model with vit_b encoder.
 334
 335    By default the models are downloaded to a folder named 'micro_sam/models'
 336    inside your default cache directory, eg:
 337    * Mac: ~/Library/Caches/<AppName>
 338    * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined.
 339    * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache
 340    See the pooch.os_cache() documentation for more details:
 341    https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
 342
 343    Args:
 344        model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default.
 345            To get a list of all available model names you can call `micro_sam.util.get_model_names`.
 346        device: The device for the model. If 'None' is provided, will use GPU if available.
 347        checkpoint_path: The path to a file with weights that should be used instead of using the
 348            weights corresponding to `model_type`. If given, `model_type` must match the architecture
 349            corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder
 350            then `model_type` must be given as 'vit_b'.
 351        return_sam: Return the sam model object as well as the predictor. By default, set to 'False'.
 352        return_state: Return the unpickled checkpoint state. By default, set to 'False'.
 353        peft_kwargs: Keyword arguments for th PEFT wrapper class.
 354            If passed 'None', it does not initialize any parameter efficient finetuning.
 355        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
 356            By default, set to 'False'.
 357        progress_bar_factory: A function to create a progress bar for the model download.
 358        model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
 359
 360    Returns:
 361        The Segment Anything predictor.
 362    """
 363    device = get_device(device)
 364
 365    # We support passing a local filepath to a checkpoint.
 366    # In this case we do not download any weights but just use the local weight file,
 367    # as it is, without copying it over anywhere or checking it's hashes.
 368
 369    # checkpoint_path has not been passed, we download a known model and derive the correct
 370    # URL from the model_type. If the model_type is invalid pooch will raise an error.
 371    _provided_checkpoint_path = checkpoint_path is not None
 372    if checkpoint_path is None:
 373        checkpoint_path, model_hash, decoder_path = _download_sam_model(model_type, progress_bar_factory)
 374
 375    # checkpoint_path has been passed, we use it instead of downloading a model.
 376    else:
 377        # Check if the file exists and raise an error otherwise.
 378        # We can't check any hashes here, and we don't check if the file is actually a valid weight file.
 379        # (If it isn't the model creation will fail below.)
 380        if not os.path.exists(checkpoint_path):
 381            raise ValueError(f"Checkpoint at '{checkpoint_path}' could not be found.")
 382        model_hash = _compute_hash(checkpoint_path)
 383        decoder_path = None
 384
 385    # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped
 386    # before calling sam_model_registry.
 387    abbreviated_model_type = model_type[:5]
 388    if abbreviated_model_type not in _MODEL_TYPES:
 389        raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}")
 390    if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
 391        raise RuntimeError(
 392            "'mobile_sam' is required for the vit-tiny. "
 393            "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
 394        )
 395
 396    state, model_state = _load_checkpoint(checkpoint_path)
 397
 398    if _provided_checkpoint_path:
 399        # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type'
 400        # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination.
 401        from micro_sam.models.build_sam import _validate_model_type
 402        _provided_model_type = _validate_model_type(model_state)
 403
 404        # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type'
 405        # Otherwise replace 'abbreviated_model_type' with the later.
 406        if abbreviated_model_type != _provided_model_type:
 407            # Printing the message below to avoid any filtering of warnings on user's end.
 408            print(
 409                f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', "
 410                f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. "
 411                f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. "
 412                "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future."
 413            )
 414
 415        # Replace the extracted 'abbreviated_model_type' subjected to the model weights.
 416        abbreviated_model_type = _provided_model_type
 417
 418    # Whether to update parameters necessary to initialize the model
 419    if model_kwargs:  # Checks whether model_kwargs have been provided or not
 420        if abbreviated_model_type == "vit_t":
 421            raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.")
 422        sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs)
 423
 424    else:
 425        sam = sam_model_registry[abbreviated_model_type]()
 426
 427    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
 428    # Overwrites the SAM model by freezing the backbone and allow PEFT.
 429    if peft_kwargs and isinstance(peft_kwargs, dict):
 430        # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference.
 431        peft_kwargs.pop("quantize", None)
 432
 433        if abbreviated_model_type == "vit_t":
 434            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
 435
 436        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
 437    # In case the model checkpoints have some issues when it is initialized with different parameters than default.
 438    if flexible_load_checkpoint:
 439        sam = _handle_checkpoint_loading(sam, model_state)
 440    else:
 441        sam.load_state_dict(model_state)
 442    sam.to(device=device)
 443
 444    predictor = SamPredictor(sam)
 445    predictor.model_type = abbreviated_model_type
 446    predictor._hash = model_hash
 447    predictor.model_name = model_type
 448    predictor.checkpoint_path = checkpoint_path
 449
 450    # Add the decoder to the state if we have one and if the state is returned.
 451    if decoder_path is not None and return_state:
 452        state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False)
 453
 454    if return_sam and return_state:
 455        return predictor, sam, state
 456    if return_sam:
 457        return predictor, sam
 458    if return_state:
 459        return predictor, state
 460    return predictor
 461
 462
 463def _handle_checkpoint_loading(sam, model_state):
 464    # Whether to handle the mismatch issues in a bit more elegant way.
 465    # eg. while training for multi-class semantic segmentation in the mask encoder,
 466    # parameters are updated - leading to "size mismatch" errors
 467
 468    new_state_dict = {}  # for loading matching parameters
 469    mismatched_layers = []  # for tracking mismatching parameters
 470
 471    reference_state = sam.state_dict()
 472
 473    for k, v in model_state.items():
 474        if k in reference_state:  # This is done to get rid of unwanted layers from pretrained SAM.
 475            if reference_state[k].size() == v.size():
 476                new_state_dict[k] = v
 477            else:
 478                mismatched_layers.append(k)
 479
 480    reference_state.update(new_state_dict)
 481
 482    if len(mismatched_layers) > 0:
 483        warnings.warn(f"The layers with size mismatch: {mismatched_layers}")
 484
 485    for mlayer in mismatched_layers:
 486        if 'weight' in mlayer:
 487            torch.nn.init.kaiming_uniform_(reference_state[mlayer])
 488        elif 'bias' in mlayer:
 489            reference_state[mlayer].zero_()
 490
 491    sam.load_state_dict(reference_state)
 492
 493    return sam
 494
 495
 496def export_custom_sam_model(
 497    checkpoint_path: Union[str, os.PathLike],
 498    model_type: str,
 499    save_path: Union[str, os.PathLike],
 500    with_segmentation_decoder: bool = False,
 501    prefix: str = "sam.",
 502) -> None:
 503    """Export a finetuned Segment Anything Model to the standard model format.
 504
 505    The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
 506
 507    Args:
 508        checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
 509        model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
 510        save_path: Where to save the exported model.
 511        with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well.
 512            If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
 513        prefix: The prefix to remove from the model parameter keys.
 514    """
 515    state, model_state = _load_checkpoint(checkpoint_path=checkpoint_path)
 516    model_state = OrderedDict([(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()])
 517
 518    # Store the 'decoder_state' as well, if desired.
 519    if with_segmentation_decoder:
 520        if "decoder_state" not in state:
 521            raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.")
 522        decoder_state = state["decoder_state"]
 523        save_state = {"model_state": model_state, "decoder_state": decoder_state}
 524    else:
 525        save_state = model_state
 526
 527    torch.save(save_state, save_path)
 528
 529
 530def export_custom_qlora_model(
 531    checkpoint_path: Optional[Union[str, os.PathLike]],
 532    finetuned_path: Union[str, os.PathLike],
 533    model_type: str,
 534    save_path: Union[str, os.PathLike],
 535) -> None:
 536    """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
 537
 538    The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`.
 539
 540    Args:
 541        checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
 542        finetuned_path: The path to the new finetuned model, using QLoRA.
 543        model_type: The Segment Anything Model type corresponding to the checkpoint.
 544        save_path: Where to save the exported model.
 545    """
 546    # Step 1: Get the base SAM model: used to start finetuning from.
 547    _, sam = get_sam_model(
 548        model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True,
 549    )
 550
 551    # Step 2: Load the QLoRA-style finetuned model.
 552    ft_state, ft_model_state = _load_checkpoint(finetuned_path)
 553
 554    # Step 3: Identify LoRA layers from QLoRA model.
 555    # - differentiate between LoRA applied to the attention matrices and LoRA applied to the MLP layers.
 556    # - then copy the LoRA layers from the QLoRA model to the new state dict
 557    updated_model_state = {}
 558
 559    modified_attn_layers = set()
 560    modified_mlp_layers = set()
 561
 562    for k, v in ft_model_state.items():
 563        if "blocks." in k:
 564            layer_id = int(k.split("blocks.")[1].split(".")[0])
 565        if k.find("qkv.w_a_linear") != -1 or k.find("qkv.w_b_linear") != -1:
 566            modified_attn_layers.add(layer_id)
 567            updated_model_state[k] = v
 568        if k.find("mlp.w_a_linear") != -1 or k.find("mlp.w_b_linear") != -1:
 569            modified_mlp_layers.add(layer_id)
 570            updated_model_state[k] = v
 571
 572    # Step 4: Next, we get all the remaining parameters from the base SAM model.
 573    for k, v in sam.state_dict().items():
 574        if "blocks." in k:
 575            layer_id = int(k.split("blocks.")[1].split(".")[0])
 576        if k.find("attn.qkv.") != -1:
 577            if layer_id in modified_attn_layers:  # We have LoRA in QKV layers, so we need to modify the key
 578                k = k.replace("qkv", "qkv.qkv_proj")
 579        elif k.find("mlp") != -1 and k.find("image_encoder") != -1:
 580            if layer_id in modified_mlp_layers:  # We have LoRA in MLP layers, so we need to modify the key
 581                k = k.replace("mlp.", "mlp.mlp_layer.")
 582        updated_model_state[k] = v
 583
 584    # Step 5: Finally, we replace the old model state with the new one (to retain other relevant stuff)
 585    ft_state['model_state'] = updated_model_state
 586
 587    # Step 6: Store the new "state" to "save_path"
 588    torch.save(ft_state, save_path)
 589
 590
 591def get_model_names() -> Iterable:
 592    model_registry = models()
 593    model_names = model_registry.registry.keys()
 594    return model_names
 595
 596
 597#
 598# Functionality for precomputing image embeddings.
 599#
 600
 601
 602def _to_image(input_):
 603    # we require the input to be uint8
 604    if input_.dtype != np.dtype("uint8"):
 605        # first normalize the input to [0, 1]
 606        input_ = input_.astype("float32") - input_.min()
 607        input_ = input_ / input_.max()
 608        # then bring to [0, 255] and cast to uint8
 609        input_ = (input_ * 255).astype("uint8")
 610
 611    if input_.ndim == 2:
 612        image = np.concatenate([input_[..., None]] * 3, axis=-1)
 613    elif input_.ndim == 3 and input_.shape[-1] == 3:
 614        image = input_
 615    else:
 616        raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image.")
 617
 618    # explicitly return a numpy array for compatibility with torchvision
 619    # because the input_ array could be something like dask array
 620    return np.array(image)
 621
 622
 623@torch.no_grad
 624def _compute_embeddings_batched(predictor, batched_images):
 625    predictor.reset_image()
 626    batched_tensors, original_sizes, input_sizes = [], [], []
 627
 628    # Apply proeprocessing to all images in the batch, and then stack them.
 629    # Note: after the transformation the images are all of the same size,
 630    # so they can be stacked and processed as a batch, even if the input images were of different size.
 631    for image in batched_images:
 632        tensor = predictor.transform.apply_image(image)
 633        tensor = torch.as_tensor(tensor, device=predictor.device)
 634        tensor = tensor.permute(2, 0, 1).contiguous()[None, :, :, :]
 635
 636        original_sizes.append(image.shape[:2])
 637        input_sizes.append(tensor.shape[-2:])
 638
 639        tensor = predictor.model.preprocess(tensor)
 640        batched_tensors.append(tensor)
 641
 642    batched_tensors = torch.cat(batched_tensors)
 643    features = predictor.model.image_encoder(batched_tensors)
 644
 645    predictor.original_size = original_sizes[-1]
 646    predictor.input_size = input_sizes[-1]
 647    predictor.features = features[-1]
 648    predictor.is_image_set = True
 649
 650    return features, original_sizes, input_sizes
 651
 652
 653# Wrapper of zarr.create dataset to support zarr v2 and zarr v3.
 654def _create_dataset_with_data(group, name, data, chunks=None):
 655    zarr_major_version = int(zarr.__version__.split(".")[0])
 656    if chunks is None:
 657        chunks = data.shape
 658    if zarr_major_version == 2:
 659        ds = group.create_dataset(
 660            name, data=data, shape=data.shape, compression="gzip", chunks=chunks
 661        )
 662    elif zarr_major_version == 3:
 663        ds = group.create_array(
 664            name, shape=data.shape, compressors=[zarr.codecs.GzipCodec()], chunks=chunks, dtype=data.dtype,
 665        )
 666        ds[:] = data
 667    else:
 668        raise RuntimeError(f"Unsupported zarr version: {zarr_major_version}")
 669    return ds
 670
 671
 672def _create_dataset_without_data(group, name, shape, dtype, chunks):
 673    zarr_major_version = int(zarr.__version__.split(".")[0])
 674    if zarr_major_version == 2:
 675        ds = group.create_dataset(
 676            name, shape=shape, dtype=dtype, compression="gzip", chunks=chunks
 677        )
 678    elif zarr_major_version == 3:
 679        ds = group.create_array(
 680            name, shape=shape, compressors=[zarr.codecs.GzipCodec()], chunks=chunks, dtype=dtype
 681        )
 682    else:
 683        raise RuntimeError(f"Unsupported zarr version: {zarr_major_version}")
 684    return ds
 685
 686
 687def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size):
 688    tiling = blocking([0, 0], input_.shape[:2], tile_shape)
 689    n_tiles = tiling.numberOfBlocks
 690
 691    features = f.require_group("features")
 692    features.attrs["shape"] = input_.shape[:2]
 693    features.attrs["tile_shape"] = tile_shape
 694    features.attrs["halo"] = halo
 695
 696    pbar_init(n_tiles, "Compute Image Embeddings 2D tiled")
 697
 698    n_batches = int(np.ceil(n_tiles / batch_size))
 699    for batch_id in range(n_batches):
 700        tile_start = batch_id * batch_size
 701        tile_stop = min(tile_start + batch_size, n_tiles)
 702
 703        batched_images = []
 704        for tile_id in range(tile_start, tile_stop):
 705            tile = tiling.getBlockWithHalo(tile_id, list(halo))
 706            outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end))
 707            tile_input = _to_image(input_[outer_tile])
 708            batched_images.append(tile_input)
 709
 710        batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
 711        for i, tile_id in enumerate(range(tile_start, tile_stop)):
 712            tile_embeddings, original_size, input_size = batched_embeddings[i], original_sizes[i], input_sizes[i]
 713            # Unsqueeze the channel axis of the tile embeddings.
 714            tile_embeddings = tile_embeddings.unsqueeze(0)
 715            ds = _create_dataset_with_data(features, str(tile_id), data=tile_embeddings.cpu().numpy())
 716            ds.attrs["original_size"] = original_size
 717            ds.attrs["input_size"] = input_size
 718            pbar_update(1)
 719
 720    _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None)
 721    return features
 722
 723
 724def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size):
 725    assert input_.ndim == 3
 726
 727    shape = input_.shape[1:]
 728    tiling = blocking([0, 0], shape, tile_shape)
 729    n_tiles = tiling.numberOfBlocks
 730
 731    features = f.require_group("features")
 732    features.attrs["shape"] = shape
 733    features.attrs["tile_shape"] = tile_shape
 734    features.attrs["halo"] = halo
 735
 736    n_slices = input_.shape[0]
 737    pbar_init(n_tiles * n_slices, "Compute Image Embeddings 3D tiled")
 738
 739    # We batch across the z axis.
 740    n_batches = int(np.ceil(n_slices / batch_size))
 741
 742    for tile_id in range(n_tiles):
 743        tile = tiling.getBlockWithHalo(tile_id, list(halo))
 744        outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end))
 745
 746        ds = None
 747        for batch_id in range(n_batches):
 748            z_start = batch_id * batch_size
 749            z_stop = min(z_start + batch_size, n_slices)
 750
 751            batched_images = []
 752            for z in range(z_start, z_stop):
 753                tile_input = _to_image(input_[z][outer_tile])
 754                batched_images.append(tile_input)
 755
 756            batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
 757            for i, z in enumerate(range(z_start, z_stop)):
 758                tile_embeddings = batched_embeddings[i].unsqueeze(0)
 759                if ds is None:
 760                    shape = (n_slices,) + tile_embeddings.shape
 761                    chunks = (1,) + tile_embeddings.shape
 762                    ds = _create_dataset_without_data(
 763                        features, str(tile_id), shape=shape, dtype="float32", chunks=chunks
 764                    )
 765
 766                ds[z] = tile_embeddings.cpu().numpy()
 767                pbar_update(1)
 768
 769        ds.attrs["original_size"] = original_sizes[-1]
 770        ds.attrs["input_size"] = input_sizes[-1]
 771
 772    _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None)
 773
 774    return features
 775
 776
 777def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update):
 778    # Check if the embeddings are already cached.
 779    if save_path is not None and "input_size" in f.attrs:
 780        # In this case we load the embeddings.
 781        features = f["features"][:]
 782        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 783        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 784        # Also set the embeddings.
 785        set_precomputed(predictor, image_embeddings)
 786        return image_embeddings
 787
 788    pbar_init(1, "Compute Image Embeddings 2D")
 789    # Otherwise we have to compute the embeddings.
 790    predictor.reset_image()
 791    predictor.set_image(_to_image(input_))
 792    features = predictor.get_image_embedding().cpu().numpy()
 793    original_size = predictor.original_size
 794    input_size = predictor.input_size
 795    pbar_update(1)
 796
 797    # Save the embeddings if we have a save_path.
 798    if save_path is not None:
 799        _create_dataset_with_data(f, "features", data=features)
 800        _write_embedding_signature(
 801            f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size,
 802        )
 803
 804    image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 805    return image_embeddings
 806
 807
 808def _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size):
 809    # Check if the features are already computed.
 810    if "input_size" in f.attrs:
 811        features = f["features"]
 812        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 813        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 814        return image_embeddings
 815
 816    # Otherwise compute them. Note: saving happens automatically because we
 817    # always write the features to zarr. If no save path is given we use an in-memory zarr.
 818    features = _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size)
 819    image_embeddings = {"features": features, "input_size": None, "original_size": None}
 820    return image_embeddings
 821
 822
 823def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size):
 824    # Check if the embeddings are already fully cached.
 825    if save_path is not None and "input_size" in f.attrs:
 826        # In this case we load the embeddings.
 827        features = f["features"] if lazy_loading else f["features"][:]
 828        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 829        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 830        return image_embeddings
 831
 832    # Otherwise we have to compute the embeddings.
 833
 834    # First check if we have a save path or not and set things up accordingly.
 835    if save_path is None:
 836        features = []
 837        save_features = False
 838        partial_features = False
 839    else:
 840        save_features = True
 841        embed_shape = (1, 256, 64, 64)
 842        shape = (input_.shape[0],) + embed_shape
 843        chunks = (1,) + embed_shape
 844        if "features" in f:
 845            partial_features = True
 846            features = f["features"]
 847            if features.shape != shape or features.chunks != chunks:
 848                raise RuntimeError("Invalid partial features")
 849        else:
 850            partial_features = False
 851            features = _create_dataset_without_data(f, "features", shape=shape, chunks=chunks, dtype="float32")
 852
 853    # Initialize the pbar and batches.
 854    n_slices = input_.shape[0]
 855    pbar_init(n_slices, "Compute Image Embeddings 3D")
 856    n_batches = int(np.ceil(n_slices / batch_size))
 857
 858    for batch_id in range(n_batches):
 859        z_start = batch_id * batch_size
 860        z_stop = min(z_start + batch_size, n_slices)
 861
 862        batched_images, batched_z = [], []
 863        for z in range(z_start, z_stop):
 864            # Skip feature computation in case of partial features in non-zero slice.
 865            if partial_features and np.count_nonzero(features[z]) != 0:
 866                continue
 867            tile_input = _to_image(input_[z])
 868            batched_images.append(tile_input)
 869            batched_z.append(z)
 870
 871        batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
 872
 873        for z, embedding in zip(batched_z, batched_embeddings):
 874            embedding = embedding.unsqueeze(0)
 875            if save_features:
 876                features[z] = embedding.cpu().numpy()
 877            else:
 878                features.append(embedding.unsqueeze(0))
 879            pbar_update(1)
 880
 881    if save_features:
 882        _write_embedding_signature(
 883            f, input_, predictor, tile_shape=None, halo=None,
 884            input_size=input_sizes[-1], original_size=original_sizes[-1],
 885        )
 886    else:
 887        # Concatenate across the z axis.
 888        features = torch.cat(features).cpu().numpy()
 889
 890    image_embeddings = {"features": features, "input_size": input_sizes[-1], "original_size": original_sizes[-1]}
 891    return image_embeddings
 892
 893
 894def _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size):
 895    # Check if the features are already computed.
 896    if "input_size" in f.attrs:
 897        features = f["features"]
 898        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 899        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 900        return image_embeddings
 901
 902    # Otherwise compute them. Note: saving happens automatically because we
 903    # always write the features to zarr. If no save path is given we use an in-memory zarr.
 904    features = _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size)
 905    image_embeddings = {"features": features, "input_size": None, "original_size": None}
 906    return image_embeddings
 907
 908
 909def _compute_data_signature(input_):
 910    data_signature = hashlib.sha1(np.asarray(input_).tobytes()).hexdigest()
 911    return data_signature
 912
 913
 914# Create all metadata that is stored along with the embeddings.
 915def _get_embedding_signature(input_, predictor, tile_shape, halo, data_signature=None):
 916    if data_signature is None:
 917        data_signature = _compute_data_signature(input_)
 918
 919    signature = {
 920        "data_signature": data_signature,
 921        "tile_shape": tile_shape if tile_shape is None else list(tile_shape),
 922        "halo": halo if halo is None else list(halo),
 923        "model_type": predictor.model_type,
 924        "model_name": predictor.model_name,
 925        "micro_sam_version": __version__,
 926        "model_hash": getattr(predictor, "_hash", None),
 927    }
 928    return signature
 929
 930
 931# Note: the input size and orginal size are different if embeddings are tiled or not.
 932# That's why we do not include them in the main signature that is being checked
 933# (_get_embedding_signature), but just add it for serialization here.
 934def _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size, original_size):
 935    signature = _get_embedding_signature(input_, predictor, tile_shape, halo)
 936    signature.update({"input_size": input_size, "original_size": original_size})
 937    for key, val in signature.items():
 938        f.attrs[key] = val
 939
 940
 941def _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo):
 942    # We may have an empty zarr file that was already created to save the embeddings in.
 943    # In this case the embeddings will be computed and we don't need to perform any checks.
 944    if "input_size" not in f.attrs:
 945        return
 946
 947    signature = _get_embedding_signature(input_, predictor, tile_shape, halo)
 948    for key, val in signature.items():
 949        # Check whether the key is missing from the attrs or if the value is not matching.
 950        if key not in f.attrs or f.attrs[key] != val:
 951            # These keys were recently added, so we don't want to fail yet if they don't
 952            # match in order to not invalidate previous embedding files.
 953            # Instead we just raise a warning. (For the version we probably also don't want to fail
 954            # i the future since it should not invalidate the embeddings).
 955            if key in ("micro_sam_version", "model_hash", "model_name"):
 956                warnings.warn(
 957                    f"The signature for {key} in embeddings file {save_path} has a mismatch: "
 958                    f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. "
 959                    "But please recompute them if model predictions don't look as expected."
 960                )
 961            else:
 962                raise RuntimeError(
 963                    f"Embeddings file {save_path} is invalid due to mismatch in {key}: "
 964                    f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file."
 965                )
 966
 967
 968# Helper function for optional external progress bars.
 969def handle_pbar(verbose, pbar_init, pbar_update):
 970    """@private"""
 971
 972    # Noop to provide dummy functions.
 973    def noop(*args):
 974        pass
 975
 976    if verbose and pbar_init is None:  # we are verbose and don't have an external progress bar.
 977        assert pbar_update is None  # avoid inconsistent state of callbacks
 978
 979        # Create our own progress bar and callbacks
 980        pbar = tqdm()
 981
 982        def pbar_init(total, description):
 983            pbar.total = total
 984            pbar.set_description(description)
 985
 986        def pbar_update(update):
 987            pbar.update(update)
 988
 989        def pbar_close():
 990            pbar.close()
 991
 992    elif verbose and pbar_init is not None:  # external pbar -> we don't have to do anything
 993        assert pbar_update is not None
 994        pbar = None
 995        pbar_close = noop
 996
 997    else:  # we are not verbose, do nothing
 998        pbar = None
 999        pbar_init, pbar_update, pbar_close = noop, noop, noop
1000
1001    return pbar, pbar_init, pbar_update, pbar_close
1002
1003
1004def precompute_image_embeddings(
1005    predictor: SamPredictor,
1006    input_: np.ndarray,
1007    save_path: Optional[Union[str, os.PathLike]] = None,
1008    lazy_loading: bool = False,
1009    ndim: Optional[int] = None,
1010    tile_shape: Optional[Tuple[int, int]] = None,
1011    halo: Optional[Tuple[int, int]] = None,
1012    verbose: bool = True,
1013    batch_size: int = 1,
1014    pbar_init: Optional[callable] = None,
1015    pbar_update: Optional[callable] = None,
1016) -> ImageEmbeddings:
1017    """Compute the image embeddings (output of the encoder) for the input.
1018
1019    If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
1020
1021    Args:
1022        predictor: The Segment Anything predictor.
1023        input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
1024        save_path: Path to save the embeddings in a zarr container.
1025            By default, set to 'None', i.e. the computed embeddings will not be stored locally.
1026        lazy_loading: Whether to load all embeddings into memory or return an
1027            object to load them on demand when required. This only has an effect if 'save_path' is given
1028            and if the input is 3 dimensional. By default, set to 'False'.
1029        ndim: The dimensionality of the data. If not given will be deduced from the input data.
1030            By default, set to 'None', i.e. will be computed from the provided `input_`.
1031        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
1032        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
1033        verbose: Whether to be verbose in the computation. By default, set to 'True'.
1034        batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'.
1035        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1036            Can be used together with pbar_update to handle napari progress bar in other thread.
1037            To enables using this function within a threadworker.
1038        pbar_update: Callback to update an external progress bar.
1039
1040    Returns:
1041        The image embeddings.
1042    """
1043    ndim = input_.ndim if ndim is None else ndim
1044
1045    # Handle the embedding save_path.
1046    # We don't have a save path, open in memory zarr file to hold tiled embeddings.
1047    if save_path is None:
1048        f = zarr.group()
1049
1050    # We have a save path and it already exists. Embeddings will be loaded from it,
1051    # check that the saved embeddings in there match the parameters of the function call.
1052    elif os.path.exists(save_path):
1053        f = zarr.open(save_path, mode="a")
1054        _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo)
1055
1056    # We have a save path and it does not exist yet. Create the zarr file to which the
1057    # embeddings will then be saved.
1058    else:
1059        f = zarr.open(save_path, mode="a")
1060
1061    _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update)
1062
1063    if ndim == 2 and tile_shape is None:
1064        embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update)
1065    elif ndim == 2 and tile_shape is not None:
1066        embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size)
1067    elif ndim == 3 and tile_shape is None:
1068        embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size)
1069    elif ndim == 3 and tile_shape is not None:
1070        embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size)
1071    else:
1072        raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.")
1073
1074    pbar_close()
1075    return embeddings
1076
1077
1078def set_precomputed(
1079    predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None,
1080) -> SamPredictor:
1081    """Set the precomputed image embeddings for a predictor.
1082
1083    Args:
1084        predictor: The Segment Anything predictor.
1085        image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`.
1086        i: Index for the image data. Required if `image` has three spatial dimensions
1087            or a time dimension and two spatial dimensions.
1088        tile_id: Index for the tile. This is required if the embeddings are tiled.
1089
1090    Returns:
1091        The predictor with set features.
1092    """
1093    if tile_id is not None:
1094        tile_features = image_embeddings["features"][str(tile_id)]
1095        tile_image_embeddings = {
1096            "features": tile_features,
1097            "input_size": tile_features.attrs["input_size"],
1098            "original_size": tile_features.attrs["original_size"]
1099        }
1100        return set_precomputed(predictor, tile_image_embeddings, i=i)
1101
1102    device = predictor.device
1103    features = image_embeddings["features"]
1104    assert features.ndim in (4, 5), f"{features.ndim}"
1105    if features.ndim == 5 and i is None:
1106        raise ValueError("The data is 3D so an index i is needed.")
1107    elif features.ndim == 4 and i is not None:
1108        raise ValueError("The data is 2D so an index is not needed.")
1109
1110    if i is None:
1111        predictor.features = features.to(device) if torch.is_tensor(features) else \
1112            torch.from_numpy(features[:]).to(device)
1113    else:
1114        predictor.features = features[i].to(device) if torch.is_tensor(features) else \
1115            torch.from_numpy(features[i]).to(device)
1116
1117    predictor.original_size = image_embeddings["original_size"]
1118    predictor.input_size = image_embeddings["input_size"]
1119    predictor.is_image_set = True
1120
1121    return predictor
1122
1123
1124#
1125# Misc functionality
1126#
1127
1128
1129def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
1130    """Compute the intersection over union of two masks.
1131
1132    Args:
1133        mask1: The first mask.
1134        mask2: The second mask.
1135
1136    Returns:
1137        The intersection over union of the two masks.
1138    """
1139    overlap = np.logical_and(mask1 == 1, mask2 == 1).sum()
1140    union = np.logical_or(mask1 == 1, mask2 == 1).sum()
1141    eps = 1e-7
1142    iou = float(overlap) / (float(union) + eps)
1143    return iou
1144
1145
1146def get_centers_and_bounding_boxes(
1147    segmentation: np.ndarray, mode: str = "v"
1148) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]:
1149    """Returns the center coordinates of the foreground instances in the ground-truth.
1150
1151    Args:
1152        segmentation: The segmentation.
1153        mode: Determines the functionality used for computing the centers.
1154            If 'v', the object's eccentricity centers computed by vigra are used.
1155            If 'p' the object's centroids computed by skimage are used.
1156
1157    Returns:
1158        A dictionary that maps object ids to the corresponding centroid.
1159        A dictionary that maps object_ids to the corresponding bounding box.
1160    """
1161    assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra"
1162
1163    properties = regionprops(segmentation)
1164
1165    if mode == "p":
1166        center_coordinates = {prop.label: prop.centroid for prop in properties}
1167    elif mode == "v":
1168        center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32'))
1169        center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0}
1170
1171    bbox_coordinates = {prop.label: prop.bbox for prop in properties}
1172
1173    assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}"
1174    return center_coordinates, bbox_coordinates
1175
1176
1177def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray:
1178    """Helper function to load image data from file.
1179
1180    Args:
1181        path: The filepath to the image data.
1182        key: The internal filepath for complex data formats like hdf5.
1183        lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
1184
1185    Returns:
1186        The image data.
1187    """
1188    if key is None:
1189        image_data = imageio.imread(path)
1190    else:
1191        with open_file(path, mode="r") as f:
1192            image_data = f[key]
1193            if not lazy_loading:
1194                image_data = image_data[:]
1195
1196    return image_data
1197
1198
1199def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor:
1200    """Convert the segmentation to one-hot encoded masks.
1201
1202    Args:
1203        segmentation: The segmentation.
1204        segmentation_ids: Optional subset of ids that will be used to subsample the masks.
1205            By default, computes the number of ids from the provided `segmentation` masks.
1206
1207    Returns:
1208        The one-hot encoded masks.
1209    """
1210    masks = segmentation.copy()
1211    if segmentation_ids is None:
1212        n_ids = int(segmentation.max())
1213
1214    else:
1215        msg = "No foreground objects were found."
1216        if len(segmentation_ids) == 0:  # The list should not be completely empty.
1217            raise RuntimeError(msg)
1218
1219        if 0 in segmentation_ids:  # The list should not have 'zero' as a value.
1220            raise RuntimeError(msg)
1221
1222        # the segmentation ids have to be sorted
1223        segmentation_ids = np.sort(segmentation_ids)
1224
1225        # set the non selected objects to zero and relabel sequentially
1226        masks[~np.isin(masks, segmentation_ids)] = 0
1227        masks = relabel_sequential(masks)[0]
1228        n_ids = len(segmentation_ids)
1229
1230    masks = torch.from_numpy(masks)
1231
1232    one_hot_shape = (n_ids + 1,) + masks.shape
1233    masks = masks.unsqueeze(0)  # add dimension to scatter
1234    masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:]
1235
1236    # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W
1237    masks = masks.unsqueeze(1)
1238    return masks
1239
1240
1241def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1242    """Get a suitable block shape for chunking a given shape.
1243
1244    The primary use for this is determining chunk sizes for
1245    zarr arrays or block shapes for parallelization.
1246
1247    Args:
1248        shape: The image or volume shape.
1249
1250    Returns:
1251        The block shape.
1252    """
1253    ndim = len(shape)
1254    if ndim == 2:
1255        block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape))
1256    elif ndim == 3:
1257        block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape))
1258    else:
1259        raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.")
1260
1261    return block_shape
1262
1263
1264def micro_sam_info():
1265    """Display μSAM information using a rich console."""
1266    import psutil
1267    import platform
1268    import argparse
1269    from rich import progress
1270    from rich.panel import Panel
1271    from rich.table import Table
1272    from rich.console import Console
1273
1274    import torch
1275    import micro_sam
1276
1277    parser = argparse.ArgumentParser(description="μSAM Information Booth")
1278    parser.add_argument(
1279        "--download", nargs="+", metavar=("WHAT", "KIND"),
1280        help="Downloads the pretrained SAM models."
1281        "'--download models' -> downloads all pretrained models; "
1282        "'--download models vit_b_lm vit_b_em_organelles' -> downloads the listed models; "
1283        "'--download model/models vit_b_lm' -> downloads a single specified model."
1284    )
1285    args = parser.parse_args()
1286
1287    # Open up a new console.
1288    console = Console()
1289
1290    # The header for information CLI.
1291    console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center")
1292    console.print("-" * console.width)
1293
1294    # μSAM version panel.
1295    console.print(
1296        Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True)
1297    )
1298
1299    # The documentation link panel.
1300    console.print(
1301        Panel(
1302            "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n"
1303            "https://computational-cell-analytics.github.io/micro-sam", title="Documentation"
1304        )
1305    )
1306
1307    # The publication panel.
1308    console.print(
1309        Panel(
1310            "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n"
1311            "https://www.nature.com/articles/s41592-024-02580-4", title="Publication"
1312        )
1313    )
1314
1315    # The cache directory panel.
1316    console.print(
1317        Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{get_cache_directory()}", title="Cache Directory")
1318    )
1319
1320    # The available models panel.
1321    available_models = list(get_model_names())
1322    # We filter out the decoder models.
1323    available_models = [m for m in available_models if not m.endswith("_decoder")]
1324    model_list = "\n".join(available_models)
1325    console.print(
1326        Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models")
1327    )
1328
1329    # The system information table.
1330    total_memory = psutil.virtual_memory().total / (1024 ** 3)
1331    table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True)
1332    table.add_column("Property")
1333    table.add_column("Value", style="bold #56B4E9")
1334    table.add_row("System", platform.system())
1335    table.add_row("Node Name", platform.node())
1336    table.add_row("Release", platform.release())
1337    table.add_row("Version", platform.version())
1338    table.add_row("Machine", platform.machine())
1339    table.add_row("Processor", platform.processor())
1340    table.add_row("Platform", platform.platform())
1341    table.add_row("Total RAM (GB)", f"{total_memory:.2f}")
1342    console.print(table)
1343
1344    # The device information and check for available GPU acceleration.
1345    default_device = _get_default_device()
1346
1347    if default_device == "cuda":
1348        device_index = torch.cuda.current_device()
1349        device_name = torch.cuda.get_device_name(device_index)
1350        console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information"))
1351    elif default_device == "mps":
1352        console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information"))
1353    else:
1354        console.print(
1355            Panel(
1356                "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]",
1357                title="Device Information"
1358            )
1359        )
1360
1361    # The section allowing to download models.
1362    # NOTE: In future, can be extended to download sample data.
1363    if args.download:
1364        download_provided_args = [t.lower() for t in args.download]
1365        mode, *model_types = download_provided_args
1366
1367        if mode not in {"models", "model"}:
1368            console.print(f"[red]Unknown option for --download: {mode}[/]")
1369            return
1370
1371        if mode in ["model", "models"] and not model_types:  # If user did not specify, we will download all models.
1372            download_list = available_models
1373        else:
1374            download_list = model_types
1375            incorrect_models = [m for m in download_list if m not in available_models]
1376            if incorrect_models:
1377                console.print(Panel("[red]Unknown model(s):[/] " + ", ".join(incorrect_models), title="Download Error"))
1378                return
1379
1380        with progress.Progress(
1381            progress.SpinnerColumn(),
1382            progress.TextColumn("[progress.description]{task.description}"),
1383            progress.BarColumn(bar_width=None),
1384            "[progress.percentage]{task.percentage:>3.0f}%",
1385            progress.TimeRemainingColumn(),
1386            console=console,
1387        ) as prog:
1388            task = prog.add_task("[green]Downloading μSAM models…", total=len(download_list))
1389            for model_type in download_list:
1390                prog.update(task, description=f"Downloading [cyan]{model_type}[/]…")
1391                _download_sam_model(model_type=model_type)
1392                prog.advance(task)
1393
1394        console.print(Panel("[bold green] Downloads complete![/]", title="Finished"))
def get_cache_directory() -> None:
60def get_cache_directory() -> None:
61    """Get micro-sam cache directory location.
62
63    Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
64    """
65    default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam"))
66    cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory))
67    return cache_directory

Get micro-sam cache directory location.

Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.

def microsam_cachedir() -> None:
75def microsam_cachedir() -> None:
76    """Return the micro-sam cache directory.
77
78    Returns the top level cache directory for micro-sam models and sample data.
79
80    Every time this function is called, we check for any user updates made to
81    the MICROSAM_CACHEDIR os environment variable since the last time.
82    """
83    cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam")
84    return cache_directory

Return the micro-sam cache directory.

Returns the top level cache directory for micro-sam models and sample data.

Every time this function is called, we check for any user updates made to the MICROSAM_CACHEDIR os environment variable since the last time.

def models():
 87def models():
 88    """Return the segmentation models registry.
 89
 90    We recreate the model registry every time this function is called,
 91    so any user changes to the default micro-sam cache directory location
 92    are respected.
 93    """
 94
 95    # We use xxhash to compute the hash of the models, see
 96    # https://github.com/computational-cell-analytics/micro-sam/issues/283
 97    # (It is now a dependency, so we don't provide the sha256 fallback anymore.)
 98    # To generate the xxh128 hash:
 99    #     xxh128sum filename
100    encoder_registry = {
101        # The default segment anything models:
102        "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a",
103        "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2",
104        "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b",
105        # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM.
106        "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0",
107        # The current version of our models in the modelzoo.
108        # LM generalist models:
109        "vit_l_lm": "xxh128:017f20677997d628426dec80a8018f9d",
110        "vit_b_lm": "xxh128:fe9252a29f3f4ea53c15a06de471e186",
111        "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e",
112        # EM models:
113        "vit_l_em_organelles": "xxh128:810b084b6e51acdbf760a993d8619f2d",
114        "vit_b_em_organelles": "xxh128:f3bf2ed83d691456bae2c3f9a05fb438",
115        "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9",
116        # Histopathology models:
117        "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974",
118        "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2",
119        "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e",
120        # Medical Imaging models:
121        "vit_b_medical_imaging": "xxh128:5be672f1458263a9edc9fd40d7f56ac1",
122    }
123    # Additional decoders for instance segmentation.
124    decoder_registry = {
125        # LM generalist models:
126        "vit_l_lm_decoder": "xxh128:2faeafa03819dfe03e7c46a44aaac64a",
127        "vit_b_lm_decoder": "xxh128:708b15ac620e235f90bb38612c4929ba",
128        "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823",
129        # EM models:
130        "vit_l_em_organelles_decoder": "xxh128:334877640bfdaaabce533e3252a17294",
131        "vit_b_em_organelles_decoder": "xxh128:bb6398956a6b0132c26b631c14f95ce2",
132        "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44",
133        # Histopathology models:
134        "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213",
135        "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04",
136        "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47",
137    }
138    registry = {**encoder_registry, **decoder_registry}
139
140    encoder_urls = {
141        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
142        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
143        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
144        "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
145        "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l.pt",
146        "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b.pt",
147        "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt",
148        "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l.pt",  # noqa
149        "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b.pt",  # noqa
150        "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt",  # noqa
151        "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download",
152        "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download",
153        "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download",
154        "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/AB69HGhj8wuozXQ/download",
155    }
156
157    decoder_urls = {
158        "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l_decoder.pt",  # noqa
159        "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b_decoder.pt",  # noqa
160        "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt",  # noqa
161        "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l_decoder.pt",  # noqa
162        "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b_decoder.pt",  # noqa
163        "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt",  # noqa
164        "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download",
165        "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download",
166        "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download",
167    }
168    urls = {**encoder_urls, **decoder_urls}
169
170    models = pooch.create(
171        path=os.path.join(microsam_cachedir(), "models"),
172        base_url="",
173        registry=registry,
174        urls=urls,
175    )
176    return models

Return the segmentation models registry.

We recreate the model registry every time this function is called, so any user changes to the default micro-sam cache directory location are respected.

def get_device( device: Union[str, torch.device, NoneType] = None) -> Union[str, torch.device]:
198def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
199    """Get the torch device.
200
201    If no device is passed the default device for your system is used.
202    Else it will be checked if the device you have passed is supported.
203
204    Args:
205        device: The input device. By default, selects the best available device supports.
206
207    Returns:
208        The device.
209    """
210    if device is None or device == "auto":
211        device = _get_default_device()
212    else:
213        device_type = device if isinstance(device, str) else device.type
214        if device_type.lower() == "cuda":
215            if not torch.cuda.is_available():
216                raise RuntimeError("PyTorch CUDA backend is not available.")
217        elif device_type.lower() == "mps":
218            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
219                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
220        elif device_type.lower() == "cpu":
221            pass  # cpu is always available
222        else:
223            raise RuntimeError(f"Unsupported device: '{device}'. Please choose from 'cpu', 'cuda', or 'mps'.")
224
225    return device

Get the torch device.

If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.

Arguments:
  • device: The input device. By default, selects the best available device supports.
Returns:

The device.

def get_sam_model( model_type: str = 'vit_b_lm', device: Union[str, torch.device, NoneType] = None, checkpoint_path: Union[str, os.PathLike, NoneType] = None, return_sam: bool = False, return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, progress_bar_factory: Optional[Callable] = None, **model_kwargs) -> segment_anything.predictor.SamPredictor:
312def get_sam_model(
313    model_type: str = _DEFAULT_MODEL,
314    device: Optional[Union[str, torch.device]] = None,
315    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
316    return_sam: bool = False,
317    return_state: bool = False,
318    peft_kwargs: Optional[Dict] = None,
319    flexible_load_checkpoint: bool = False,
320    progress_bar_factory: Optional[Callable] = None,
321    **model_kwargs,
322) -> SamPredictor:
323    r"""Get the Segment Anything Predictor.
324
325    This function will download the required model or load it from the cached weight file.
326    This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
327    The name of the requested model can be set via `model_type`.
328    See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
329    for an overview of the available models
330
331    Alternatively this function can also load a model from weights stored in a local filepath.
332    The corresponding file path is given via `checkpoint_path`. In this case `model_type`
333    must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
334    a SAM model with vit_b encoder.
335
336    By default the models are downloaded to a folder named 'micro_sam/models'
337    inside your default cache directory, eg:
338    * Mac: ~/Library/Caches/<AppName>
339    * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined.
340    * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache
341    See the pooch.os_cache() documentation for more details:
342    https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
343
344    Args:
345        model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default.
346            To get a list of all available model names you can call `micro_sam.util.get_model_names`.
347        device: The device for the model. If 'None' is provided, will use GPU if available.
348        checkpoint_path: The path to a file with weights that should be used instead of using the
349            weights corresponding to `model_type`. If given, `model_type` must match the architecture
350            corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder
351            then `model_type` must be given as 'vit_b'.
352        return_sam: Return the sam model object as well as the predictor. By default, set to 'False'.
353        return_state: Return the unpickled checkpoint state. By default, set to 'False'.
354        peft_kwargs: Keyword arguments for th PEFT wrapper class.
355            If passed 'None', it does not initialize any parameter efficient finetuning.
356        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
357            By default, set to 'False'.
358        progress_bar_factory: A function to create a progress bar for the model download.
359        model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
360
361    Returns:
362        The Segment Anything predictor.
363    """
364    device = get_device(device)
365
366    # We support passing a local filepath to a checkpoint.
367    # In this case we do not download any weights but just use the local weight file,
368    # as it is, without copying it over anywhere or checking it's hashes.
369
370    # checkpoint_path has not been passed, we download a known model and derive the correct
371    # URL from the model_type. If the model_type is invalid pooch will raise an error.
372    _provided_checkpoint_path = checkpoint_path is not None
373    if checkpoint_path is None:
374        checkpoint_path, model_hash, decoder_path = _download_sam_model(model_type, progress_bar_factory)
375
376    # checkpoint_path has been passed, we use it instead of downloading a model.
377    else:
378        # Check if the file exists and raise an error otherwise.
379        # We can't check any hashes here, and we don't check if the file is actually a valid weight file.
380        # (If it isn't the model creation will fail below.)
381        if not os.path.exists(checkpoint_path):
382            raise ValueError(f"Checkpoint at '{checkpoint_path}' could not be found.")
383        model_hash = _compute_hash(checkpoint_path)
384        decoder_path = None
385
386    # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped
387    # before calling sam_model_registry.
388    abbreviated_model_type = model_type[:5]
389    if abbreviated_model_type not in _MODEL_TYPES:
390        raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}")
391    if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
392        raise RuntimeError(
393            "'mobile_sam' is required for the vit-tiny. "
394            "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
395        )
396
397    state, model_state = _load_checkpoint(checkpoint_path)
398
399    if _provided_checkpoint_path:
400        # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type'
401        # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination.
402        from micro_sam.models.build_sam import _validate_model_type
403        _provided_model_type = _validate_model_type(model_state)
404
405        # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type'
406        # Otherwise replace 'abbreviated_model_type' with the later.
407        if abbreviated_model_type != _provided_model_type:
408            # Printing the message below to avoid any filtering of warnings on user's end.
409            print(
410                f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', "
411                f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. "
412                f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. "
413                "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future."
414            )
415
416        # Replace the extracted 'abbreviated_model_type' subjected to the model weights.
417        abbreviated_model_type = _provided_model_type
418
419    # Whether to update parameters necessary to initialize the model
420    if model_kwargs:  # Checks whether model_kwargs have been provided or not
421        if abbreviated_model_type == "vit_t":
422            raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.")
423        sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs)
424
425    else:
426        sam = sam_model_registry[abbreviated_model_type]()
427
428    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
429    # Overwrites the SAM model by freezing the backbone and allow PEFT.
430    if peft_kwargs and isinstance(peft_kwargs, dict):
431        # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference.
432        peft_kwargs.pop("quantize", None)
433
434        if abbreviated_model_type == "vit_t":
435            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
436
437        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
438    # In case the model checkpoints have some issues when it is initialized with different parameters than default.
439    if flexible_load_checkpoint:
440        sam = _handle_checkpoint_loading(sam, model_state)
441    else:
442        sam.load_state_dict(model_state)
443    sam.to(device=device)
444
445    predictor = SamPredictor(sam)
446    predictor.model_type = abbreviated_model_type
447    predictor._hash = model_hash
448    predictor.model_name = model_type
449    predictor.checkpoint_path = checkpoint_path
450
451    # Add the decoder to the state if we have one and if the state is returned.
452    if decoder_path is not None and return_state:
453        state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False)
454
455    if return_sam and return_state:
456        return predictor, sam, state
457    if return_sam:
458        return predictor, sam
459    if return_state:
460        return predictor, state
461    return predictor

Get the Segment Anything Predictor.

This function will download the required model or load it from the cached weight file. This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. The name of the requested model can be set via model_type. See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models for an overview of the available models

Alternatively this function can also load a model from weights stored in a local filepath. The corresponding file path is given via checkpoint_path. In this case model_type must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for a SAM model with vit_b encoder.

By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg:

Arguments:
  • model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default. To get a list of all available model names you can call get_model_names.
  • device: The device for the model. If 'None' is provided, will use GPU if available.
  • checkpoint_path: The path to a file with weights that should be used instead of using the weights corresponding to model_type. If given, model_type must match the architecture corresponding to the weight file. e.g. if you use weights for SAM with vit_b encoder then model_type must be given as 'vit_b'.
  • return_sam: Return the sam model object as well as the predictor. By default, set to 'False'.
  • return_state: Return the unpickled checkpoint state. By default, set to 'False'.
  • peft_kwargs: Keyword arguments for th PEFT wrapper class. If passed 'None', it does not initialize any parameter efficient finetuning.
  • flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. By default, set to 'False'.
  • progress_bar_factory: A function to create a progress bar for the model download.
  • model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
Returns:

The Segment Anything predictor.

def export_custom_sam_model( checkpoint_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike], with_segmentation_decoder: bool = False, prefix: str = 'sam.') -> None:
497def export_custom_sam_model(
498    checkpoint_path: Union[str, os.PathLike],
499    model_type: str,
500    save_path: Union[str, os.PathLike],
501    with_segmentation_decoder: bool = False,
502    prefix: str = "sam.",
503) -> None:
504    """Export a finetuned Segment Anything Model to the standard model format.
505
506    The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
507
508    Args:
509        checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
510        model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
511        save_path: Where to save the exported model.
512        with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well.
513            If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
514        prefix: The prefix to remove from the model parameter keys.
515    """
516    state, model_state = _load_checkpoint(checkpoint_path=checkpoint_path)
517    model_state = OrderedDict([(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()])
518
519    # Store the 'decoder_state' as well, if desired.
520    if with_segmentation_decoder:
521        if "decoder_state" not in state:
522            raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.")
523        decoder_state = state["decoder_state"]
524        save_state = {"model_state": model_state, "decoder_state": decoder_state}
525    else:
526        save_state = model_state
527
528    torch.save(save_state, save_path)

Export a finetuned Segment Anything Model to the standard model format.

The exported model can be used by the interactive annotation tools in micro_sam.annotator.

Arguments:
  • checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
  • model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
  • save_path: Where to save the exported model.
  • with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
  • prefix: The prefix to remove from the model parameter keys.
def export_custom_qlora_model( checkpoint_path: Union[str, os.PathLike, NoneType], finetuned_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike]) -> None:
531def export_custom_qlora_model(
532    checkpoint_path: Optional[Union[str, os.PathLike]],
533    finetuned_path: Union[str, os.PathLike],
534    model_type: str,
535    save_path: Union[str, os.PathLike],
536) -> None:
537    """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
538
539    The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`.
540
541    Args:
542        checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
543        finetuned_path: The path to the new finetuned model, using QLoRA.
544        model_type: The Segment Anything Model type corresponding to the checkpoint.
545        save_path: Where to save the exported model.
546    """
547    # Step 1: Get the base SAM model: used to start finetuning from.
548    _, sam = get_sam_model(
549        model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True,
550    )
551
552    # Step 2: Load the QLoRA-style finetuned model.
553    ft_state, ft_model_state = _load_checkpoint(finetuned_path)
554
555    # Step 3: Identify LoRA layers from QLoRA model.
556    # - differentiate between LoRA applied to the attention matrices and LoRA applied to the MLP layers.
557    # - then copy the LoRA layers from the QLoRA model to the new state dict
558    updated_model_state = {}
559
560    modified_attn_layers = set()
561    modified_mlp_layers = set()
562
563    for k, v in ft_model_state.items():
564        if "blocks." in k:
565            layer_id = int(k.split("blocks.")[1].split(".")[0])
566        if k.find("qkv.w_a_linear") != -1 or k.find("qkv.w_b_linear") != -1:
567            modified_attn_layers.add(layer_id)
568            updated_model_state[k] = v
569        if k.find("mlp.w_a_linear") != -1 or k.find("mlp.w_b_linear") != -1:
570            modified_mlp_layers.add(layer_id)
571            updated_model_state[k] = v
572
573    # Step 4: Next, we get all the remaining parameters from the base SAM model.
574    for k, v in sam.state_dict().items():
575        if "blocks." in k:
576            layer_id = int(k.split("blocks.")[1].split(".")[0])
577        if k.find("attn.qkv.") != -1:
578            if layer_id in modified_attn_layers:  # We have LoRA in QKV layers, so we need to modify the key
579                k = k.replace("qkv", "qkv.qkv_proj")
580        elif k.find("mlp") != -1 and k.find("image_encoder") != -1:
581            if layer_id in modified_mlp_layers:  # We have LoRA in MLP layers, so we need to modify the key
582                k = k.replace("mlp.", "mlp.mlp_layer.")
583        updated_model_state[k] = v
584
585    # Step 5: Finally, we replace the old model state with the new one (to retain other relevant stuff)
586    ft_state['model_state'] = updated_model_state
587
588    # Step 6: Store the new "state" to "save_path"
589    torch.save(ft_state, save_path)

Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.

The exported model can be used with the LoRA backbone by passing the relevant peft_kwargs to get_sam_model.

Arguments:
  • checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
  • finetuned_path: The path to the new finetuned model, using QLoRA.
  • model_type: The Segment Anything Model type corresponding to the checkpoint.
  • save_path: Where to save the exported model.
def get_model_names() -> Iterable:
592def get_model_names() -> Iterable:
593    model_registry = models()
594    model_names = model_registry.registry.keys()
595    return model_names
def precompute_image_embeddings( predictor: segment_anything.predictor.SamPredictor, input_: numpy.ndarray, save_path: Union[str, os.PathLike, NoneType] = None, lazy_loading: bool = False, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, batch_size: int = 1, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> Dict[str, Any]:
1005def precompute_image_embeddings(
1006    predictor: SamPredictor,
1007    input_: np.ndarray,
1008    save_path: Optional[Union[str, os.PathLike]] = None,
1009    lazy_loading: bool = False,
1010    ndim: Optional[int] = None,
1011    tile_shape: Optional[Tuple[int, int]] = None,
1012    halo: Optional[Tuple[int, int]] = None,
1013    verbose: bool = True,
1014    batch_size: int = 1,
1015    pbar_init: Optional[callable] = None,
1016    pbar_update: Optional[callable] = None,
1017) -> ImageEmbeddings:
1018    """Compute the image embeddings (output of the encoder) for the input.
1019
1020    If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
1021
1022    Args:
1023        predictor: The Segment Anything predictor.
1024        input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
1025        save_path: Path to save the embeddings in a zarr container.
1026            By default, set to 'None', i.e. the computed embeddings will not be stored locally.
1027        lazy_loading: Whether to load all embeddings into memory or return an
1028            object to load them on demand when required. This only has an effect if 'save_path' is given
1029            and if the input is 3 dimensional. By default, set to 'False'.
1030        ndim: The dimensionality of the data. If not given will be deduced from the input data.
1031            By default, set to 'None', i.e. will be computed from the provided `input_`.
1032        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
1033        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
1034        verbose: Whether to be verbose in the computation. By default, set to 'True'.
1035        batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'.
1036        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1037            Can be used together with pbar_update to handle napari progress bar in other thread.
1038            To enables using this function within a threadworker.
1039        pbar_update: Callback to update an external progress bar.
1040
1041    Returns:
1042        The image embeddings.
1043    """
1044    ndim = input_.ndim if ndim is None else ndim
1045
1046    # Handle the embedding save_path.
1047    # We don't have a save path, open in memory zarr file to hold tiled embeddings.
1048    if save_path is None:
1049        f = zarr.group()
1050
1051    # We have a save path and it already exists. Embeddings will be loaded from it,
1052    # check that the saved embeddings in there match the parameters of the function call.
1053    elif os.path.exists(save_path):
1054        f = zarr.open(save_path, mode="a")
1055        _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo)
1056
1057    # We have a save path and it does not exist yet. Create the zarr file to which the
1058    # embeddings will then be saved.
1059    else:
1060        f = zarr.open(save_path, mode="a")
1061
1062    _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update)
1063
1064    if ndim == 2 and tile_shape is None:
1065        embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update)
1066    elif ndim == 2 and tile_shape is not None:
1067        embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size)
1068    elif ndim == 3 and tile_shape is None:
1069        embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size)
1070    elif ndim == 3 and tile_shape is not None:
1071        embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size)
1072    else:
1073        raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.")
1074
1075    pbar_close()
1076    return embeddings

Compute the image embeddings (output of the encoder) for the input.

If 'save_path' is given the embeddings will be loaded/saved in a zarr container.

Arguments:
  • predictor: The Segment Anything predictor.
  • input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
  • save_path: Path to save the embeddings in a zarr container. By default, set to 'None', i.e. the computed embeddings will not be stored locally.
  • lazy_loading: Whether to load all embeddings into memory or return an object to load them on demand when required. This only has an effect if 'save_path' is given and if the input is 3 dimensional. By default, set to 'False'.
  • ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, set to 'None', i.e. will be computed from the provided input_.
  • tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
  • verbose: Whether to be verbose in the computation. By default, set to 'True'.
  • batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
Returns:

The image embeddings.

def set_precomputed( predictor: segment_anything.predictor.SamPredictor, image_embeddings: Dict[str, Any], i: Optional[int] = None, tile_id: Optional[int] = None) -> segment_anything.predictor.SamPredictor:
1079def set_precomputed(
1080    predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None,
1081) -> SamPredictor:
1082    """Set the precomputed image embeddings for a predictor.
1083
1084    Args:
1085        predictor: The Segment Anything predictor.
1086        image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`.
1087        i: Index for the image data. Required if `image` has three spatial dimensions
1088            or a time dimension and two spatial dimensions.
1089        tile_id: Index for the tile. This is required if the embeddings are tiled.
1090
1091    Returns:
1092        The predictor with set features.
1093    """
1094    if tile_id is not None:
1095        tile_features = image_embeddings["features"][str(tile_id)]
1096        tile_image_embeddings = {
1097            "features": tile_features,
1098            "input_size": tile_features.attrs["input_size"],
1099            "original_size": tile_features.attrs["original_size"]
1100        }
1101        return set_precomputed(predictor, tile_image_embeddings, i=i)
1102
1103    device = predictor.device
1104    features = image_embeddings["features"]
1105    assert features.ndim in (4, 5), f"{features.ndim}"
1106    if features.ndim == 5 and i is None:
1107        raise ValueError("The data is 3D so an index i is needed.")
1108    elif features.ndim == 4 and i is not None:
1109        raise ValueError("The data is 2D so an index is not needed.")
1110
1111    if i is None:
1112        predictor.features = features.to(device) if torch.is_tensor(features) else \
1113            torch.from_numpy(features[:]).to(device)
1114    else:
1115        predictor.features = features[i].to(device) if torch.is_tensor(features) else \
1116            torch.from_numpy(features[i]).to(device)
1117
1118    predictor.original_size = image_embeddings["original_size"]
1119    predictor.input_size = image_embeddings["input_size"]
1120    predictor.is_image_set = True
1121
1122    return predictor

Set the precomputed image embeddings for a predictor.

Arguments:
  • predictor: The Segment Anything predictor.
  • image_embeddings: The precomputed image embeddings computed by precompute_image_embeddings.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_id: Index for the tile. This is required if the embeddings are tiled.
Returns:

The predictor with set features.

def compute_iou(mask1: numpy.ndarray, mask2: numpy.ndarray) -> float:
1130def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
1131    """Compute the intersection over union of two masks.
1132
1133    Args:
1134        mask1: The first mask.
1135        mask2: The second mask.
1136
1137    Returns:
1138        The intersection over union of the two masks.
1139    """
1140    overlap = np.logical_and(mask1 == 1, mask2 == 1).sum()
1141    union = np.logical_or(mask1 == 1, mask2 == 1).sum()
1142    eps = 1e-7
1143    iou = float(overlap) / (float(union) + eps)
1144    return iou

Compute the intersection over union of two masks.

Arguments:
  • mask1: The first mask.
  • mask2: The second mask.
Returns:

The intersection over union of the two masks.

def get_centers_and_bounding_boxes( segmentation: numpy.ndarray, mode: str = 'v') -> Tuple[Dict[int, numpy.ndarray], Dict[int, tuple]]:
1147def get_centers_and_bounding_boxes(
1148    segmentation: np.ndarray, mode: str = "v"
1149) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]:
1150    """Returns the center coordinates of the foreground instances in the ground-truth.
1151
1152    Args:
1153        segmentation: The segmentation.
1154        mode: Determines the functionality used for computing the centers.
1155            If 'v', the object's eccentricity centers computed by vigra are used.
1156            If 'p' the object's centroids computed by skimage are used.
1157
1158    Returns:
1159        A dictionary that maps object ids to the corresponding centroid.
1160        A dictionary that maps object_ids to the corresponding bounding box.
1161    """
1162    assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra"
1163
1164    properties = regionprops(segmentation)
1165
1166    if mode == "p":
1167        center_coordinates = {prop.label: prop.centroid for prop in properties}
1168    elif mode == "v":
1169        center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32'))
1170        center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0}
1171
1172    bbox_coordinates = {prop.label: prop.bbox for prop in properties}
1173
1174    assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}"
1175    return center_coordinates, bbox_coordinates

Returns the center coordinates of the foreground instances in the ground-truth.

Arguments:
  • segmentation: The segmentation.
  • mode: Determines the functionality used for computing the centers. If 'v', the object's eccentricity centers computed by vigra are used. If 'p' the object's centroids computed by skimage are used.
Returns:

A dictionary that maps object ids to the corresponding centroid. A dictionary that maps object_ids to the corresponding bounding box.

def load_image_data( path: str, key: Optional[str] = None, lazy_loading: bool = False) -> numpy.ndarray:
1178def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray:
1179    """Helper function to load image data from file.
1180
1181    Args:
1182        path: The filepath to the image data.
1183        key: The internal filepath for complex data formats like hdf5.
1184        lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
1185
1186    Returns:
1187        The image data.
1188    """
1189    if key is None:
1190        image_data = imageio.imread(path)
1191    else:
1192        with open_file(path, mode="r") as f:
1193            image_data = f[key]
1194            if not lazy_loading:
1195                image_data = image_data[:]
1196
1197    return image_data

Helper function to load image data from file.

Arguments:
  • path: The filepath to the image data.
  • key: The internal filepath for complex data formats like hdf5.
  • lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
Returns:

The image data.

def segmentation_to_one_hot( segmentation: numpy.ndarray, segmentation_ids: Optional[numpy.ndarray] = None) -> torch.Tensor:
1200def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor:
1201    """Convert the segmentation to one-hot encoded masks.
1202
1203    Args:
1204        segmentation: The segmentation.
1205        segmentation_ids: Optional subset of ids that will be used to subsample the masks.
1206            By default, computes the number of ids from the provided `segmentation` masks.
1207
1208    Returns:
1209        The one-hot encoded masks.
1210    """
1211    masks = segmentation.copy()
1212    if segmentation_ids is None:
1213        n_ids = int(segmentation.max())
1214
1215    else:
1216        msg = "No foreground objects were found."
1217        if len(segmentation_ids) == 0:  # The list should not be completely empty.
1218            raise RuntimeError(msg)
1219
1220        if 0 in segmentation_ids:  # The list should not have 'zero' as a value.
1221            raise RuntimeError(msg)
1222
1223        # the segmentation ids have to be sorted
1224        segmentation_ids = np.sort(segmentation_ids)
1225
1226        # set the non selected objects to zero and relabel sequentially
1227        masks[~np.isin(masks, segmentation_ids)] = 0
1228        masks = relabel_sequential(masks)[0]
1229        n_ids = len(segmentation_ids)
1230
1231    masks = torch.from_numpy(masks)
1232
1233    one_hot_shape = (n_ids + 1,) + masks.shape
1234    masks = masks.unsqueeze(0)  # add dimension to scatter
1235    masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:]
1236
1237    # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W
1238    masks = masks.unsqueeze(1)
1239    return masks

Convert the segmentation to one-hot encoded masks.

Arguments:
  • segmentation: The segmentation.
  • segmentation_ids: Optional subset of ids that will be used to subsample the masks. By default, computes the number of ids from the provided segmentation masks.
Returns:

The one-hot encoded masks.

def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1242def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1243    """Get a suitable block shape for chunking a given shape.
1244
1245    The primary use for this is determining chunk sizes for
1246    zarr arrays or block shapes for parallelization.
1247
1248    Args:
1249        shape: The image or volume shape.
1250
1251    Returns:
1252        The block shape.
1253    """
1254    ndim = len(shape)
1255    if ndim == 2:
1256        block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape))
1257    elif ndim == 3:
1258        block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape))
1259    else:
1260        raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.")
1261
1262    return block_shape

Get a suitable block shape for chunking a given shape.

The primary use for this is determining chunk sizes for zarr arrays or block shapes for parallelization.

Arguments:
  • shape: The image or volume shape.
Returns:

The block shape.

def micro_sam_info():
1265def micro_sam_info():
1266    """Display μSAM information using a rich console."""
1267    import psutil
1268    import platform
1269    import argparse
1270    from rich import progress
1271    from rich.panel import Panel
1272    from rich.table import Table
1273    from rich.console import Console
1274
1275    import torch
1276    import micro_sam
1277
1278    parser = argparse.ArgumentParser(description="μSAM Information Booth")
1279    parser.add_argument(
1280        "--download", nargs="+", metavar=("WHAT", "KIND"),
1281        help="Downloads the pretrained SAM models."
1282        "'--download models' -> downloads all pretrained models; "
1283        "'--download models vit_b_lm vit_b_em_organelles' -> downloads the listed models; "
1284        "'--download model/models vit_b_lm' -> downloads a single specified model."
1285    )
1286    args = parser.parse_args()
1287
1288    # Open up a new console.
1289    console = Console()
1290
1291    # The header for information CLI.
1292    console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center")
1293    console.print("-" * console.width)
1294
1295    # μSAM version panel.
1296    console.print(
1297        Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True)
1298    )
1299
1300    # The documentation link panel.
1301    console.print(
1302        Panel(
1303            "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n"
1304            "https://computational-cell-analytics.github.io/micro-sam", title="Documentation"
1305        )
1306    )
1307
1308    # The publication panel.
1309    console.print(
1310        Panel(
1311            "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n"
1312            "https://www.nature.com/articles/s41592-024-02580-4", title="Publication"
1313        )
1314    )
1315
1316    # The cache directory panel.
1317    console.print(
1318        Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{get_cache_directory()}", title="Cache Directory")
1319    )
1320
1321    # The available models panel.
1322    available_models = list(get_model_names())
1323    # We filter out the decoder models.
1324    available_models = [m for m in available_models if not m.endswith("_decoder")]
1325    model_list = "\n".join(available_models)
1326    console.print(
1327        Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models")
1328    )
1329
1330    # The system information table.
1331    total_memory = psutil.virtual_memory().total / (1024 ** 3)
1332    table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True)
1333    table.add_column("Property")
1334    table.add_column("Value", style="bold #56B4E9")
1335    table.add_row("System", platform.system())
1336    table.add_row("Node Name", platform.node())
1337    table.add_row("Release", platform.release())
1338    table.add_row("Version", platform.version())
1339    table.add_row("Machine", platform.machine())
1340    table.add_row("Processor", platform.processor())
1341    table.add_row("Platform", platform.platform())
1342    table.add_row("Total RAM (GB)", f"{total_memory:.2f}")
1343    console.print(table)
1344
1345    # The device information and check for available GPU acceleration.
1346    default_device = _get_default_device()
1347
1348    if default_device == "cuda":
1349        device_index = torch.cuda.current_device()
1350        device_name = torch.cuda.get_device_name(device_index)
1351        console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information"))
1352    elif default_device == "mps":
1353        console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information"))
1354    else:
1355        console.print(
1356            Panel(
1357                "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]",
1358                title="Device Information"
1359            )
1360        )
1361
1362    # The section allowing to download models.
1363    # NOTE: In future, can be extended to download sample data.
1364    if args.download:
1365        download_provided_args = [t.lower() for t in args.download]
1366        mode, *model_types = download_provided_args
1367
1368        if mode not in {"models", "model"}:
1369            console.print(f"[red]Unknown option for --download: {mode}[/]")
1370            return
1371
1372        if mode in ["model", "models"] and not model_types:  # If user did not specify, we will download all models.
1373            download_list = available_models
1374        else:
1375            download_list = model_types
1376            incorrect_models = [m for m in download_list if m not in available_models]
1377            if incorrect_models:
1378                console.print(Panel("[red]Unknown model(s):[/] " + ", ".join(incorrect_models), title="Download Error"))
1379                return
1380
1381        with progress.Progress(
1382            progress.SpinnerColumn(),
1383            progress.TextColumn("[progress.description]{task.description}"),
1384            progress.BarColumn(bar_width=None),
1385            "[progress.percentage]{task.percentage:>3.0f}%",
1386            progress.TimeRemainingColumn(),
1387            console=console,
1388        ) as prog:
1389            task = prog.add_task("[green]Downloading μSAM models…", total=len(download_list))
1390            for model_type in download_list:
1391                prog.update(task, description=f"Downloading [cyan]{model_type}[/]…")
1392                _download_sam_model(model_type=model_type)
1393                prog.advance(task)
1394
1395        console.print(Panel("[bold green] Downloads complete![/]", title="Finished"))

Display μSAM information using a rich console.