micro_sam.util

Helper functions for downloading Segment Anything models and predicting image embeddings.

   1"""
   2Helper functions for downloading Segment Anything models and predicting image embeddings.
   3"""
   4
   5import os
   6import multiprocessing as mp
   7import pickle
   8import hashlib
   9import warnings
  10from concurrent import futures
  11from pathlib import Path
  12from collections import OrderedDict
  13from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Callable
  14
  15import elf.parallel as parallel_impl
  16import imageio.v3 as imageio
  17import numpy as np
  18import pooch
  19import segment_anything.utils.amg as amg_utils
  20import torch
  21import vigra
  22import xxhash
  23import zarr
  24
  25from elf.io import open_file
  26from nifty.tools import blocking
  27from skimage.measure import regionprops
  28from skimage.segmentation import relabel_sequential
  29from torchvision.ops.boxes import batched_nms
  30
  31from .__version__ import __version__
  32from . import models as custom_models
  33
  34try:
  35    # Avoid import warnigns from mobile_sam
  36    with warnings.catch_warnings():
  37        warnings.simplefilter("ignore")
  38        from mobile_sam import sam_model_registry, SamPredictor
  39    VIT_T_SUPPORT = True
  40except ImportError:
  41    from segment_anything import sam_model_registry, SamPredictor
  42    VIT_T_SUPPORT = False
  43
  44try:
  45    from napari.utils import progress as tqdm
  46except ImportError:
  47    from tqdm import tqdm
  48
  49# This is the default model used in micro_sam
  50# Currently it is set to vit_b_lm
  51_DEFAULT_MODEL = "vit_b_lm"
  52
  53# The valid model types. Each type corresponds to the architecture of the
  54# vision transformer used within SAM.
  55_MODEL_TYPES = ("vit_l", "vit_b", "vit_h", "vit_t")
  56
  57
  58ImageEmbeddings = Dict[str, Any]
  59"""@private"""
  60
  61
  62def get_cache_directory() -> None:
  63    """Get micro-sam cache directory location.
  64
  65    Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
  66    """
  67    default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam"))
  68    cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory))
  69    return cache_directory
  70
  71
  72#
  73# Functionality for model download and export
  74#
  75
  76
  77def microsam_cachedir() -> None:
  78    """Return the micro-sam cache directory.
  79
  80    Returns the top level cache directory for micro-sam models and sample data.
  81
  82    Every time this function is called, we check for any user updates made to
  83    the MICROSAM_CACHEDIR os environment variable since the last time.
  84    """
  85    cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam")
  86    return cache_directory
  87
  88
  89def models():
  90    """Return the segmentation models registry.
  91
  92    We recreate the model registry every time this function is called,
  93    so any user changes to the default micro-sam cache directory location
  94    are respected.
  95    """
  96
  97    # We use xxhash to compute the hash of the models, see
  98    # https://github.com/computational-cell-analytics/micro-sam/issues/283
  99    # (It is now a dependency, so we don't provide the sha256 fallback anymore.)
 100    # To generate the xxh128 hash:
 101    #     xxh128sum filename
 102    encoder_registry = {
 103        # The default segment anything models:
 104        "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a",
 105        "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2",
 106        "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b",
 107        # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM.
 108        "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0",
 109        # The current version of our models in the modelzoo.
 110        # LM generalist models:
 111        "vit_l_lm": "xxh128:017f20677997d628426dec80a8018f9d",
 112        "vit_b_lm": "xxh128:fe9252a29f3f4ea53c15a06de471e186",
 113        "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e",
 114        # EM models:
 115        "vit_l_em_organelles": "xxh128:810b084b6e51acdbf760a993d8619f2d",
 116        "vit_b_em_organelles": "xxh128:f3bf2ed83d691456bae2c3f9a05fb438",
 117        "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9",
 118        # Histopathology models:
 119        "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974",
 120        "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2",
 121        "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e",
 122        # Medical Imaging models:
 123        "vit_b_medical_imaging": "xxh128:40169f1e3c03a4b67bff58249c176d92",
 124    }
 125    # Additional decoders for instance segmentation.
 126    decoder_registry = {
 127        # LM generalist models:
 128        "vit_l_lm_decoder": "xxh128:2faeafa03819dfe03e7c46a44aaac64a",
 129        "vit_b_lm_decoder": "xxh128:708b15ac620e235f90bb38612c4929ba",
 130        "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823",
 131        # EM models:
 132        "vit_l_em_organelles_decoder": "xxh128:334877640bfdaaabce533e3252a17294",
 133        "vit_b_em_organelles_decoder": "xxh128:bb6398956a6b0132c26b631c14f95ce2",
 134        "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44",
 135        # Histopathology models:
 136        "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213",
 137        "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04",
 138        "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47",
 139        # Medical Imaging models:
 140        "vit_b_medical_imaging_decoder": "xxh128:9e498b12f526f119b96c88be76e3b2ed",
 141    }
 142    registry = {**encoder_registry, **decoder_registry}
 143
 144    encoder_urls = {
 145        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
 146        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
 147        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
 148        "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
 149        "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l.pt",
 150        "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b.pt",
 151        "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt",
 152        "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l.pt",  # noqa
 153        "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b.pt",  # noqa
 154        "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt",  # noqa
 155        "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download",
 156        "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download",
 157        "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download",
 158        "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/f5Ol4FrjPQWfjUF/download",
 159    }
 160
 161    decoder_urls = {
 162        "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l_decoder.pt",  # noqa
 163        "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b_decoder.pt",  # noqa
 164        "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt",  # noqa
 165        "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l_decoder.pt",  # noqa
 166        "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b_decoder.pt",  # noqa
 167        "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt",  # noqa
 168        "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download",
 169        "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download",
 170        "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download",
 171        "vit_b_medical_imaging_decoder": "https://owncloud.gwdg.de/index.php/s/ahd3ZhZl2e0RIwz/download",
 172    }
 173    urls = {**encoder_urls, **decoder_urls}
 174
 175    models = pooch.create(
 176        path=os.path.join(microsam_cachedir(), "models"),
 177        base_url="",
 178        registry=registry,
 179        urls=urls,
 180    )
 181    return models
 182
 183
 184def _get_default_device():
 185    # check that we're in CI and use the CPU if we are
 186    # otherwise the tests may run out of memory on MAC if MPS is used.
 187    if os.getenv("GITHUB_ACTIONS") == "true":
 188        return "cpu"
 189    # Use cuda enabled gpu if it's available.
 190    if torch.cuda.is_available():
 191        device = "cuda"
 192    # As second priority use mps.
 193    # See https://pytorch.org/docs/stable/notes/mps.html for details
 194    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
 195        print("Using apple MPS device.")
 196        device = "mps"
 197    # Use the CPU as fallback.
 198    else:
 199        device = "cpu"
 200    return device
 201
 202
 203def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
 204    """Get the torch device.
 205
 206    If no device is passed the default device for your system is used.
 207    Else it will be checked if the device you have passed is supported.
 208
 209    Args:
 210        device: The input device. By default, selects the best available device supports.
 211
 212    Returns:
 213        The device.
 214    """
 215    if device is None or device == "auto":
 216        device = _get_default_device()
 217    else:
 218        device_type = device if isinstance(device, str) else device.type
 219        if device_type.lower() == "cuda":
 220            if not torch.cuda.is_available():
 221                raise RuntimeError("PyTorch CUDA backend is not available.")
 222        elif device_type.lower() == "mps":
 223            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
 224                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
 225        elif device_type.lower() == "cpu":
 226            pass  # cpu is always available
 227        else:
 228            raise RuntimeError(f"Unsupported device: '{device}'. Please choose from 'cpu', 'cuda', or 'mps'.")
 229
 230    return device
 231
 232
 233def _available_devices():
 234    available_devices = []
 235    for i in ["cuda", "mps", "cpu"]:
 236        try:
 237            device = get_device(i)
 238        except RuntimeError:
 239            pass
 240        else:
 241            available_devices.append(device)
 242    return available_devices
 243
 244
 245# We write a custom unpickler that skips objects that cannot be found instead of
 246# throwing an AttributeError or ModueNotFoundError.
 247# NOTE: since we just want to unpickle the model to load its weights these errors don't matter.
 248# See also https://stackoverflow.com/questions/27732354/unable-to-load-files-using-pickle-and-multiple-modules
 249class _CustomUnpickler(pickle.Unpickler):
 250    def find_class(self, module, name):
 251        try:
 252            return super().find_class(module, name)
 253        except (AttributeError, ModuleNotFoundError) as e:
 254            warnings.warn(f"Did not find {module}:{name} and will skip it, due to error {e}")
 255            return None
 256
 257
 258def _compute_hash(path, chunk_size=8192):
 259    hash_obj = xxhash.xxh128()
 260    with open(path, "rb") as f:
 261        chunk = f.read(chunk_size)
 262        while chunk:
 263            hash_obj.update(chunk)
 264            chunk = f.read(chunk_size)
 265    hash_val = hash_obj.hexdigest()
 266    return f"xxh128:{hash_val}"
 267
 268
 269# Load the state from a checkpoint.
 270# The checkpoint can either contain a sam encoder state
 271# or it can be a checkpoint for model finetuning.
 272def _load_checkpoint(checkpoint_path):
 273    # Over-ride the unpickler with our custom one.
 274    # This enables imports from torch_em checkpoints even if it cannot be fully unpickled.
 275    custom_pickle = pickle
 276    custom_pickle.Unpickler = _CustomUnpickler
 277
 278    state = torch.load(checkpoint_path, map_location="cpu", pickle_module=custom_pickle)
 279    if "model_state" in state:
 280        # Copy the model weights from torch_em's training format.
 281        model_state = state["model_state"]
 282        sam_prefix = "sam."
 283        model_state = OrderedDict(
 284            [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()]
 285        )
 286    else:
 287        model_state = state
 288
 289    return state, model_state
 290
 291
 292def _download_sam_model(model_type, progress_bar_factory=None):
 293    model_registry = models()
 294
 295    progress_bar = True
 296    # Check if we have to download the model.
 297    # If we do and have a progress bar factory, then we over-write the progress bar.
 298    if not os.path.exists(os.path.join(get_cache_directory(), model_type)) and progress_bar_factory is not None:
 299        progress_bar = progress_bar_factory(model_type)
 300
 301    checkpoint_path = model_registry.fetch(model_type, progressbar=progress_bar)
 302    if not isinstance(progress_bar, bool):  # Close the progress bar when the task finishes.
 303        progress_bar.close()
 304
 305    model_hash = model_registry.registry[model_type]
 306
 307    # If we have a custom model then we may also have a decoder checkpoint.
 308    # Download it here, so that we can add it to the state.
 309    decoder_name = f"{model_type}_decoder"
 310    decoder_path = model_registry.fetch(
 311        decoder_name, progressbar=True
 312    ) if decoder_name in model_registry.registry else None
 313
 314    return checkpoint_path, model_hash, decoder_path
 315
 316
 317def get_sam_model(
 318    model_type: str = _DEFAULT_MODEL,
 319    device: Optional[Union[str, torch.device]] = None,
 320    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 321    return_sam: bool = False,
 322    return_state: bool = False,
 323    peft_kwargs: Optional[Dict] = None,
 324    flexible_load_checkpoint: bool = False,
 325    progress_bar_factory: Optional[Callable] = None,
 326    **model_kwargs,
 327) -> SamPredictor:
 328    r"""Get the Segment Anything Predictor.
 329
 330    This function will download the required model or load it from the cached weight file.
 331    This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
 332    The name of the requested model can be set via `model_type`.
 333    See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
 334    for an overview of the available models
 335
 336    Alternatively this function can also load a model from weights stored in a local filepath.
 337    The corresponding file path is given via `checkpoint_path`. In this case `model_type`
 338    must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
 339    a SAM model with vit_b encoder.
 340
 341    By default the models are downloaded to a folder named 'micro_sam/models'
 342    inside your default cache directory, eg:
 343    * Mac: ~/Library/Caches/<AppName>
 344    * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined.
 345    * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache
 346    See the pooch.os_cache() documentation for more details:
 347    https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
 348
 349    Args:
 350        model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default.
 351            To get a list of all available model names you can call `micro_sam.util.get_model_names`.
 352        device: The device for the model. If 'None' is provided, will use GPU if available.
 353        checkpoint_path: The path to a file with weights that should be used instead of using the
 354            weights corresponding to `model_type`. If given, `model_type` must match the architecture
 355            corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder
 356            then `model_type` must be given as 'vit_b'.
 357        return_sam: Return the sam model object as well as the predictor. By default, set to 'False'.
 358        return_state: Return the unpickled checkpoint state. By default, set to 'False'.
 359        peft_kwargs: Keyword arguments for th PEFT wrapper class.
 360            If passed 'None', it does not initialize any parameter efficient finetuning.
 361        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
 362            By default, set to 'False'.
 363        progress_bar_factory: A function to create a progress bar for the model download.
 364        model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
 365
 366    Returns:
 367        The Segment Anything predictor.
 368    """
 369    device = get_device(device)
 370
 371    # We support passing a local filepath to a checkpoint.
 372    # In this case we do not download any weights but just use the local weight file,
 373    # as it is, without copying it over anywhere or checking it's hashes.
 374
 375    # checkpoint_path has not been passed, we download a known model and derive the correct
 376    # URL from the model_type. If the model_type is invalid pooch will raise an error.
 377    _provided_checkpoint_path = checkpoint_path is not None
 378    if checkpoint_path is None:
 379        checkpoint_path, model_hash, decoder_path = _download_sam_model(model_type, progress_bar_factory)
 380
 381    # checkpoint_path has been passed, we use it instead of downloading a model.
 382    else:
 383        # Check if the file exists and raise an error otherwise.
 384        # We can't check any hashes here, and we don't check if the file is actually a valid weight file.
 385        # (If it isn't the model creation will fail below.)
 386        if not os.path.exists(checkpoint_path):
 387            raise ValueError(f"Checkpoint at '{checkpoint_path}' could not be found.")
 388        model_hash = _compute_hash(checkpoint_path)
 389        decoder_path = None
 390
 391    # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped
 392    # before calling sam_model_registry.
 393    abbreviated_model_type = model_type[:5]
 394    if abbreviated_model_type not in _MODEL_TYPES:
 395        raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}")
 396    if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
 397        raise RuntimeError(
 398            "'mobile_sam' is required for the vit-tiny. "
 399            "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
 400        )
 401
 402    state, model_state = _load_checkpoint(checkpoint_path)
 403
 404    if _provided_checkpoint_path:
 405        # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type'
 406        # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination.
 407        from micro_sam.models.build_sam import _validate_model_type
 408        _provided_model_type = _validate_model_type(model_state)
 409
 410        # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type'
 411        # Otherwise replace 'abbreviated_model_type' with the later.
 412        if abbreviated_model_type != _provided_model_type:
 413            # Printing the message below to avoid any filtering of warnings on user's end.
 414            print(
 415                f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', "
 416                f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. "
 417                f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. "
 418                "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future."
 419            )
 420
 421        # Replace the extracted 'abbreviated_model_type' subjected to the model weights.
 422        abbreviated_model_type = _provided_model_type
 423
 424    # Whether to update parameters necessary to initialize the model
 425    if model_kwargs:  # Checks whether model_kwargs have been provided or not
 426        if abbreviated_model_type == "vit_t":
 427            raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.")
 428        sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs)
 429
 430    else:
 431        sam = sam_model_registry[abbreviated_model_type]()
 432
 433    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
 434    # Overwrites the SAM model by freezing the backbone and allow PEFT.
 435    if peft_kwargs and isinstance(peft_kwargs, dict):
 436        # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference.
 437        peft_kwargs.pop("quantize", None)
 438
 439        if abbreviated_model_type == "vit_t":
 440            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
 441
 442        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
 443    # In case the model checkpoints have some issues when it is initialized with different parameters than default.
 444    if flexible_load_checkpoint:
 445        sam = _handle_checkpoint_loading(sam, model_state)
 446    else:
 447        sam.load_state_dict(model_state)
 448    sam.to(device=device)
 449
 450    predictor = SamPredictor(sam)
 451    predictor.model_type = abbreviated_model_type
 452    predictor._hash = model_hash
 453    predictor.model_name = model_type
 454    predictor.checkpoint_path = checkpoint_path
 455
 456    # Add the decoder to the state if we have one and if the state is returned.
 457    if decoder_path is not None and return_state:
 458        state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False)
 459
 460    if return_sam and return_state:
 461        return predictor, sam, state
 462    if return_sam:
 463        return predictor, sam
 464    if return_state:
 465        return predictor, state
 466    return predictor
 467
 468
 469def _handle_checkpoint_loading(sam, model_state):
 470    # Whether to handle the mismatch issues in a bit more elegant way.
 471    # eg. while training for multi-class semantic segmentation in the mask encoder,
 472    # parameters are updated - leading to "size mismatch" errors
 473
 474    new_state_dict = {}  # for loading matching parameters
 475    mismatched_layers = []  # for tracking mismatching parameters
 476
 477    reference_state = sam.state_dict()
 478
 479    for k, v in model_state.items():
 480        if k in reference_state:  # This is done to get rid of unwanted layers from pretrained SAM.
 481            if reference_state[k].size() == v.size():
 482                new_state_dict[k] = v
 483            else:
 484                mismatched_layers.append(k)
 485
 486    reference_state.update(new_state_dict)
 487
 488    if len(mismatched_layers) > 0:
 489        warnings.warn(f"The layers with size mismatch: {mismatched_layers}")
 490
 491    for mlayer in mismatched_layers:
 492        if 'weight' in mlayer:
 493            torch.nn.init.kaiming_uniform_(reference_state[mlayer])
 494        elif 'bias' in mlayer:
 495            reference_state[mlayer].zero_()
 496
 497    sam.load_state_dict(reference_state)
 498
 499    return sam
 500
 501
 502def export_custom_sam_model(
 503    checkpoint_path: Union[str, os.PathLike],
 504    model_type: str,
 505    save_path: Union[str, os.PathLike],
 506    with_segmentation_decoder: bool = False,
 507    prefix: str = "sam.",
 508) -> None:
 509    """Export a finetuned Segment Anything Model to the standard model format.
 510
 511    The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
 512
 513    Args:
 514        checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
 515        model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
 516        save_path: Where to save the exported model.
 517        with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well.
 518            If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
 519        prefix: The prefix to remove from the model parameter keys.
 520    """
 521    state, model_state = _load_checkpoint(checkpoint_path=checkpoint_path)
 522    model_state = OrderedDict([(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()])
 523
 524    # Store the 'decoder_state' as well, if desired.
 525    if with_segmentation_decoder:
 526        if "decoder_state" not in state:
 527            raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.")
 528        decoder_state = state["decoder_state"]
 529        save_state = {"model_state": model_state, "decoder_state": decoder_state}
 530    else:
 531        save_state = model_state
 532
 533    torch.save(save_state, save_path)
 534
 535
 536def export_custom_qlora_model(
 537    checkpoint_path: Optional[Union[str, os.PathLike]],
 538    finetuned_path: Union[str, os.PathLike],
 539    model_type: str,
 540    save_path: Union[str, os.PathLike],
 541) -> None:
 542    """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
 543
 544    The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`.
 545
 546    Args:
 547        checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
 548        finetuned_path: The path to the new finetuned model, using QLoRA.
 549        model_type: The Segment Anything Model type corresponding to the checkpoint.
 550        save_path: Where to save the exported model.
 551    """
 552    # Step 1: Get the base SAM model: used to start finetuning from.
 553    _, sam = get_sam_model(
 554        model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True,
 555    )
 556
 557    # Step 2: Load the QLoRA-style finetuned model.
 558    ft_state, ft_model_state = _load_checkpoint(finetuned_path)
 559
 560    # Step 3: Identify LoRA layers from QLoRA model.
 561    # - differentiate between LoRA applied to the attention matrices and LoRA applied to the MLP layers.
 562    # - then copy the LoRA layers from the QLoRA model to the new state dict
 563    updated_model_state = {}
 564
 565    modified_attn_layers = set()
 566    modified_mlp_layers = set()
 567
 568    for k, v in ft_model_state.items():
 569        if "blocks." in k:
 570            layer_id = int(k.split("blocks.")[1].split(".")[0])
 571        if k.find("qkv.w_a_linear") != -1 or k.find("qkv.w_b_linear") != -1:
 572            modified_attn_layers.add(layer_id)
 573            updated_model_state[k] = v
 574        if k.find("mlp.w_a_linear") != -1 or k.find("mlp.w_b_linear") != -1:
 575            modified_mlp_layers.add(layer_id)
 576            updated_model_state[k] = v
 577
 578    # Step 4: Next, we get all the remaining parameters from the base SAM model.
 579    for k, v in sam.state_dict().items():
 580        if "blocks." in k:
 581            layer_id = int(k.split("blocks.")[1].split(".")[0])
 582        if k.find("attn.qkv.") != -1:
 583            if layer_id in modified_attn_layers:  # We have LoRA in QKV layers, so we need to modify the key
 584                k = k.replace("qkv", "qkv.qkv_proj")
 585        elif k.find("mlp") != -1 and k.find("image_encoder") != -1:
 586            if layer_id in modified_mlp_layers:  # We have LoRA in MLP layers, so we need to modify the key
 587                k = k.replace("mlp.", "mlp.mlp_layer.")
 588        updated_model_state[k] = v
 589
 590    # Step 5: Finally, we replace the old model state with the new one (to retain other relevant stuff)
 591    ft_state['model_state'] = updated_model_state
 592
 593    # Step 6: Store the new "state" to "save_path"
 594    torch.save(ft_state, save_path)
 595
 596
 597def get_model_names() -> Iterable:
 598    model_registry = models()
 599    model_names = model_registry.registry.keys()
 600    return model_names
 601
 602
 603#
 604# Functionality for precomputing image embeddings.
 605#
 606
 607
 608def _to_image(image):
 609    input_ = image
 610    ndim = input_.ndim
 611    n_channels = 1 if ndim == 2 else input_.shape[-1]
 612
 613    # Map the input to three channels.
 614    if ndim == 2:  # Grayscale image -> replicate channels.
 615        input_ = np.concatenate([input_[..., None]] * 3, axis=-1)
 616    elif ndim == 3 and n_channels == 1:  # Grayscale image -> replicate channels.
 617        input_ = np.concatenate([input_] * 3, axis=-1)
 618    elif ndim == 3 and n_channels == 2:  # Two channels -> add a zero channel.
 619        zero_channel = np.zeros(input_.shape[:2] + (1,), dtype=input_.dtype)
 620        input_ = np.concatenate([input_, zero_channel], axis=-1)
 621    elif input_.ndim == 3 and n_channels == 3:  # RGB input -> do nothing.
 622        pass
 623    elif input_.ndim == 3 and n_channels > 3:  # More than three channels -> select first three.
 624        warnings.warn(f"You provided an input with {n_channels} channels. Only the first three will be used.")
 625        input_ = input_[..., :3]
 626    else:
 627        raise ValueError(
 628            f"Invalid input dimensionality {ndim}. Expect either a 2D input (=grayscale image) "
 629            "or a 3D input (= image with channels)."
 630        )
 631    assert input_.ndim == 3 and input_.shape[-1] == 3
 632
 633    # Normalize the input per channel and bring it to uint8.
 634    input_ = input_.astype("float32")
 635    input_ -= input_.min(axis=(0, 1))[None, None]
 636    input_ /= (input_.max(axis=(0, 1))[None, None] + 1e-7)
 637    input_ = (input_ * 255).astype("uint8")
 638
 639    # Explicitly return a numpy array for compatibility with torchvision
 640    # because the input_ array could be something like dask array.
 641    return np.array(input_)
 642
 643
 644@torch.no_grad
 645def _compute_embeddings_batched(predictor, batched_images):
 646    predictor.reset_image()
 647    batched_tensors, original_sizes, input_sizes = [], [], []
 648
 649    # Apply proeprocessing to all images in the batch, and then stack them.
 650    # Note: after the transformation the images are all of the same size,
 651    # so they can be stacked and processed as a batch, even if the input images were of different size.
 652    for image in batched_images:
 653        tensor = predictor.transform.apply_image(image)
 654        tensor = torch.as_tensor(tensor, device=predictor.device)
 655        tensor = tensor.permute(2, 0, 1).contiguous()[None, :, :, :]
 656
 657        original_sizes.append(image.shape[:2])
 658        input_sizes.append(tensor.shape[-2:])
 659
 660        tensor = predictor.model.preprocess(tensor)
 661        batched_tensors.append(tensor)
 662
 663    batched_tensors = torch.cat(batched_tensors)
 664    features = predictor.model.image_encoder(batched_tensors)
 665
 666    predictor.original_size = original_sizes[-1]
 667    predictor.input_size = input_sizes[-1]
 668    predictor.features = features[-1]
 669    predictor.is_image_set = True
 670
 671    return features, original_sizes, input_sizes
 672
 673
 674# Wrapper of zarr.create dataset to support zarr v2 and zarr v3.
 675def _create_dataset_with_data(group, name, data, chunks=None):
 676    zarr_major_version = int(zarr.__version__.split(".")[0])
 677    if chunks is None:
 678        chunks = data.shape
 679    if zarr_major_version == 2:
 680        ds = group.create_dataset(name, data=data, shape=data.shape, chunks=chunks)
 681    elif zarr_major_version == 3:
 682        ds = group.create_array(name, shape=data.shape, chunks=chunks, dtype=data.dtype)
 683        ds[:] = data
 684    else:
 685        raise RuntimeError(f"Unsupported zarr version: {zarr_major_version}")
 686    return ds
 687
 688
 689def _create_dataset_without_data(group, name, shape, dtype, chunks):
 690    zarr_major_version = int(zarr.__version__.split(".")[0])
 691    if zarr_major_version == 2:
 692        ds = group.create_dataset(name, shape=shape, dtype=dtype, chunks=chunks)
 693    elif zarr_major_version == 3:
 694        ds = group.create_array(name, shape=shape, chunks=chunks, dtype=dtype)
 695    else:
 696        raise RuntimeError(f"Unsupported zarr version: {zarr_major_version}")
 697    return ds
 698
 699
 700def _write_batch(features, tile_ids, batched_embeddings, original_sizes, input_sizes, slices=None, n_slices=None):
 701
 702    # Pre-create / pre-fetch the datasets if we have slices.
 703    # (Dataset creation is not thread-safe)
 704    if slices is not None:
 705        datasets = {}
 706        for tile_id, tile_embeddings, original_size, input_size in zip(
 707            tile_ids, batched_embeddings, original_sizes, input_sizes
 708        ):
 709            ds_name = str(tile_id)
 710            if ds_name in datasets:
 711                continue
 712            if ds_name in features:
 713                datasets[ds_name] = features[ds_name]
 714                continue
 715            shape = (n_slices, 1) + tile_embeddings.shape
 716            chunks = (1, 1) + tile_embeddings.shape
 717            ds = _create_dataset_without_data(features, ds_name, shape=shape, dtype="float32", chunks=chunks)
 718            ds.attrs["original_size"] = original_size
 719            ds.attrs["input_size"] = input_size
 720            datasets[ds_name] = ds
 721
 722    def _write_embed(i):
 723        ds_name = str(tile_ids[i])
 724        tile_embeddings = batched_embeddings[i].unsqueeze(0)
 725        if slices is None:
 726            ds = _create_dataset_with_data(features, ds_name, data=tile_embeddings.cpu().numpy())
 727            ds.attrs["original_size"] = original_sizes[i]
 728            ds.attrs["input_size"] = input_sizes[i]
 729        elif ds_name in features:
 730            ds = datasets[ds_name]
 731            z = slices[i]
 732            ds[z] = tile_embeddings.cpu().numpy()
 733
 734    n_tiles = len(tile_ids)
 735    n_workers = min(mp.cpu_count(), n_tiles)
 736    with futures.ThreadPoolExecutor(n_workers) as tp:
 737        list(tp.map(_write_embed, range(n_tiles)))
 738
 739
 740def _get_tiles_in_mask(mask, tiling, halo, z=None):
 741    def _check_mask(tile_id):
 742        tile = tiling.getBlockWithHalo(tile_id, list(halo))
 743        outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end))
 744        if z is not None:
 745            outer_tile = (z,) + outer_tile
 746        tile_mask = mask[outer_tile].astype("bool")
 747        return None if tile_mask.sum() == 0 else tile_id
 748
 749    n_threads = mp.cpu_count()
 750    with futures.ThreadPoolExecutor(n_threads) as tp:
 751        tiles_in_mask = tp.map(_check_mask, range(tiling.numberOfBlocks))
 752    return sorted([tile_id for tile_id in tiles_in_mask if tile_id is not None])
 753
 754
 755def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask):
 756    tiling = blocking([0, 0], input_.shape[:2], tile_shape)
 757    n_tiles = tiling.numberOfBlocks
 758
 759    features = f.require_group("features")
 760    features.attrs["shape"] = input_.shape[:2]
 761    features.attrs["tile_shape"] = tile_shape
 762    features.attrs["halo"] = halo
 763
 764    n_batches = int(np.ceil(n_tiles / batch_size))
 765    if mask is None:
 766        tile_ids_for_batches = [
 767            range(batch_id * batch_size, min((batch_id + 1) * batch_size, n_tiles))
 768            for batch_id in range(n_batches)
 769        ]
 770        pbar_init(n_tiles, "Compute Image Embeddings 2D tiled")
 771    else:
 772        tiles_in_mask = _get_tiles_in_mask(mask, tiling, halo)
 773        pbar_init(len(tiles_in_mask), "Compute Image Embeddings 2D tiled with mask")
 774        tile_ids_for_batches = np.array_split(tiles_in_mask, n_batches)
 775        assert len(tile_ids_for_batches) == n_batches
 776
 777    for tile_ids in tile_ids_for_batches:
 778        batched_images = []
 779        for tile_id in tile_ids:
 780            tile = tiling.getBlockWithHalo(tile_id, list(halo))
 781            outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end))
 782            tile_input = _to_image(input_[outer_tile])
 783            batched_images.append(tile_input)
 784
 785        batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
 786        _write_batch(features, tile_ids, batched_embeddings, original_sizes, input_sizes)
 787        pbar_update(len(tile_ids))
 788
 789    _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None)
 790    if mask is not None:
 791        features.attrs["tiles_in_mask"] = tiles_in_mask
 792
 793    return features
 794
 795
 796class _BatchProvider:
 797    def __init__(self, n_slices, n_tiles_per_plane, tiles_in_mask_per_slice, batch_size):
 798        if tiles_in_mask_per_slice is None:
 799            self.n_tiles_total = n_slices * n_tiles_per_plane
 800        else:
 801            self.n_tiles_total = sum(len(val) for val in tiles_in_mask_per_slice.values())
 802
 803        self.n_batches = int(np.ceil(self.n_tiles_total / batch_size))
 804        self.n_slices = n_slices
 805        self.n_tiles_per_plane = n_tiles_per_plane
 806        self.tiles_in_mask_per_slice = tiles_in_mask_per_slice
 807        self.batch_size = batch_size
 808
 809        # Iter variables.
 810        self.batch_id = 0
 811        self.z = 0
 812        self.tile_id = 0
 813
 814    def __iter__(self):
 815        return self
 816
 817    def __next__(self):
 818        if self.batch_id >= self.n_batches:
 819            raise StopIteration
 820
 821        z_list = list(range(self.n_tiles_per_plane))
 822        z_tiles = z_list if self.tiles_in_mask_per_slice is None else self.tiles_in_mask_per_slice[self.z]
 823
 824        slices, tile_ids = [], []
 825        this_batch_size = 0
 826        while this_batch_size < self.batch_size:
 827            if self.tile_id == len(z_tiles):
 828                self.z += 1
 829                self.tile_id = 0
 830                if self.z >= self.n_slices:
 831                    break
 832                z_tiles = z_list if self.tiles_in_mask_per_slice is None else self.tiles_in_mask_per_slice[self.z]
 833                continue
 834
 835            slices.append(self.z), tile_ids.append(z_tiles[self.tile_id])
 836            self.tile_id += 1
 837            this_batch_size += 1
 838
 839        self.batch_id += 1
 840        return slices, tile_ids
 841
 842
 843def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask):
 844    assert input_.ndim == 3
 845
 846    shape = input_.shape[1:]
 847    tiling = blocking([0, 0], shape, tile_shape)
 848    features = f.require_group("features")
 849    features.attrs["shape"] = shape
 850    features.attrs["tile_shape"] = tile_shape
 851    features.attrs["halo"] = halo
 852
 853    n_tiles_per_plane = tiling.numberOfBlocks
 854    n_slices = input_.shape[0]
 855
 856    msg = "Compute Image Embeddings 3D tiled"
 857    if mask is None:
 858        n_tiles_total = n_slices * n_tiles_per_plane
 859        tiles_in_mask_per_slice = None
 860    else:
 861        tiles_in_mask_per_slice = {}
 862        for z in range(n_slices):
 863            tiles_in_mask_per_slice[z] = _get_tiles_in_mask(mask, tiling, halo, z=z)
 864        n_tiles_total = sum(len(val) for val in tiles_in_mask_per_slice.values())
 865        msg += " masked"
 866    pbar_init(n_tiles_total, msg)
 867
 868    batch_provider = _BatchProvider(n_slices, n_tiles_per_plane, tiles_in_mask_per_slice, batch_size)
 869    for slices, tile_ids in batch_provider:
 870        batched_images = []
 871        for z, tile_id in zip(slices, tile_ids):
 872            tile = tiling.getBlockWithHalo(tile_id, list(halo))
 873            outer_tile = (z,) + tuple(
 874                slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end)
 875            )
 876            tile_input = _to_image(input_[outer_tile])
 877            batched_images.append(tile_input)
 878
 879        batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
 880        _write_batch(
 881            features, tile_ids, batched_embeddings, original_sizes, input_sizes, slices=slices, n_slices=n_slices
 882        )
 883        pbar_update(len(tile_ids))
 884
 885    if mask is not None:
 886        features.attrs["tiles_in_mask"] = {str(z): per_slice for z, per_slice in tiles_in_mask_per_slice.items()}
 887
 888    _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None)
 889    return features
 890
 891
 892def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update):
 893    # Check if the embeddings are already cached.
 894    if save_path is not None and "input_size" in f.attrs:
 895        # In this case we load the embeddings.
 896        features = f["features"][:]
 897        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 898        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 899        # Also set the embeddings.
 900        set_precomputed(predictor, image_embeddings)
 901        return image_embeddings
 902
 903    pbar_init(1, "Compute Image Embeddings 2D")
 904    # Otherwise we have to compute the embeddings.
 905    predictor.reset_image()
 906    predictor.set_image(_to_image(input_))
 907    features = predictor.get_image_embedding().cpu().numpy()
 908    original_size = predictor.original_size
 909    input_size = predictor.input_size
 910    pbar_update(1)
 911
 912    # Save the embeddings if we have a save_path.
 913    if save_path is not None:
 914        _create_dataset_with_data(f, "features", data=features)
 915        _write_embedding_signature(
 916            f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size,
 917        )
 918
 919    image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 920    return image_embeddings
 921
 922
 923def _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask):
 924    # Check if the features are already computed.
 925    if "input_size" in f.attrs:
 926        features = f["features"]
 927        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 928        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 929        return image_embeddings
 930
 931    # Otherwise compute them. Note: saving happens automatically because we
 932    # always write the features to zarr. If no save path is given we use an in-memory zarr.
 933    features = _compute_tiled_features_2d(
 934        predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask
 935    )
 936    image_embeddings = {"features": features, "input_size": None, "original_size": None}
 937    return image_embeddings
 938
 939
 940def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size):
 941    # Check if the embeddings are already fully cached.
 942    if save_path is not None and "input_size" in f.attrs:
 943        # In this case we load the embeddings.
 944        features = f["features"] if lazy_loading else f["features"][:]
 945        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 946        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 947        return image_embeddings
 948
 949    # Otherwise we have to compute the embeddings.
 950
 951    # First check if we have a save path or not and set things up accordingly.
 952    if save_path is None:
 953        features = []
 954        save_features = False
 955        partial_features = False
 956    else:
 957        save_features = True
 958        embed_shape = (1, 256, 64, 64)
 959        shape = (input_.shape[0],) + embed_shape
 960        chunks = (1,) + embed_shape
 961        if "features" in f:
 962            partial_features = True
 963            features = f["features"]
 964            if features.shape != shape or features.chunks != chunks:
 965                raise RuntimeError("Invalid partial features")
 966        else:
 967            partial_features = False
 968            features = _create_dataset_without_data(f, "features", shape=shape, chunks=chunks, dtype="float32")
 969
 970    # Initialize the pbar and batches.
 971    n_slices = input_.shape[0]
 972    pbar_init(n_slices, "Compute Image Embeddings 3D")
 973    n_batches = int(np.ceil(n_slices / batch_size))
 974
 975    for batch_id in range(n_batches):
 976        z_start = batch_id * batch_size
 977        z_stop = min(z_start + batch_size, n_slices)
 978
 979        batched_images, batched_z = [], []
 980        for z in range(z_start, z_stop):
 981            # Skip feature computation in case of partial features in non-zero slice.
 982            if partial_features and np.count_nonzero(features[z]) != 0:
 983                continue
 984            tile_input = _to_image(input_[z])
 985            batched_images.append(tile_input)
 986            batched_z.append(z)
 987
 988        batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images)
 989
 990        for z, embedding in zip(batched_z, batched_embeddings):
 991            embedding = embedding.unsqueeze(0)
 992            if save_features:
 993                features[z] = embedding.cpu().numpy()
 994            else:
 995                features.append(embedding.unsqueeze(0))
 996            pbar_update(1)
 997
 998    if save_features:
 999        _write_embedding_signature(
1000            f, input_, predictor, tile_shape=None, halo=None,
1001            input_size=input_sizes[-1], original_size=original_sizes[-1],
1002        )
1003    else:
1004        # Concatenate across the z axis.
1005        features = torch.cat(features).cpu().numpy()
1006
1007    image_embeddings = {"features": features, "input_size": input_sizes[-1], "original_size": original_sizes[-1]}
1008    return image_embeddings
1009
1010
1011def _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask):
1012    # Check if the features are already computed.
1013    if "input_size" in f.attrs:
1014        features = f["features"]
1015        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
1016        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
1017        return image_embeddings
1018
1019    # Otherwise compute them. Note: saving happens automatically because we
1020    # always write the features to zarr. If no save path is given we use an in-memory zarr.
1021    features = _compute_tiled_features_3d(
1022        predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask
1023    )
1024    image_embeddings = {"features": features, "input_size": None, "original_size": None}
1025    return image_embeddings
1026
1027
1028def _compute_data_signature(input_):
1029    data_signature = hashlib.sha1(np.asarray(input_).tobytes()).hexdigest()
1030    return data_signature
1031
1032
1033# Create all metadata that is stored along with the embeddings.
1034def _get_embedding_signature(input_, predictor, tile_shape, halo, data_signature=None):
1035    if data_signature is None:
1036        data_signature = _compute_data_signature(input_)
1037
1038    signature = {
1039        "data_signature": data_signature,
1040        "tile_shape": tile_shape if tile_shape is None else list(tile_shape),
1041        "halo": halo if halo is None else list(halo),
1042        "model_type": predictor.model_type,
1043        "model_name": predictor.model_name,
1044        "micro_sam_version": __version__,
1045        "model_hash": getattr(predictor, "_hash", None),
1046    }
1047    return signature
1048
1049
1050# Note: the input size and orginal size are different if embeddings are tiled or not.
1051# That's why we do not include them in the main signature that is being checked
1052# (_get_embedding_signature), but just add it for serialization here.
1053def _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size, original_size):
1054    signature = _get_embedding_signature(input_, predictor, tile_shape, halo)
1055    signature.update({"input_size": input_size, "original_size": original_size})
1056    for key, val in signature.items():
1057        f.attrs[key] = val
1058
1059
1060def _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo):
1061    # We may have an empty zarr file that was already created to save the embeddings in.
1062    # In this case the embeddings will be computed and we don't need to perform any checks.
1063    if "input_size" not in f.attrs:
1064        return
1065
1066    signature = _get_embedding_signature(input_, predictor, tile_shape, halo)
1067    for key, val in signature.items():
1068        # Check whether the key is missing from the attrs or if the value is not matching.
1069        if key not in f.attrs or f.attrs[key] != val:
1070            # These keys were recently added, so we don't want to fail yet if they don't
1071            # match in order to not invalidate previous embedding files.
1072            # Instead we just raise a warning. (For the version we probably also don't want to fail
1073            # i the future since it should not invalidate the embeddings).
1074            if key in ("micro_sam_version", "model_hash", "model_name"):
1075                warnings.warn(
1076                    f"The signature for {key} in embeddings file {save_path} has a mismatch: "
1077                    f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. "
1078                    "But please recompute them if model predictions don't look as expected."
1079                )
1080            else:
1081                raise RuntimeError(
1082                    f"Embeddings file {save_path} is invalid due to mismatch in {key}: "
1083                    f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file."
1084                )
1085
1086
1087# Helper function for optional external progress bars.
1088def handle_pbar(verbose, pbar_init, pbar_update):
1089    """@private"""
1090
1091    # Noop to provide dummy functions.
1092    def noop(*args):
1093        pass
1094
1095    if verbose and pbar_init is None:  # we are verbose and don't have an external progress bar.
1096        assert pbar_update is None  # avoid inconsistent state of callbacks
1097
1098        # Create our own progress bar and callbacks
1099        pbar = tqdm()
1100
1101        def pbar_init(total, description):
1102            pbar.total = total
1103            pbar.set_description(description)
1104
1105        def pbar_update(update):
1106            pbar.update(update)
1107
1108        def pbar_close():
1109            pbar.close()
1110
1111    elif verbose and pbar_init is not None:  # external pbar -> we don't have to do anything
1112        assert pbar_update is not None
1113        pbar = None
1114        pbar_close = noop
1115
1116    else:  # we are not verbose, do nothing
1117        pbar = None
1118        pbar_init, pbar_update, pbar_close = noop, noop, noop
1119
1120    return pbar, pbar_init, pbar_update, pbar_close
1121
1122
1123def precompute_image_embeddings(
1124    predictor: SamPredictor,
1125    input_: np.ndarray,
1126    save_path: Optional[Union[str, os.PathLike]] = None,
1127    lazy_loading: bool = False,
1128    ndim: Optional[int] = None,
1129    tile_shape: Optional[Tuple[int, int]] = None,
1130    halo: Optional[Tuple[int, int]] = None,
1131    verbose: bool = True,
1132    batch_size: int = 1,
1133    mask: Optional[np.typing.ArrayLike] = None,
1134    pbar_init: Optional[callable] = None,
1135    pbar_update: Optional[callable] = None,
1136) -> ImageEmbeddings:
1137    """Compute the image embeddings (output of the encoder) for the input.
1138
1139    If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
1140
1141    Args:
1142        predictor: The Segment Anything predictor.
1143        input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
1144        save_path: Path to save the embeddings in a zarr container.
1145            By default, set to 'None', i.e. the computed embeddings will not be stored locally.
1146        lazy_loading: Whether to load all embeddings into memory or return an
1147            object to load them on demand when required. This only has an effect if 'save_path' is given
1148            and if the input is 3 dimensional. By default, set to 'False'.
1149        ndim: The dimensionality of the data. If not given will be deduced from the input data.
1150            By default, set to 'None', i.e. will be computed from the provided `input_`.
1151        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
1152        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
1153        verbose: Whether to be verbose in the computation. By default, set to 'True'.
1154        batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'.
1155        mask: An optional mask to define areas that are ignored in the computation.
1156            The mask will be used within tiled embedding computation and tiles that don't contain any foreground
1157            in the mask will be excluded from the computation. It does not have any effect for non-tiled embeddings.
1158        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1159            Can be used together with pbar_update to handle napari progress bar in other thread.
1160            To enables using this function within a threadworker.
1161        pbar_update: Callback to update an external progress bar.
1162
1163    Returns:
1164        The image embeddings.
1165    """
1166    ndim = input_.ndim if ndim is None else ndim
1167
1168    # Handle the embedding save_path.
1169    # We don't have a save path, open in memory zarr file to hold tiled embeddings.
1170    if save_path is None:
1171        f = zarr.group()
1172
1173    # We have a save path and it already exists. Embeddings will be loaded from it,
1174    # check that the saved embeddings in there match the parameters of the function call.
1175    elif os.path.exists(save_path):
1176        f = zarr.open(save_path, mode="a")
1177        _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo)
1178
1179    # We have a save path and it does not exist yet. Create the zarr file to which the
1180    # embeddings will then be saved.
1181    else:
1182        f = zarr.open(save_path, mode="a")
1183
1184    _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update)
1185
1186    if ndim == 2 and tile_shape is None:
1187        embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update)
1188    elif ndim == 2 and tile_shape is not None:
1189        embeddings = _compute_tiled_2d(
1190            input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask
1191        )
1192    elif ndim == 3 and tile_shape is None:
1193        embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size)
1194    elif ndim == 3 and tile_shape is not None:
1195        embeddings = _compute_tiled_3d(
1196            input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask
1197        )
1198    else:
1199        raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.")
1200
1201    pbar_close()
1202    return embeddings
1203
1204
1205def set_precomputed(
1206    predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None,
1207) -> SamPredictor:
1208    """Set the precomputed image embeddings for a predictor.
1209
1210    Args:
1211        predictor: The Segment Anything predictor.
1212        image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`.
1213        i: Index for the image data. Required if `image` has three spatial dimensions
1214            or a time dimension and two spatial dimensions.
1215        tile_id: Index for the tile. This is required if the embeddings are tiled.
1216
1217    Returns:
1218        The predictor with set features.
1219    """
1220    if tile_id is not None:
1221        tile_features = image_embeddings["features"][str(tile_id)]
1222        tile_image_embeddings = {
1223            "features": tile_features,
1224            "input_size": tile_features.attrs["input_size"],
1225            "original_size": tile_features.attrs["original_size"]
1226        }
1227        return set_precomputed(predictor, tile_image_embeddings, i=i)
1228
1229    device = predictor.device
1230    features = image_embeddings["features"]
1231    assert features.ndim in (4, 5), f"{features.ndim}"
1232    if features.ndim == 5 and i is None:
1233        raise ValueError("The data is 3D so an index i is needed.")
1234    elif features.ndim == 4 and i is not None:
1235        raise ValueError("The data is 2D so an index is not needed.")
1236
1237    if i is None:
1238        predictor.features = features.to(device) if torch.is_tensor(features) else \
1239            torch.from_numpy(features[:]).to(device)
1240    else:
1241        predictor.features = features[i].to(device) if torch.is_tensor(features) else \
1242            torch.from_numpy(features[i]).to(device)
1243
1244    predictor.original_size = image_embeddings["original_size"]
1245    predictor.input_size = image_embeddings["input_size"]
1246    predictor.is_image_set = True
1247
1248    return predictor
1249
1250
1251#
1252# Misc functionality
1253#
1254
1255
1256def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
1257    """Compute the intersection over union of two masks.
1258
1259    Args:
1260        mask1: The first mask.
1261        mask2: The second mask.
1262
1263    Returns:
1264        The intersection over union of the two masks.
1265    """
1266    overlap = np.logical_and(mask1 == 1, mask2 == 1).sum()
1267    union = np.logical_or(mask1 == 1, mask2 == 1).sum()
1268    eps = 1e-7
1269    iou = float(overlap) / (float(union) + eps)
1270    return iou
1271
1272
1273def get_centers_and_bounding_boxes(
1274    segmentation: np.ndarray, mode: str = "v"
1275) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]:
1276    """Returns the center coordinates of the foreground instances in the ground-truth.
1277
1278    Args:
1279        segmentation: The segmentation.
1280        mode: Determines the functionality used for computing the centers.
1281            If 'v', the object's eccentricity centers computed by vigra are used.
1282            If 'p' the object's centroids computed by skimage are used.
1283
1284    Returns:
1285        A dictionary that maps object ids to the corresponding centroid.
1286        A dictionary that maps object_ids to the corresponding bounding box.
1287    """
1288    assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra"
1289
1290    properties = regionprops(segmentation)
1291
1292    if mode == "p":
1293        center_coordinates = {prop.label: prop.centroid for prop in properties}
1294    elif mode == "v":
1295        center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32'))
1296        center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0}
1297
1298    bbox_coordinates = {prop.label: prop.bbox for prop in properties}
1299
1300    assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}"
1301    return center_coordinates, bbox_coordinates
1302
1303
1304def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray:
1305    """Helper function to load image data from file.
1306
1307    Args:
1308        path: The filepath to the image data.
1309        key: The internal filepath for complex data formats like hdf5.
1310        lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
1311
1312    Returns:
1313        The image data.
1314    """
1315    if key is None:
1316        image_data = imageio.imread(path)
1317    else:
1318        with open_file(path, mode="r") as f:
1319            image_data = f[key]
1320            if not lazy_loading:
1321                image_data = image_data[:]
1322
1323    return image_data
1324
1325
1326def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor:
1327    """Convert the segmentation to one-hot encoded masks.
1328
1329    Args:
1330        segmentation: The segmentation.
1331        segmentation_ids: Optional subset of ids that will be used to subsample the masks.
1332            By default, computes the number of ids from the provided `segmentation` masks.
1333
1334    Returns:
1335        The one-hot encoded masks.
1336    """
1337    masks = segmentation.copy()
1338    if segmentation_ids is None:
1339        n_ids = int(segmentation.max())
1340
1341    else:
1342        msg = "No foreground objects were found."
1343        if len(segmentation_ids) == 0:  # The list should not be completely empty.
1344            raise RuntimeError(msg)
1345
1346        if 0 in segmentation_ids:  # The list should not have 'zero' as a value.
1347            raise RuntimeError(msg)
1348
1349        # the segmentation ids have to be sorted
1350        segmentation_ids = np.sort(segmentation_ids)
1351
1352        # set the non selected objects to zero and relabel sequentially
1353        masks[~np.isin(masks, segmentation_ids)] = 0
1354        masks = relabel_sequential(masks)[0]
1355        n_ids = len(segmentation_ids)
1356
1357    masks = torch.from_numpy(masks)
1358
1359    one_hot_shape = (n_ids + 1,) + masks.shape
1360    masks = masks.unsqueeze(0)  # add dimension to scatter
1361    masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:]
1362
1363    # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W
1364    masks = masks.unsqueeze(1)
1365    return masks
1366
1367
1368def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1369    """Get a suitable block shape for chunking a given shape.
1370
1371    The primary use for this is determining chunk sizes for
1372    zarr arrays or block shapes for parallelization.
1373
1374    Args:
1375        shape: The image or volume shape.
1376
1377    Returns:
1378        The block shape.
1379    """
1380    ndim = len(shape)
1381    if ndim == 2:
1382        block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape))
1383    elif ndim == 3:
1384        block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape))
1385    else:
1386        raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.")
1387
1388    return block_shape
1389
1390
1391def micro_sam_info() -> None:
1392    """Display μSAM information using a rich console."""
1393    import psutil
1394    import platform
1395    import argparse
1396    from rich import progress
1397    from rich.panel import Panel
1398    from rich.table import Table
1399    from rich.console import Console
1400
1401    import torch
1402    import micro_sam
1403
1404    parser = argparse.ArgumentParser(description="μSAM Information Booth")
1405    parser.add_argument(
1406        "--download", nargs="+", metavar=("WHAT", "KIND"),
1407        help="Downloads the pretrained SAM models."
1408        "'--download models' -> downloads all pretrained models; "
1409        "'--download models vit_b_lm vit_b_em_organelles' -> downloads the listed models; "
1410        "'--download model/models vit_b_lm' -> downloads a single specified model."
1411    )
1412    args = parser.parse_args()
1413
1414    # Open up a new console.
1415    console = Console()
1416
1417    # The header for information CLI.
1418    console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center")
1419    console.print("-" * console.width)
1420
1421    # μSAM version panel.
1422    console.print(
1423        Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True)
1424    )
1425
1426    # The documentation link panel.
1427    console.print(
1428        Panel(
1429            "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n"
1430            "https://computational-cell-analytics.github.io/micro-sam", title="Documentation"
1431        )
1432    )
1433
1434    # The publication panel.
1435    console.print(
1436        Panel(
1437            "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n"
1438            "https://www.nature.com/articles/s41592-024-02580-4", title="Publication"
1439        )
1440    )
1441
1442    # Creating a cache directory when users' run `micro_sam.info`.
1443    cache_dir = get_cache_directory()
1444    os.makedirs(cache_dir, exist_ok=True)
1445
1446    # The cache directory panel.
1447    console.print(
1448        Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{cache_dir}", title="Cache Directory")
1449    )
1450
1451    # We have a simple versioning logic here (which is what I'll follow here for mapping model versions).
1452    available_models = []
1453    for model_name, model_path in models().urls.items():  # We filter out the decoder models.
1454        if model_name.endswith("decoder"):
1455            continue
1456
1457        if "https://dl.fbaipublicfiles.com/segment_anything/" in model_path:  # Valid v1 SAM models.
1458            available_models.append(model_name)
1459
1460        if "https://owncloud.gwdg.de/" in model_path:  # Our own hosted models (in their v1 mode quite often)
1461            if model_name == "vit_t":  # MobileSAM model.
1462                available_models.append(model_name)
1463            else:
1464                available_models.append(f"{model_name} (v1)")
1465
1466        # Now for our models, the BioImageIO ModelZoo upload structure is such that:
1467        # '/1/files' corresponds to v2 models.
1468        # '/1.1/files' corresponds to v3 models.
1469        # '/1.2/files' corresponds to v4 models.
1470        if "/1/files" in model_path:
1471            available_models.append(f"{model_name} (v2)")
1472        if "/1.1/files" in model_path:
1473            available_models.append(f"{model_name} (v3)")
1474        if "/1.2/files" in model_path:
1475            available_models.append(f"{model_name} (v4)")
1476
1477    model_list = "\n".join(available_models)
1478
1479    # The available models panel.
1480    console.print(
1481        Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models")
1482    )
1483
1484    # The system information table.
1485    total_memory = psutil.virtual_memory().total / (1024 ** 3)
1486    table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True)
1487    table.add_column("Property")
1488    table.add_column("Value", style="bold #56B4E9")
1489    table.add_row("System", platform.system())
1490    table.add_row("Node Name", platform.node())
1491    table.add_row("Release", platform.release())
1492    table.add_row("Version", platform.version())
1493    table.add_row("Machine", platform.machine())
1494    table.add_row("Processor", platform.processor())
1495    table.add_row("Platform", platform.platform())
1496    table.add_row("Total RAM (GB)", f"{total_memory:.2f}")
1497    console.print(table)
1498
1499    # The device information and check for available GPU acceleration.
1500    default_device = _get_default_device()
1501
1502    if default_device == "cuda":
1503        device_index = torch.cuda.current_device()
1504        device_name = torch.cuda.get_device_name(device_index)
1505        console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information"))
1506    elif default_device == "mps":
1507        console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information"))
1508    else:
1509        console.print(
1510            Panel(
1511                "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]",
1512                title="Device Information"
1513            )
1514        )
1515
1516    # The section allowing to download models.
1517    # NOTE: In future, can be extended to download sample data.
1518    if args.download:
1519        download_provided_args = [t.lower() for t in args.download]
1520        mode, *model_types = download_provided_args
1521
1522        if mode not in {"models", "model"}:
1523            console.print(f"[red]Unknown option for --download: {mode}[/]")
1524            return
1525
1526        if mode in ["model", "models"] and not model_types:  # If user did not specify, we will download all models.
1527            download_list = available_models
1528        else:
1529            download_list = model_types
1530            incorrect_models = [m for m in download_list if m not in available_models]
1531            if incorrect_models:
1532                console.print(Panel("[red]Unknown model(s):[/] " + ", ".join(incorrect_models), title="Download Error"))
1533                return
1534
1535        with progress.Progress(
1536            progress.SpinnerColumn(),
1537            progress.TextColumn("[progress.description]{task.description}"),
1538            progress.BarColumn(bar_width=None),
1539            "[progress.percentage]{task.percentage:>3.0f}%",
1540            progress.TimeRemainingColumn(),
1541            console=console,
1542        ) as prog:
1543            task = prog.add_task("[green]Downloading μSAM models…", total=len(download_list))
1544            for model_type in download_list:
1545                prog.update(task, description=f"Downloading [cyan]{model_type}[/]…")
1546                _download_sam_model(model_type=model_type)
1547                prog.advance(task)
1548
1549        console.print(Panel("[bold green] Downloads complete![/]", title="Finished"))
1550
1551
1552#
1553# Functionality to convert mask predictions to an instance segmentation via non-maximum suppression.
1554# The functionality for computing NMS for masks is taken from CellSeg1:
1555# https://github.com/Nuisal/cellseg1/blob/1c027c2568b83494d2662d1fbecec9aafb478ee0/mask_nms.py
1556#
1557
1558
1559def _overlap_matrix(boxes):
1560    x1 = torch.max(boxes[:, None, 0], boxes[:, 0])
1561    y1 = torch.max(boxes[:, None, 1], boxes[:, 1])
1562    x2 = torch.min(boxes[:, None, 2], boxes[:, 2])
1563    y2 = torch.min(boxes[:, None, 3], boxes[:, 3])
1564
1565    w = torch.clamp(x2 - x1, min=0)
1566    h = torch.clamp(y2 - y1, min=0)
1567
1568    return (w * h) > 0
1569
1570
1571def _calculate_ious_between_pred_masks(masks, boxes, diagonal_value=1):
1572    n_points = masks.shape[0]
1573    m = torch.zeros((n_points, n_points))
1574
1575    overlap_m = _overlap_matrix(boxes)
1576
1577    for i in range(n_points):
1578        js = torch.where(overlap_m[i])[0]
1579        js_half = js[js > i].to(masks.device)
1580
1581        if len(js_half) > 0:
1582            intersection = torch.logical_and(masks[i], masks[js_half]).sum(dim=(1, 2))
1583            union = torch.logical_or(masks[i], masks[js_half]).sum(dim=(1, 2))
1584            iou = intersection / union
1585            m[i, js_half] = iou
1586
1587    m = m + m.T
1588    m.fill_diagonal_(diagonal_value)
1589    return m
1590
1591
1592def _calculate_iomin_between_pred_masks(masks, boxes, eps=1e-6):
1593    overlap_m = _overlap_matrix(boxes)
1594
1595    # Flatten spatial dimensions: (N, H*W) or (N, D*H*W)
1596    N = masks.shape[0]
1597    masks_flat = masks.reshape(N, -1).float()
1598
1599    # Per-mask area
1600    areas = masks_flat.sum(dim=1)  # (N,)
1601
1602    # Pairwise intersections via matrix multiplication
1603    # inter[i, j] = sum_k masks_flat[i, k] * masks_flat[j, k]
1604    inter = masks_flat @ masks_flat.t()  # (N, N)
1605
1606    # Denominator: min area of the two masks
1607    min_areas = torch.minimum(areas[:, None], areas[None, :])  # (N, N)
1608
1609    # IoMin = intersection / min(area_i, area_j)
1610    iomin = inter / (min_areas + eps)
1611
1612    # Set elements without any overlap explicitly to zero.
1613    iomin[~overlap_m] = 0
1614    return iomin
1615
1616
1617def _batched_mask_nms(masks, boxes, scores, nms_thresh, intersection_over_min):
1618    boxes = (
1619        boxes.detach() if isinstance(boxes, torch.Tensor) else torch.tensor(boxes)
1620    ).cpu()
1621    scores = (
1622        scores.detach() if isinstance(scores, torch.Tensor) else torch.tensor(scores)
1623    ).cpu()
1624    masks = (
1625        masks.detach() if isinstance(masks, torch.Tensor) else torch.tensor(masks)
1626    ).cpu()
1627
1628    if intersection_over_min:
1629        iou_matrix = _calculate_iomin_between_pred_masks(masks, boxes)
1630    else:
1631        iou_matrix = _calculate_ious_between_pred_masks(masks, boxes)
1632    sorted_indices = torch.argsort(scores, descending=True)
1633
1634    keep = []
1635    while len(sorted_indices) > 0:
1636        i = sorted_indices[0]
1637        keep.append(i)
1638
1639        if len(sorted_indices) == 1:
1640            break
1641
1642        iou_values = iou_matrix[i, sorted_indices[1:]]
1643        mask = iou_values <= nms_thresh
1644        sorted_indices = sorted_indices[1:][mask]
1645
1646    return torch.tensor(keep)
1647
1648
1649def mask_data_to_segmentation(
1650    masks: List[Dict[str, Any]],
1651    shape: Optional[Tuple[int, int]] = None,
1652    min_object_size: int = 0,
1653    max_object_size: Optional[int] = None,
1654    label_masks: bool = True,
1655    with_background: bool = False,
1656    merge_exclusively: bool = True,
1657) -> np.ndarray:
1658    """Convert the output of the automatic mask generation to an instance segmentation.
1659
1660    Args:
1661        masks: The outputs generated by `AutomaticMaskGenerator`, other classes from `micro_sam.instance_segmentation`,
1662            or from `micro_sam.inference` functions. Only supported for output_mode=binary_mask.
1663        shape: The shape of the output segmentation. If None, it will be derived from the mask input.
1664            If the mask where predicted with tiling then the shape must be given.
1665        min_object_size: The minimal size of an object in pixels. By default, set to '0'.
1666        max_object_size: The maximal size of an object in pixels.
1667        label_masks: Whether to apply connected components to the result before removing small objects.
1668            By default, set to 'True'.
1669        with_background: Whether to remove the largest object, which often covers the background for AMG.
1670        merge_exclusively: Whether to exclude previous merged masks from merging.
1671
1672    Returns:
1673        The instance segmentation.
1674    """
1675    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
1676    if shape is None:
1677        shape = next(iter(masks))["segmentation"].shape
1678    segmentation = np.zeros(shape, dtype="uint32")
1679
1680    def require_numpy(mask):
1681        return mask.cpu().numpy() if torch.is_tensor(mask) else mask
1682
1683    seg_id = 1
1684    for mask_data in masks:
1685        area = mask_data["area"]
1686        if (area < min_object_size) or (max_object_size is not None and area > max_object_size):
1687            continue
1688
1689        this_mask = require_numpy(mask_data["segmentation"])
1690        this_seg_id = mask_data.get("seg_id", seg_id)
1691        if "global_bbox" in mask_data:
1692            bb = mask_data["bbox"]
1693            bb = np.s_[bb[1]:bb[1] + bb[3], bb[0]:bb[0] + bb[2]]
1694            global_bb = mask_data["global_bbox"]
1695            global_bb = np.s_[global_bb[1]:global_bb[1] + global_bb[3], global_bb[0]:global_bb[0] + global_bb[2]]
1696            if merge_exclusively:
1697                this_mask = np.logical_and(this_mask[bb], segmentation[global_bb] == 0)
1698            else:
1699                this_mask = this_mask[bb]
1700            segmentation[global_bb][this_mask] = this_seg_id
1701        else:
1702            if merge_exclusively:
1703                this_mask = np.logical_and(this_mask, segmentation == 0)
1704            segmentation[this_mask] = this_seg_id
1705        seg_id = this_seg_id + 1
1706
1707    block_shape = (512, 512)
1708    if label_masks:
1709        segmentation_cc = np.zeros_like(segmentation, dtype=segmentation.dtype)
1710        segmentation_cc = parallel_impl.label(segmentation, out=segmentation_cc, block_shape=block_shape)
1711        segmentation = segmentation_cc
1712
1713    seg_ids, sizes = parallel_impl.unique(segmentation, return_counts=True, block_shape=block_shape)
1714    filter_ids = seg_ids[sizes < min_object_size]
1715    if with_background:
1716        bg_id = seg_ids[np.argmax(sizes)]
1717        filter_ids = np.concatenate([filter_ids, [bg_id]])
1718
1719    filter_mask = np.zeros(segmentation.shape, dtype="bool")
1720    filter_mask = parallel_impl.isin(segmentation, filter_ids, out=filter_mask, block_shape=block_shape)
1721    segmentation[filter_mask] = 0
1722    parallel_impl.relabel_consecutive(segmentation, block_shape=block_shape)[0]
1723
1724    return segmentation
1725
1726
1727def apply_nms(
1728    predictions: List[Dict[str, Any]],
1729    min_size: int,
1730    shape: Optional[Tuple[int, int]] = None,
1731    perform_box_nms: bool = False,
1732    nms_thresh: float = 0.9,
1733    max_size: Optional[int] = None,
1734    intersection_over_min: bool = False,
1735) -> np.ndarray:
1736    """Apply non-maximum suppression to mask predictions from a segment anything model.
1737
1738    Args:
1739        predictions: The mask predictions from SAM.
1740        min_size: The minimum mask size to keep in the output.
1741        shape: The shape of the output segmentation.
1742            Has to be passed for predictions obtained from tiling.
1743        perform_box_nms: Whether to perform NMS on the box coordinates or on the masks.
1744        nms_thresh: The threshold for filtering out objects in NMS.
1745        max_size: The maximum mask size to keep in the output.
1746        intersection_over_min: Whether to perform intersection over the minimum overlap shape
1747            or to perform intersection over union.
1748
1749    Returns:
1750        The segmentation obtained from merging the masks left after NMS.
1751    """
1752    data = amg_utils.MaskData(
1753        masks=torch.cat([pred["segmentation"][None] for pred in predictions], dim=0),
1754        iou_preds=torch.tensor([pred["predicted_iou"] for pred in predictions]),
1755    )
1756    data["boxes"] = torch.tensor(np.array([pred["bbox"] for pred in predictions]))
1757    data["area"] = [mask.sum() for mask in data["masks"]]
1758    data["stability_scores"] = torch.tensor([pred["stability_score"] for pred in predictions])
1759
1760    # Check if the input comes with a 'global_bbox' attribute. If it does, then the predictions are from
1761    # a tiled prediction. In this case, we have to take the coordinates w.r.t. the tiling into account.
1762    if "global_bbox" in predictions[0]:
1763        if shape is None:
1764            raise ValueError("The output shape 'shape' has to be passed for tiled predictions.")
1765        data["global_boxes"] = torch.tensor(np.array([pred["global_bbox"] for pred in predictions]))
1766        is_tiled = True
1767    else:
1768        is_tiled = False
1769
1770    if min_size > 0:
1771        keep_by_size = torch.tensor(
1772            [i for i, area in enumerate(data["area"]) if area > min_size], dtype=torch.long,
1773        )
1774        data.filter(keep_by_size)
1775
1776    if max_size is not None:
1777        keep_by_size = torch.tensor([i for i, area in enumerate(data["area"]) if area < max_size])
1778        data.filter(keep_by_size)
1779
1780    scores = data["iou_preds"] * data["stability_scores"]
1781    if perform_box_nms:
1782        assert not intersection_over_min  # not implemented
1783        keep_by_nms = batched_nms(
1784            data["global_boxes"].float() if is_tiled else data["boxes"].float(),
1785            scores,
1786            torch.zeros_like(data["boxes"][:, 0]),  # categories
1787            iou_threshold=nms_thresh,
1788        )
1789    else:
1790        keep_by_nms = _batched_mask_nms(
1791            masks=data["masks"],
1792            boxes=data["global_boxes"].float() if is_tiled else data["boxes"].float(),
1793            scores=scores,
1794            nms_thresh=nms_thresh,
1795            intersection_over_min=intersection_over_min,
1796        )
1797    data.filter(keep_by_nms)
1798
1799    if is_tiled:
1800        mask_data = [
1801            {"segmentation": mask, "area": area, "bbox": box, "global_bbox": global_box}
1802            for mask, area, box, global_box in zip(data["masks"], data["area"], data["boxes"], data["global_boxes"])
1803        ]
1804    else:
1805        mask_data = [
1806            {"segmentation": mask, "area": area, "bbox": box}
1807            for mask, area, box in zip(data["masks"], data["area"], data["boxes"])
1808        ]
1809
1810    if shape is None:
1811        shape = predictions[0]["segmentation"].shape
1812    if mask_data:
1813        segmentation = mask_data_to_segmentation(mask_data, shape=shape, min_object_size=min_size)
1814    else:  # In case all objects have been filtered out due to size filtering.
1815        segmentation = np.zeros(shape, dtype="uint32")
1816
1817    return segmentation
def get_cache_directory() -> None:
63def get_cache_directory() -> None:
64    """Get micro-sam cache directory location.
65
66    Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
67    """
68    default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam"))
69    cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory))
70    return cache_directory

Get micro-sam cache directory location.

Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.

def microsam_cachedir() -> None:
78def microsam_cachedir() -> None:
79    """Return the micro-sam cache directory.
80
81    Returns the top level cache directory for micro-sam models and sample data.
82
83    Every time this function is called, we check for any user updates made to
84    the MICROSAM_CACHEDIR os environment variable since the last time.
85    """
86    cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam")
87    return cache_directory

Return the micro-sam cache directory.

Returns the top level cache directory for micro-sam models and sample data.

Every time this function is called, we check for any user updates made to the MICROSAM_CACHEDIR os environment variable since the last time.

def models():
 90def models():
 91    """Return the segmentation models registry.
 92
 93    We recreate the model registry every time this function is called,
 94    so any user changes to the default micro-sam cache directory location
 95    are respected.
 96    """
 97
 98    # We use xxhash to compute the hash of the models, see
 99    # https://github.com/computational-cell-analytics/micro-sam/issues/283
100    # (It is now a dependency, so we don't provide the sha256 fallback anymore.)
101    # To generate the xxh128 hash:
102    #     xxh128sum filename
103    encoder_registry = {
104        # The default segment anything models:
105        "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a",
106        "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2",
107        "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b",
108        # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM.
109        "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0",
110        # The current version of our models in the modelzoo.
111        # LM generalist models:
112        "vit_l_lm": "xxh128:017f20677997d628426dec80a8018f9d",
113        "vit_b_lm": "xxh128:fe9252a29f3f4ea53c15a06de471e186",
114        "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e",
115        # EM models:
116        "vit_l_em_organelles": "xxh128:810b084b6e51acdbf760a993d8619f2d",
117        "vit_b_em_organelles": "xxh128:f3bf2ed83d691456bae2c3f9a05fb438",
118        "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9",
119        # Histopathology models:
120        "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974",
121        "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2",
122        "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e",
123        # Medical Imaging models:
124        "vit_b_medical_imaging": "xxh128:40169f1e3c03a4b67bff58249c176d92",
125    }
126    # Additional decoders for instance segmentation.
127    decoder_registry = {
128        # LM generalist models:
129        "vit_l_lm_decoder": "xxh128:2faeafa03819dfe03e7c46a44aaac64a",
130        "vit_b_lm_decoder": "xxh128:708b15ac620e235f90bb38612c4929ba",
131        "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823",
132        # EM models:
133        "vit_l_em_organelles_decoder": "xxh128:334877640bfdaaabce533e3252a17294",
134        "vit_b_em_organelles_decoder": "xxh128:bb6398956a6b0132c26b631c14f95ce2",
135        "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44",
136        # Histopathology models:
137        "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213",
138        "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04",
139        "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47",
140        # Medical Imaging models:
141        "vit_b_medical_imaging_decoder": "xxh128:9e498b12f526f119b96c88be76e3b2ed",
142    }
143    registry = {**encoder_registry, **decoder_registry}
144
145    encoder_urls = {
146        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
147        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
148        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
149        "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
150        "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l.pt",
151        "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b.pt",
152        "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt",
153        "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l.pt",  # noqa
154        "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b.pt",  # noqa
155        "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt",  # noqa
156        "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download",
157        "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download",
158        "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download",
159        "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/f5Ol4FrjPQWfjUF/download",
160    }
161
162    decoder_urls = {
163        "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.2/files/vit_l_decoder.pt",  # noqa
164        "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.2/files/vit_b_decoder.pt",  # noqa
165        "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt",  # noqa
166        "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1.2/files/vit_l_decoder.pt",  # noqa
167        "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1.2/files/vit_b_decoder.pt",  # noqa
168        "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt",  # noqa
169        "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download",
170        "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download",
171        "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download",
172        "vit_b_medical_imaging_decoder": "https://owncloud.gwdg.de/index.php/s/ahd3ZhZl2e0RIwz/download",
173    }
174    urls = {**encoder_urls, **decoder_urls}
175
176    models = pooch.create(
177        path=os.path.join(microsam_cachedir(), "models"),
178        base_url="",
179        registry=registry,
180        urls=urls,
181    )
182    return models

Return the segmentation models registry.

We recreate the model registry every time this function is called, so any user changes to the default micro-sam cache directory location are respected.

def get_device( device: Union[str, torch.device, NoneType] = None) -> Union[str, torch.device]:
204def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
205    """Get the torch device.
206
207    If no device is passed the default device for your system is used.
208    Else it will be checked if the device you have passed is supported.
209
210    Args:
211        device: The input device. By default, selects the best available device supports.
212
213    Returns:
214        The device.
215    """
216    if device is None or device == "auto":
217        device = _get_default_device()
218    else:
219        device_type = device if isinstance(device, str) else device.type
220        if device_type.lower() == "cuda":
221            if not torch.cuda.is_available():
222                raise RuntimeError("PyTorch CUDA backend is not available.")
223        elif device_type.lower() == "mps":
224            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
225                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
226        elif device_type.lower() == "cpu":
227            pass  # cpu is always available
228        else:
229            raise RuntimeError(f"Unsupported device: '{device}'. Please choose from 'cpu', 'cuda', or 'mps'.")
230
231    return device

Get the torch device.

If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.

Arguments:
  • device: The input device. By default, selects the best available device supports.
Returns:

The device.

def get_sam_model( model_type: str = 'vit_b_lm', device: Union[str, torch.device, NoneType] = None, checkpoint_path: Union[str, os.PathLike, NoneType] = None, return_sam: bool = False, return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, progress_bar_factory: Optional[Callable] = None, **model_kwargs) -> mobile_sam.predictor.SamPredictor:
318def get_sam_model(
319    model_type: str = _DEFAULT_MODEL,
320    device: Optional[Union[str, torch.device]] = None,
321    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
322    return_sam: bool = False,
323    return_state: bool = False,
324    peft_kwargs: Optional[Dict] = None,
325    flexible_load_checkpoint: bool = False,
326    progress_bar_factory: Optional[Callable] = None,
327    **model_kwargs,
328) -> SamPredictor:
329    r"""Get the Segment Anything Predictor.
330
331    This function will download the required model or load it from the cached weight file.
332    This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
333    The name of the requested model can be set via `model_type`.
334    See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
335    for an overview of the available models
336
337    Alternatively this function can also load a model from weights stored in a local filepath.
338    The corresponding file path is given via `checkpoint_path`. In this case `model_type`
339    must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
340    a SAM model with vit_b encoder.
341
342    By default the models are downloaded to a folder named 'micro_sam/models'
343    inside your default cache directory, eg:
344    * Mac: ~/Library/Caches/<AppName>
345    * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined.
346    * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache
347    See the pooch.os_cache() documentation for more details:
348    https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
349
350    Args:
351        model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default.
352            To get a list of all available model names you can call `micro_sam.util.get_model_names`.
353        device: The device for the model. If 'None' is provided, will use GPU if available.
354        checkpoint_path: The path to a file with weights that should be used instead of using the
355            weights corresponding to `model_type`. If given, `model_type` must match the architecture
356            corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder
357            then `model_type` must be given as 'vit_b'.
358        return_sam: Return the sam model object as well as the predictor. By default, set to 'False'.
359        return_state: Return the unpickled checkpoint state. By default, set to 'False'.
360        peft_kwargs: Keyword arguments for th PEFT wrapper class.
361            If passed 'None', it does not initialize any parameter efficient finetuning.
362        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
363            By default, set to 'False'.
364        progress_bar_factory: A function to create a progress bar for the model download.
365        model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
366
367    Returns:
368        The Segment Anything predictor.
369    """
370    device = get_device(device)
371
372    # We support passing a local filepath to a checkpoint.
373    # In this case we do not download any weights but just use the local weight file,
374    # as it is, without copying it over anywhere or checking it's hashes.
375
376    # checkpoint_path has not been passed, we download a known model and derive the correct
377    # URL from the model_type. If the model_type is invalid pooch will raise an error.
378    _provided_checkpoint_path = checkpoint_path is not None
379    if checkpoint_path is None:
380        checkpoint_path, model_hash, decoder_path = _download_sam_model(model_type, progress_bar_factory)
381
382    # checkpoint_path has been passed, we use it instead of downloading a model.
383    else:
384        # Check if the file exists and raise an error otherwise.
385        # We can't check any hashes here, and we don't check if the file is actually a valid weight file.
386        # (If it isn't the model creation will fail below.)
387        if not os.path.exists(checkpoint_path):
388            raise ValueError(f"Checkpoint at '{checkpoint_path}' could not be found.")
389        model_hash = _compute_hash(checkpoint_path)
390        decoder_path = None
391
392    # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped
393    # before calling sam_model_registry.
394    abbreviated_model_type = model_type[:5]
395    if abbreviated_model_type not in _MODEL_TYPES:
396        raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}")
397    if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
398        raise RuntimeError(
399            "'mobile_sam' is required for the vit-tiny. "
400            "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
401        )
402
403    state, model_state = _load_checkpoint(checkpoint_path)
404
405    if _provided_checkpoint_path:
406        # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type'
407        # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination.
408        from micro_sam.models.build_sam import _validate_model_type
409        _provided_model_type = _validate_model_type(model_state)
410
411        # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type'
412        # Otherwise replace 'abbreviated_model_type' with the later.
413        if abbreviated_model_type != _provided_model_type:
414            # Printing the message below to avoid any filtering of warnings on user's end.
415            print(
416                f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', "
417                f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. "
418                f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. "
419                "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future."
420            )
421
422        # Replace the extracted 'abbreviated_model_type' subjected to the model weights.
423        abbreviated_model_type = _provided_model_type
424
425    # Whether to update parameters necessary to initialize the model
426    if model_kwargs:  # Checks whether model_kwargs have been provided or not
427        if abbreviated_model_type == "vit_t":
428            raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.")
429        sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs)
430
431    else:
432        sam = sam_model_registry[abbreviated_model_type]()
433
434    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
435    # Overwrites the SAM model by freezing the backbone and allow PEFT.
436    if peft_kwargs and isinstance(peft_kwargs, dict):
437        # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference.
438        peft_kwargs.pop("quantize", None)
439
440        if abbreviated_model_type == "vit_t":
441            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
442
443        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
444    # In case the model checkpoints have some issues when it is initialized with different parameters than default.
445    if flexible_load_checkpoint:
446        sam = _handle_checkpoint_loading(sam, model_state)
447    else:
448        sam.load_state_dict(model_state)
449    sam.to(device=device)
450
451    predictor = SamPredictor(sam)
452    predictor.model_type = abbreviated_model_type
453    predictor._hash = model_hash
454    predictor.model_name = model_type
455    predictor.checkpoint_path = checkpoint_path
456
457    # Add the decoder to the state if we have one and if the state is returned.
458    if decoder_path is not None and return_state:
459        state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False)
460
461    if return_sam and return_state:
462        return predictor, sam, state
463    if return_sam:
464        return predictor, sam
465    if return_state:
466        return predictor, state
467    return predictor

Get the Segment Anything Predictor.

This function will download the required model or load it from the cached weight file. This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. The name of the requested model can be set via model_type. See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models for an overview of the available models

Alternatively this function can also load a model from weights stored in a local filepath. The corresponding file path is given via checkpoint_path. In this case model_type must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for a SAM model with vit_b encoder.

By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg:

Arguments:
  • model_type: The Segment Anything model to use. Will use the 'vit_b_lm' model by default. To get a list of all available model names you can call get_model_names.
  • device: The device for the model. If 'None' is provided, will use GPU if available.
  • checkpoint_path: The path to a file with weights that should be used instead of using the weights corresponding to model_type. If given, model_type must match the architecture corresponding to the weight file. e.g. if you use weights for SAM with vit_b encoder then model_type must be given as 'vit_b'.
  • return_sam: Return the sam model object as well as the predictor. By default, set to 'False'.
  • return_state: Return the unpickled checkpoint state. By default, set to 'False'.
  • peft_kwargs: Keyword arguments for th PEFT wrapper class. If passed 'None', it does not initialize any parameter efficient finetuning.
  • flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. By default, set to 'False'.
  • progress_bar_factory: A function to create a progress bar for the model download.
  • model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
Returns:

The Segment Anything predictor.

def export_custom_sam_model( checkpoint_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike], with_segmentation_decoder: bool = False, prefix: str = 'sam.') -> None:
503def export_custom_sam_model(
504    checkpoint_path: Union[str, os.PathLike],
505    model_type: str,
506    save_path: Union[str, os.PathLike],
507    with_segmentation_decoder: bool = False,
508    prefix: str = "sam.",
509) -> None:
510    """Export a finetuned Segment Anything Model to the standard model format.
511
512    The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
513
514    Args:
515        checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
516        model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
517        save_path: Where to save the exported model.
518        with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well.
519            If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
520        prefix: The prefix to remove from the model parameter keys.
521    """
522    state, model_state = _load_checkpoint(checkpoint_path=checkpoint_path)
523    model_state = OrderedDict([(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()])
524
525    # Store the 'decoder_state' as well, if desired.
526    if with_segmentation_decoder:
527        if "decoder_state" not in state:
528            raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.")
529        decoder_state = state["decoder_state"]
530        save_state = {"model_state": model_state, "decoder_state": decoder_state}
531    else:
532        save_state = model_state
533
534    torch.save(save_state, save_path)

Export a finetuned Segment Anything Model to the standard model format.

The exported model can be used by the interactive annotation tools in micro_sam.annotator.

Arguments:
  • checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
  • model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
  • save_path: Where to save the exported model.
  • with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
  • prefix: The prefix to remove from the model parameter keys.
def export_custom_qlora_model( checkpoint_path: Union[str, os.PathLike, NoneType], finetuned_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike]) -> None:
537def export_custom_qlora_model(
538    checkpoint_path: Optional[Union[str, os.PathLike]],
539    finetuned_path: Union[str, os.PathLike],
540    model_type: str,
541    save_path: Union[str, os.PathLike],
542) -> None:
543    """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
544
545    The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`.
546
547    Args:
548        checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
549        finetuned_path: The path to the new finetuned model, using QLoRA.
550        model_type: The Segment Anything Model type corresponding to the checkpoint.
551        save_path: Where to save the exported model.
552    """
553    # Step 1: Get the base SAM model: used to start finetuning from.
554    _, sam = get_sam_model(
555        model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True,
556    )
557
558    # Step 2: Load the QLoRA-style finetuned model.
559    ft_state, ft_model_state = _load_checkpoint(finetuned_path)
560
561    # Step 3: Identify LoRA layers from QLoRA model.
562    # - differentiate between LoRA applied to the attention matrices and LoRA applied to the MLP layers.
563    # - then copy the LoRA layers from the QLoRA model to the new state dict
564    updated_model_state = {}
565
566    modified_attn_layers = set()
567    modified_mlp_layers = set()
568
569    for k, v in ft_model_state.items():
570        if "blocks." in k:
571            layer_id = int(k.split("blocks.")[1].split(".")[0])
572        if k.find("qkv.w_a_linear") != -1 or k.find("qkv.w_b_linear") != -1:
573            modified_attn_layers.add(layer_id)
574            updated_model_state[k] = v
575        if k.find("mlp.w_a_linear") != -1 or k.find("mlp.w_b_linear") != -1:
576            modified_mlp_layers.add(layer_id)
577            updated_model_state[k] = v
578
579    # Step 4: Next, we get all the remaining parameters from the base SAM model.
580    for k, v in sam.state_dict().items():
581        if "blocks." in k:
582            layer_id = int(k.split("blocks.")[1].split(".")[0])
583        if k.find("attn.qkv.") != -1:
584            if layer_id in modified_attn_layers:  # We have LoRA in QKV layers, so we need to modify the key
585                k = k.replace("qkv", "qkv.qkv_proj")
586        elif k.find("mlp") != -1 and k.find("image_encoder") != -1:
587            if layer_id in modified_mlp_layers:  # We have LoRA in MLP layers, so we need to modify the key
588                k = k.replace("mlp.", "mlp.mlp_layer.")
589        updated_model_state[k] = v
590
591    # Step 5: Finally, we replace the old model state with the new one (to retain other relevant stuff)
592    ft_state['model_state'] = updated_model_state
593
594    # Step 6: Store the new "state" to "save_path"
595    torch.save(ft_state, save_path)

Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.

The exported model can be used with the LoRA backbone by passing the relevant peft_kwargs to get_sam_model.

Arguments:
  • checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
  • finetuned_path: The path to the new finetuned model, using QLoRA.
  • model_type: The Segment Anything Model type corresponding to the checkpoint.
  • save_path: Where to save the exported model.
def get_model_names() -> Iterable:
598def get_model_names() -> Iterable:
599    model_registry = models()
600    model_names = model_registry.registry.keys()
601    return model_names
def precompute_image_embeddings( predictor: mobile_sam.predictor.SamPredictor, input_: numpy.ndarray, save_path: Union[str, os.PathLike, NoneType] = None, lazy_loading: bool = False, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, batch_size: int = 1, mask: Union[collections.abc.Buffer, numpy._typing._array_like._SupportsArray[numpy.dtype[Any]], numpy._typing._nested_sequence._NestedSequence[numpy._typing._array_like._SupportsArray[numpy.dtype[Any]]], bool, int, float, complex, str, bytes, numpy._typing._nested_sequence._NestedSequence[Union[bool, int, float, complex, str, bytes]], NoneType] = None, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> Dict[str, Any]:
1124def precompute_image_embeddings(
1125    predictor: SamPredictor,
1126    input_: np.ndarray,
1127    save_path: Optional[Union[str, os.PathLike]] = None,
1128    lazy_loading: bool = False,
1129    ndim: Optional[int] = None,
1130    tile_shape: Optional[Tuple[int, int]] = None,
1131    halo: Optional[Tuple[int, int]] = None,
1132    verbose: bool = True,
1133    batch_size: int = 1,
1134    mask: Optional[np.typing.ArrayLike] = None,
1135    pbar_init: Optional[callable] = None,
1136    pbar_update: Optional[callable] = None,
1137) -> ImageEmbeddings:
1138    """Compute the image embeddings (output of the encoder) for the input.
1139
1140    If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
1141
1142    Args:
1143        predictor: The Segment Anything predictor.
1144        input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
1145        save_path: Path to save the embeddings in a zarr container.
1146            By default, set to 'None', i.e. the computed embeddings will not be stored locally.
1147        lazy_loading: Whether to load all embeddings into memory or return an
1148            object to load them on demand when required. This only has an effect if 'save_path' is given
1149            and if the input is 3 dimensional. By default, set to 'False'.
1150        ndim: The dimensionality of the data. If not given will be deduced from the input data.
1151            By default, set to 'None', i.e. will be computed from the provided `input_`.
1152        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
1153        halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
1154        verbose: Whether to be verbose in the computation. By default, set to 'True'.
1155        batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'.
1156        mask: An optional mask to define areas that are ignored in the computation.
1157            The mask will be used within tiled embedding computation and tiles that don't contain any foreground
1158            in the mask will be excluded from the computation. It does not have any effect for non-tiled embeddings.
1159        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
1160            Can be used together with pbar_update to handle napari progress bar in other thread.
1161            To enables using this function within a threadworker.
1162        pbar_update: Callback to update an external progress bar.
1163
1164    Returns:
1165        The image embeddings.
1166    """
1167    ndim = input_.ndim if ndim is None else ndim
1168
1169    # Handle the embedding save_path.
1170    # We don't have a save path, open in memory zarr file to hold tiled embeddings.
1171    if save_path is None:
1172        f = zarr.group()
1173
1174    # We have a save path and it already exists. Embeddings will be loaded from it,
1175    # check that the saved embeddings in there match the parameters of the function call.
1176    elif os.path.exists(save_path):
1177        f = zarr.open(save_path, mode="a")
1178        _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo)
1179
1180    # We have a save path and it does not exist yet. Create the zarr file to which the
1181    # embeddings will then be saved.
1182    else:
1183        f = zarr.open(save_path, mode="a")
1184
1185    _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update)
1186
1187    if ndim == 2 and tile_shape is None:
1188        embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update)
1189    elif ndim == 2 and tile_shape is not None:
1190        embeddings = _compute_tiled_2d(
1191            input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask
1192        )
1193    elif ndim == 3 and tile_shape is None:
1194        embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size)
1195    elif ndim == 3 and tile_shape is not None:
1196        embeddings = _compute_tiled_3d(
1197            input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size, mask=mask
1198        )
1199    else:
1200        raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.")
1201
1202    pbar_close()
1203    return embeddings

Compute the image embeddings (output of the encoder) for the input.

If 'save_path' is given the embeddings will be loaded/saved in a zarr container.

Arguments:
  • predictor: The Segment Anything predictor.
  • input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
  • save_path: Path to save the embeddings in a zarr container. By default, set to 'None', i.e. the computed embeddings will not be stored locally.
  • lazy_loading: Whether to load all embeddings into memory or return an object to load them on demand when required. This only has an effect if 'save_path' is given and if the input is 3 dimensional. By default, set to 'False'.
  • ndim: The dimensionality of the data. If not given will be deduced from the input data. By default, set to 'None', i.e. will be computed from the provided input_.
  • tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
  • verbose: Whether to be verbose in the computation. By default, set to 'True'.
  • batch_size: The batch size for precomputing image embeddings over tiles (or planes). By default, set to '1'.
  • mask: An optional mask to define areas that are ignored in the computation. The mask will be used within tiled embedding computation and tiles that don't contain any foreground in the mask will be excluded from the computation. It does not have any effect for non-tiled embeddings.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
Returns:

The image embeddings.

def set_precomputed( predictor: mobile_sam.predictor.SamPredictor, image_embeddings: Dict[str, Any], i: Optional[int] = None, tile_id: Optional[int] = None) -> mobile_sam.predictor.SamPredictor:
1206def set_precomputed(
1207    predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None,
1208) -> SamPredictor:
1209    """Set the precomputed image embeddings for a predictor.
1210
1211    Args:
1212        predictor: The Segment Anything predictor.
1213        image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`.
1214        i: Index for the image data. Required if `image` has three spatial dimensions
1215            or a time dimension and two spatial dimensions.
1216        tile_id: Index for the tile. This is required if the embeddings are tiled.
1217
1218    Returns:
1219        The predictor with set features.
1220    """
1221    if tile_id is not None:
1222        tile_features = image_embeddings["features"][str(tile_id)]
1223        tile_image_embeddings = {
1224            "features": tile_features,
1225            "input_size": tile_features.attrs["input_size"],
1226            "original_size": tile_features.attrs["original_size"]
1227        }
1228        return set_precomputed(predictor, tile_image_embeddings, i=i)
1229
1230    device = predictor.device
1231    features = image_embeddings["features"]
1232    assert features.ndim in (4, 5), f"{features.ndim}"
1233    if features.ndim == 5 and i is None:
1234        raise ValueError("The data is 3D so an index i is needed.")
1235    elif features.ndim == 4 and i is not None:
1236        raise ValueError("The data is 2D so an index is not needed.")
1237
1238    if i is None:
1239        predictor.features = features.to(device) if torch.is_tensor(features) else \
1240            torch.from_numpy(features[:]).to(device)
1241    else:
1242        predictor.features = features[i].to(device) if torch.is_tensor(features) else \
1243            torch.from_numpy(features[i]).to(device)
1244
1245    predictor.original_size = image_embeddings["original_size"]
1246    predictor.input_size = image_embeddings["input_size"]
1247    predictor.is_image_set = True
1248
1249    return predictor

Set the precomputed image embeddings for a predictor.

Arguments:
  • predictor: The Segment Anything predictor.
  • image_embeddings: The precomputed image embeddings computed by precompute_image_embeddings.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_id: Index for the tile. This is required if the embeddings are tiled.
Returns:

The predictor with set features.

def compute_iou(mask1: numpy.ndarray, mask2: numpy.ndarray) -> float:
1257def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
1258    """Compute the intersection over union of two masks.
1259
1260    Args:
1261        mask1: The first mask.
1262        mask2: The second mask.
1263
1264    Returns:
1265        The intersection over union of the two masks.
1266    """
1267    overlap = np.logical_and(mask1 == 1, mask2 == 1).sum()
1268    union = np.logical_or(mask1 == 1, mask2 == 1).sum()
1269    eps = 1e-7
1270    iou = float(overlap) / (float(union) + eps)
1271    return iou

Compute the intersection over union of two masks.

Arguments:
  • mask1: The first mask.
  • mask2: The second mask.
Returns:

The intersection over union of the two masks.

def get_centers_and_bounding_boxes( segmentation: numpy.ndarray, mode: str = 'v') -> Tuple[Dict[int, numpy.ndarray], Dict[int, tuple]]:
1274def get_centers_and_bounding_boxes(
1275    segmentation: np.ndarray, mode: str = "v"
1276) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]:
1277    """Returns the center coordinates of the foreground instances in the ground-truth.
1278
1279    Args:
1280        segmentation: The segmentation.
1281        mode: Determines the functionality used for computing the centers.
1282            If 'v', the object's eccentricity centers computed by vigra are used.
1283            If 'p' the object's centroids computed by skimage are used.
1284
1285    Returns:
1286        A dictionary that maps object ids to the corresponding centroid.
1287        A dictionary that maps object_ids to the corresponding bounding box.
1288    """
1289    assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra"
1290
1291    properties = regionprops(segmentation)
1292
1293    if mode == "p":
1294        center_coordinates = {prop.label: prop.centroid for prop in properties}
1295    elif mode == "v":
1296        center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32'))
1297        center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0}
1298
1299    bbox_coordinates = {prop.label: prop.bbox for prop in properties}
1300
1301    assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}"
1302    return center_coordinates, bbox_coordinates

Returns the center coordinates of the foreground instances in the ground-truth.

Arguments:
  • segmentation: The segmentation.
  • mode: Determines the functionality used for computing the centers. If 'v', the object's eccentricity centers computed by vigra are used. If 'p' the object's centroids computed by skimage are used.
Returns:

A dictionary that maps object ids to the corresponding centroid. A dictionary that maps object_ids to the corresponding bounding box.

def load_image_data( path: str, key: Optional[str] = None, lazy_loading: bool = False) -> numpy.ndarray:
1305def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray:
1306    """Helper function to load image data from file.
1307
1308    Args:
1309        path: The filepath to the image data.
1310        key: The internal filepath for complex data formats like hdf5.
1311        lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
1312
1313    Returns:
1314        The image data.
1315    """
1316    if key is None:
1317        image_data = imageio.imread(path)
1318    else:
1319        with open_file(path, mode="r") as f:
1320            image_data = f[key]
1321            if not lazy_loading:
1322                image_data = image_data[:]
1323
1324    return image_data

Helper function to load image data from file.

Arguments:
  • path: The filepath to the image data.
  • key: The internal filepath for complex data formats like hdf5.
  • lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
Returns:

The image data.

def segmentation_to_one_hot( segmentation: numpy.ndarray, segmentation_ids: Optional[numpy.ndarray] = None) -> torch.Tensor:
1327def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor:
1328    """Convert the segmentation to one-hot encoded masks.
1329
1330    Args:
1331        segmentation: The segmentation.
1332        segmentation_ids: Optional subset of ids that will be used to subsample the masks.
1333            By default, computes the number of ids from the provided `segmentation` masks.
1334
1335    Returns:
1336        The one-hot encoded masks.
1337    """
1338    masks = segmentation.copy()
1339    if segmentation_ids is None:
1340        n_ids = int(segmentation.max())
1341
1342    else:
1343        msg = "No foreground objects were found."
1344        if len(segmentation_ids) == 0:  # The list should not be completely empty.
1345            raise RuntimeError(msg)
1346
1347        if 0 in segmentation_ids:  # The list should not have 'zero' as a value.
1348            raise RuntimeError(msg)
1349
1350        # the segmentation ids have to be sorted
1351        segmentation_ids = np.sort(segmentation_ids)
1352
1353        # set the non selected objects to zero and relabel sequentially
1354        masks[~np.isin(masks, segmentation_ids)] = 0
1355        masks = relabel_sequential(masks)[0]
1356        n_ids = len(segmentation_ids)
1357
1358    masks = torch.from_numpy(masks)
1359
1360    one_hot_shape = (n_ids + 1,) + masks.shape
1361    masks = masks.unsqueeze(0)  # add dimension to scatter
1362    masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:]
1363
1364    # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W
1365    masks = masks.unsqueeze(1)
1366    return masks

Convert the segmentation to one-hot encoded masks.

Arguments:
  • segmentation: The segmentation.
  • segmentation_ids: Optional subset of ids that will be used to subsample the masks. By default, computes the number of ids from the provided segmentation masks.
Returns:

The one-hot encoded masks.

def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1369def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1370    """Get a suitable block shape for chunking a given shape.
1371
1372    The primary use for this is determining chunk sizes for
1373    zarr arrays or block shapes for parallelization.
1374
1375    Args:
1376        shape: The image or volume shape.
1377
1378    Returns:
1379        The block shape.
1380    """
1381    ndim = len(shape)
1382    if ndim == 2:
1383        block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape))
1384    elif ndim == 3:
1385        block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape))
1386    else:
1387        raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.")
1388
1389    return block_shape

Get a suitable block shape for chunking a given shape.

The primary use for this is determining chunk sizes for zarr arrays or block shapes for parallelization.

Arguments:
  • shape: The image or volume shape.
Returns:

The block shape.

def micro_sam_info() -> None:
1392def micro_sam_info() -> None:
1393    """Display μSAM information using a rich console."""
1394    import psutil
1395    import platform
1396    import argparse
1397    from rich import progress
1398    from rich.panel import Panel
1399    from rich.table import Table
1400    from rich.console import Console
1401
1402    import torch
1403    import micro_sam
1404
1405    parser = argparse.ArgumentParser(description="μSAM Information Booth")
1406    parser.add_argument(
1407        "--download", nargs="+", metavar=("WHAT", "KIND"),
1408        help="Downloads the pretrained SAM models."
1409        "'--download models' -> downloads all pretrained models; "
1410        "'--download models vit_b_lm vit_b_em_organelles' -> downloads the listed models; "
1411        "'--download model/models vit_b_lm' -> downloads a single specified model."
1412    )
1413    args = parser.parse_args()
1414
1415    # Open up a new console.
1416    console = Console()
1417
1418    # The header for information CLI.
1419    console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center")
1420    console.print("-" * console.width)
1421
1422    # μSAM version panel.
1423    console.print(
1424        Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True)
1425    )
1426
1427    # The documentation link panel.
1428    console.print(
1429        Panel(
1430            "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n"
1431            "https://computational-cell-analytics.github.io/micro-sam", title="Documentation"
1432        )
1433    )
1434
1435    # The publication panel.
1436    console.print(
1437        Panel(
1438            "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n"
1439            "https://www.nature.com/articles/s41592-024-02580-4", title="Publication"
1440        )
1441    )
1442
1443    # Creating a cache directory when users' run `micro_sam.info`.
1444    cache_dir = get_cache_directory()
1445    os.makedirs(cache_dir, exist_ok=True)
1446
1447    # The cache directory panel.
1448    console.print(
1449        Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{cache_dir}", title="Cache Directory")
1450    )
1451
1452    # We have a simple versioning logic here (which is what I'll follow here for mapping model versions).
1453    available_models = []
1454    for model_name, model_path in models().urls.items():  # We filter out the decoder models.
1455        if model_name.endswith("decoder"):
1456            continue
1457
1458        if "https://dl.fbaipublicfiles.com/segment_anything/" in model_path:  # Valid v1 SAM models.
1459            available_models.append(model_name)
1460
1461        if "https://owncloud.gwdg.de/" in model_path:  # Our own hosted models (in their v1 mode quite often)
1462            if model_name == "vit_t":  # MobileSAM model.
1463                available_models.append(model_name)
1464            else:
1465                available_models.append(f"{model_name} (v1)")
1466
1467        # Now for our models, the BioImageIO ModelZoo upload structure is such that:
1468        # '/1/files' corresponds to v2 models.
1469        # '/1.1/files' corresponds to v3 models.
1470        # '/1.2/files' corresponds to v4 models.
1471        if "/1/files" in model_path:
1472            available_models.append(f"{model_name} (v2)")
1473        if "/1.1/files" in model_path:
1474            available_models.append(f"{model_name} (v3)")
1475        if "/1.2/files" in model_path:
1476            available_models.append(f"{model_name} (v4)")
1477
1478    model_list = "\n".join(available_models)
1479
1480    # The available models panel.
1481    console.print(
1482        Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models")
1483    )
1484
1485    # The system information table.
1486    total_memory = psutil.virtual_memory().total / (1024 ** 3)
1487    table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True)
1488    table.add_column("Property")
1489    table.add_column("Value", style="bold #56B4E9")
1490    table.add_row("System", platform.system())
1491    table.add_row("Node Name", platform.node())
1492    table.add_row("Release", platform.release())
1493    table.add_row("Version", platform.version())
1494    table.add_row("Machine", platform.machine())
1495    table.add_row("Processor", platform.processor())
1496    table.add_row("Platform", platform.platform())
1497    table.add_row("Total RAM (GB)", f"{total_memory:.2f}")
1498    console.print(table)
1499
1500    # The device information and check for available GPU acceleration.
1501    default_device = _get_default_device()
1502
1503    if default_device == "cuda":
1504        device_index = torch.cuda.current_device()
1505        device_name = torch.cuda.get_device_name(device_index)
1506        console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information"))
1507    elif default_device == "mps":
1508        console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information"))
1509    else:
1510        console.print(
1511            Panel(
1512                "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]",
1513                title="Device Information"
1514            )
1515        )
1516
1517    # The section allowing to download models.
1518    # NOTE: In future, can be extended to download sample data.
1519    if args.download:
1520        download_provided_args = [t.lower() for t in args.download]
1521        mode, *model_types = download_provided_args
1522
1523        if mode not in {"models", "model"}:
1524            console.print(f"[red]Unknown option for --download: {mode}[/]")
1525            return
1526
1527        if mode in ["model", "models"] and not model_types:  # If user did not specify, we will download all models.
1528            download_list = available_models
1529        else:
1530            download_list = model_types
1531            incorrect_models = [m for m in download_list if m not in available_models]
1532            if incorrect_models:
1533                console.print(Panel("[red]Unknown model(s):[/] " + ", ".join(incorrect_models), title="Download Error"))
1534                return
1535
1536        with progress.Progress(
1537            progress.SpinnerColumn(),
1538            progress.TextColumn("[progress.description]{task.description}"),
1539            progress.BarColumn(bar_width=None),
1540            "[progress.percentage]{task.percentage:>3.0f}%",
1541            progress.TimeRemainingColumn(),
1542            console=console,
1543        ) as prog:
1544            task = prog.add_task("[green]Downloading μSAM models…", total=len(download_list))
1545            for model_type in download_list:
1546                prog.update(task, description=f"Downloading [cyan]{model_type}[/]…")
1547                _download_sam_model(model_type=model_type)
1548                prog.advance(task)
1549
1550        console.print(Panel("[bold green] Downloads complete![/]", title="Finished"))

Display μSAM information using a rich console.

def mask_data_to_segmentation( masks: List[Dict[str, Any]], shape: Optional[Tuple[int, int]] = None, min_object_size: int = 0, max_object_size: Optional[int] = None, label_masks: bool = True, with_background: bool = False, merge_exclusively: bool = True) -> numpy.ndarray:
1650def mask_data_to_segmentation(
1651    masks: List[Dict[str, Any]],
1652    shape: Optional[Tuple[int, int]] = None,
1653    min_object_size: int = 0,
1654    max_object_size: Optional[int] = None,
1655    label_masks: bool = True,
1656    with_background: bool = False,
1657    merge_exclusively: bool = True,
1658) -> np.ndarray:
1659    """Convert the output of the automatic mask generation to an instance segmentation.
1660
1661    Args:
1662        masks: The outputs generated by `AutomaticMaskGenerator`, other classes from `micro_sam.instance_segmentation`,
1663            or from `micro_sam.inference` functions. Only supported for output_mode=binary_mask.
1664        shape: The shape of the output segmentation. If None, it will be derived from the mask input.
1665            If the mask where predicted with tiling then the shape must be given.
1666        min_object_size: The minimal size of an object in pixels. By default, set to '0'.
1667        max_object_size: The maximal size of an object in pixels.
1668        label_masks: Whether to apply connected components to the result before removing small objects.
1669            By default, set to 'True'.
1670        with_background: Whether to remove the largest object, which often covers the background for AMG.
1671        merge_exclusively: Whether to exclude previous merged masks from merging.
1672
1673    Returns:
1674        The instance segmentation.
1675    """
1676    masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
1677    if shape is None:
1678        shape = next(iter(masks))["segmentation"].shape
1679    segmentation = np.zeros(shape, dtype="uint32")
1680
1681    def require_numpy(mask):
1682        return mask.cpu().numpy() if torch.is_tensor(mask) else mask
1683
1684    seg_id = 1
1685    for mask_data in masks:
1686        area = mask_data["area"]
1687        if (area < min_object_size) or (max_object_size is not None and area > max_object_size):
1688            continue
1689
1690        this_mask = require_numpy(mask_data["segmentation"])
1691        this_seg_id = mask_data.get("seg_id", seg_id)
1692        if "global_bbox" in mask_data:
1693            bb = mask_data["bbox"]
1694            bb = np.s_[bb[1]:bb[1] + bb[3], bb[0]:bb[0] + bb[2]]
1695            global_bb = mask_data["global_bbox"]
1696            global_bb = np.s_[global_bb[1]:global_bb[1] + global_bb[3], global_bb[0]:global_bb[0] + global_bb[2]]
1697            if merge_exclusively:
1698                this_mask = np.logical_and(this_mask[bb], segmentation[global_bb] == 0)
1699            else:
1700                this_mask = this_mask[bb]
1701            segmentation[global_bb][this_mask] = this_seg_id
1702        else:
1703            if merge_exclusively:
1704                this_mask = np.logical_and(this_mask, segmentation == 0)
1705            segmentation[this_mask] = this_seg_id
1706        seg_id = this_seg_id + 1
1707
1708    block_shape = (512, 512)
1709    if label_masks:
1710        segmentation_cc = np.zeros_like(segmentation, dtype=segmentation.dtype)
1711        segmentation_cc = parallel_impl.label(segmentation, out=segmentation_cc, block_shape=block_shape)
1712        segmentation = segmentation_cc
1713
1714    seg_ids, sizes = parallel_impl.unique(segmentation, return_counts=True, block_shape=block_shape)
1715    filter_ids = seg_ids[sizes < min_object_size]
1716    if with_background:
1717        bg_id = seg_ids[np.argmax(sizes)]
1718        filter_ids = np.concatenate([filter_ids, [bg_id]])
1719
1720    filter_mask = np.zeros(segmentation.shape, dtype="bool")
1721    filter_mask = parallel_impl.isin(segmentation, filter_ids, out=filter_mask, block_shape=block_shape)
1722    segmentation[filter_mask] = 0
1723    parallel_impl.relabel_consecutive(segmentation, block_shape=block_shape)[0]
1724
1725    return segmentation

Convert the output of the automatic mask generation to an instance segmentation.

Arguments:
  • masks: The outputs generated by AutomaticMaskGenerator, other classes from micro_sam.instance_segmentation, or from micro_sam.inference functions. Only supported for output_mode=binary_mask.
  • shape: The shape of the output segmentation. If None, it will be derived from the mask input. If the mask where predicted with tiling then the shape must be given.
  • min_object_size: The minimal size of an object in pixels. By default, set to '0'.
  • max_object_size: The maximal size of an object in pixels.
  • label_masks: Whether to apply connected components to the result before removing small objects. By default, set to 'True'.
  • with_background: Whether to remove the largest object, which often covers the background for AMG.
  • merge_exclusively: Whether to exclude previous merged masks from merging.
Returns:

The instance segmentation.

def apply_nms( predictions: List[Dict[str, Any]], min_size: int, shape: Optional[Tuple[int, int]] = None, perform_box_nms: bool = False, nms_thresh: float = 0.9, max_size: Optional[int] = None, intersection_over_min: bool = False) -> numpy.ndarray:
1728def apply_nms(
1729    predictions: List[Dict[str, Any]],
1730    min_size: int,
1731    shape: Optional[Tuple[int, int]] = None,
1732    perform_box_nms: bool = False,
1733    nms_thresh: float = 0.9,
1734    max_size: Optional[int] = None,
1735    intersection_over_min: bool = False,
1736) -> np.ndarray:
1737    """Apply non-maximum suppression to mask predictions from a segment anything model.
1738
1739    Args:
1740        predictions: The mask predictions from SAM.
1741        min_size: The minimum mask size to keep in the output.
1742        shape: The shape of the output segmentation.
1743            Has to be passed for predictions obtained from tiling.
1744        perform_box_nms: Whether to perform NMS on the box coordinates or on the masks.
1745        nms_thresh: The threshold for filtering out objects in NMS.
1746        max_size: The maximum mask size to keep in the output.
1747        intersection_over_min: Whether to perform intersection over the minimum overlap shape
1748            or to perform intersection over union.
1749
1750    Returns:
1751        The segmentation obtained from merging the masks left after NMS.
1752    """
1753    data = amg_utils.MaskData(
1754        masks=torch.cat([pred["segmentation"][None] for pred in predictions], dim=0),
1755        iou_preds=torch.tensor([pred["predicted_iou"] for pred in predictions]),
1756    )
1757    data["boxes"] = torch.tensor(np.array([pred["bbox"] for pred in predictions]))
1758    data["area"] = [mask.sum() for mask in data["masks"]]
1759    data["stability_scores"] = torch.tensor([pred["stability_score"] for pred in predictions])
1760
1761    # Check if the input comes with a 'global_bbox' attribute. If it does, then the predictions are from
1762    # a tiled prediction. In this case, we have to take the coordinates w.r.t. the tiling into account.
1763    if "global_bbox" in predictions[0]:
1764        if shape is None:
1765            raise ValueError("The output shape 'shape' has to be passed for tiled predictions.")
1766        data["global_boxes"] = torch.tensor(np.array([pred["global_bbox"] for pred in predictions]))
1767        is_tiled = True
1768    else:
1769        is_tiled = False
1770
1771    if min_size > 0:
1772        keep_by_size = torch.tensor(
1773            [i for i, area in enumerate(data["area"]) if area > min_size], dtype=torch.long,
1774        )
1775        data.filter(keep_by_size)
1776
1777    if max_size is not None:
1778        keep_by_size = torch.tensor([i for i, area in enumerate(data["area"]) if area < max_size])
1779        data.filter(keep_by_size)
1780
1781    scores = data["iou_preds"] * data["stability_scores"]
1782    if perform_box_nms:
1783        assert not intersection_over_min  # not implemented
1784        keep_by_nms = batched_nms(
1785            data["global_boxes"].float() if is_tiled else data["boxes"].float(),
1786            scores,
1787            torch.zeros_like(data["boxes"][:, 0]),  # categories
1788            iou_threshold=nms_thresh,
1789        )
1790    else:
1791        keep_by_nms = _batched_mask_nms(
1792            masks=data["masks"],
1793            boxes=data["global_boxes"].float() if is_tiled else data["boxes"].float(),
1794            scores=scores,
1795            nms_thresh=nms_thresh,
1796            intersection_over_min=intersection_over_min,
1797        )
1798    data.filter(keep_by_nms)
1799
1800    if is_tiled:
1801        mask_data = [
1802            {"segmentation": mask, "area": area, "bbox": box, "global_bbox": global_box}
1803            for mask, area, box, global_box in zip(data["masks"], data["area"], data["boxes"], data["global_boxes"])
1804        ]
1805    else:
1806        mask_data = [
1807            {"segmentation": mask, "area": area, "bbox": box}
1808            for mask, area, box in zip(data["masks"], data["area"], data["boxes"])
1809        ]
1810
1811    if shape is None:
1812        shape = predictions[0]["segmentation"].shape
1813    if mask_data:
1814        segmentation = mask_data_to_segmentation(mask_data, shape=shape, min_object_size=min_size)
1815    else:  # In case all objects have been filtered out due to size filtering.
1816        segmentation = np.zeros(shape, dtype="uint32")
1817
1818    return segmentation

Apply non-maximum suppression to mask predictions from a segment anything model.

Arguments:
  • predictions: The mask predictions from SAM.
  • min_size: The minimum mask size to keep in the output.
  • shape: The shape of the output segmentation. Has to be passed for predictions obtained from tiling.
  • perform_box_nms: Whether to perform NMS on the box coordinates or on the masks.
  • nms_thresh: The threshold for filtering out objects in NMS.
  • max_size: The maximum mask size to keep in the output.
  • intersection_over_min: Whether to perform intersection over the minimum overlap shape or to perform intersection over union.
Returns:

The segmentation obtained from merging the masks left after NMS.