micro_sam.util

Helper functions for downloading Segment Anything models and predicting image embeddings.

   1"""
   2Helper functions for downloading Segment Anything models and predicting image embeddings.
   3"""
   4
   5import os
   6import pickle
   7import hashlib
   8import warnings
   9from pathlib import Path
  10from collections import OrderedDict
  11from typing import Any, Dict, Iterable, Optional, Tuple, Union
  12
  13import zarr
  14import vigra
  15import torch
  16import pooch
  17import xxhash
  18import numpy as np
  19import imageio.v3 as imageio
  20from skimage.measure import regionprops
  21from skimage.segmentation import relabel_sequential
  22
  23from elf.io import open_file
  24
  25from nifty.tools import blocking
  26
  27from .__version__ import __version__
  28from . import models as custom_models
  29
  30try:
  31    # Avoid import warnigns from mobile_sam
  32    with warnings.catch_warnings():
  33        warnings.simplefilter("ignore")
  34        from mobile_sam import sam_model_registry, SamPredictor
  35    VIT_T_SUPPORT = True
  36except ImportError:
  37    from segment_anything import sam_model_registry, SamPredictor
  38    VIT_T_SUPPORT = False
  39
  40try:
  41    from napari.utils import progress as tqdm
  42except ImportError:
  43    from tqdm import tqdm
  44
  45# This is the default model used in micro_sam
  46# Currently it is set to vit_b_lm
  47_DEFAULT_MODEL = "vit_b_lm"
  48
  49# The valid model types. Each type corresponds to the architecture of the
  50# vision transformer used within SAM.
  51_MODEL_TYPES = ("vit_l", "vit_b", "vit_h", "vit_t")
  52
  53
  54# TODO define the proper type for image embeddings
  55ImageEmbeddings = Dict[str, Any]
  56"""@private"""
  57
  58
  59def get_cache_directory() -> None:
  60    """Get micro-sam cache directory location.
  61
  62    Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
  63    """
  64    default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam"))
  65    cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory))
  66    return cache_directory
  67
  68
  69#
  70# Functionality for model download and export
  71#
  72
  73
  74def microsam_cachedir() -> None:
  75    """Return the micro-sam cache directory.
  76
  77    Returns the top level cache directory for micro-sam models and sample data.
  78
  79    Every time this function is called, we check for any user updates made to
  80    the MICROSAM_CACHEDIR os environment variable since the last time.
  81    """
  82    cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam")
  83    return cache_directory
  84
  85
  86def models():
  87    """Return the segmentation models registry.
  88
  89    We recreate the model registry every time this function is called,
  90    so any user changes to the default micro-sam cache directory location
  91    are respected.
  92    """
  93
  94    # We use xxhash to compute the hash of the models, see
  95    # https://github.com/computational-cell-analytics/micro-sam/issues/283
  96    # (It is now a dependency, so we don't provide the sha256 fallback anymore.)
  97    # To generate the xxh128 hash:
  98    #     xxh128sum filename
  99    encoder_registry = {
 100        # The default segment anything models:
 101        "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a",
 102        "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2",
 103        "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b",
 104        # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM.
 105        "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0",
 106        # The current version of our models in the modelzoo.
 107        # LM generalist models:
 108        "vit_l_lm": "xxh128:fc32ea6f7fcc7eb02737d1304f81f5f2",
 109        "vit_b_lm": "xxh128:8fd5806be3c3ba213e19a709d6d1495f",
 110        "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e",
 111        # EM models:
 112        "vit_l_em_organelles": "xxh128:096c9695966803ca6fde24f4c1e3c3fb",
 113        "vit_b_em_organelles": "xxh128:f6f6593aeecd0e15a07bdac86360b6cc",
 114        "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9",
 115        # Histopathology models:
 116        "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974",
 117        "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2",
 118        "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e",
 119        # Medical Imaging models:
 120        "vit_b_medical_imaging": "xxh128:5be672f1458263a9edc9fd40d7f56ac1",
 121    }
 122    # Additional decoders for instance segmentation.
 123    decoder_registry = {
 124        # LM generalist models:
 125        "vit_l_lm_decoder": "xxh128:779b5a50ecc6d46d495753fba8717f2f",
 126        "vit_b_lm_decoder": "xxh128:9f580a96984b3085389ced5d9a4ae75d",
 127        "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823",
 128        # EM models:
 129        "vit_l_em_organelles_decoder": "xxh128:d60fd96bd6060856f6430f29e42568fb",
 130        "vit_b_em_organelles_decoder": "xxh128:b2d4dcffb99f76d83497d39ee500088f",
 131        "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44",
 132        # Histopathology models:
 133        "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213",
 134        "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04",
 135        "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47",
 136    }
 137    registry = {**encoder_registry, **decoder_registry}
 138
 139    encoder_urls = {
 140        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
 141        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
 142        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
 143        "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
 144        "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l.pt",
 145        "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b.pt",
 146        "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt",
 147        "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt",  # noqa
 148        "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt",
 149        "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt",  # noqa
 150        "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download",
 151        "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download",
 152        "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download",
 153        "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/AB69HGhj8wuozXQ/download",
 154    }
 155
 156    decoder_urls = {
 157        "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l_decoder.pt",  # noqa
 158        "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b_decoder.pt",  # noqa
 159        "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt",  # noqa
 160        "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt",  # noqa
 161        "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt",  # noqa
 162        "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt",  # noqa
 163        "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download",
 164        "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download",
 165        "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download",
 166    }
 167    urls = {**encoder_urls, **decoder_urls}
 168
 169    models = pooch.create(
 170        path=os.path.join(microsam_cachedir(), "models"),
 171        base_url="",
 172        registry=registry,
 173        urls=urls,
 174    )
 175    return models
 176
 177
 178def _get_default_device():
 179    # check that we're in CI and use the CPU if we are
 180    # otherwise the tests may run out of memory on MAC if MPS is used.
 181    if os.getenv("GITHUB_ACTIONS") == "true":
 182        return "cpu"
 183    # Use cuda enabled gpu if it's available.
 184    if torch.cuda.is_available():
 185        device = "cuda"
 186    # As second priority use mps.
 187    # See https://pytorch.org/docs/stable/notes/mps.html for details
 188    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
 189        print("Using apple MPS device.")
 190        device = "mps"
 191    # Use the CPU as fallback.
 192    else:
 193        device = "cpu"
 194    return device
 195
 196
 197def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
 198    """Get the torch device.
 199
 200    If no device is passed the default device for your system is used.
 201    Else it will be checked if the device you have passed is supported.
 202
 203    Args:
 204        device: The input device.
 205
 206    Returns:
 207        The device.
 208    """
 209    if device is None or device == "auto":
 210        device = _get_default_device()
 211    else:
 212        device_type = device if isinstance(device, str) else device.type
 213        if device_type.lower() == "cuda":
 214            if not torch.cuda.is_available():
 215                raise RuntimeError("PyTorch CUDA backend is not available.")
 216        elif device_type.lower() == "mps":
 217            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
 218                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
 219        elif device_type.lower() == "cpu":
 220            pass  # cpu is always available
 221        else:
 222            raise RuntimeError(
 223                f"Unsupported device: {device}\n"
 224                "Please choose from 'cpu', 'cuda', or 'mps'."
 225            )
 226
 227    return device
 228
 229
 230def _available_devices():
 231    available_devices = []
 232    for i in ["cuda", "mps", "cpu"]:
 233        try:
 234            device = get_device(i)
 235        except RuntimeError:
 236            pass
 237        else:
 238            available_devices.append(device)
 239    return available_devices
 240
 241
 242# We write a custom unpickler that skips objects that cannot be found instead of
 243# throwing an AttributeError or ModueNotFoundError.
 244# NOTE: since we just want to unpickle the model to load its weights these errors don't matter.
 245# See also https://stackoverflow.com/questions/27732354/unable-to-load-files-using-pickle-and-multiple-modules
 246class _CustomUnpickler(pickle.Unpickler):
 247    def find_class(self, module, name):
 248        try:
 249            return super().find_class(module, name)
 250        except (AttributeError, ModuleNotFoundError) as e:
 251            warnings.warn(f"Did not find {module}:{name} and will skip it, due to error {e}")
 252            return None
 253
 254
 255def _compute_hash(path, chunk_size=8192):
 256    hash_obj = xxhash.xxh128()
 257    with open(path, "rb") as f:
 258        chunk = f.read(chunk_size)
 259        while chunk:
 260            hash_obj.update(chunk)
 261            chunk = f.read(chunk_size)
 262    hash_val = hash_obj.hexdigest()
 263    return f"xxh128:{hash_val}"
 264
 265
 266# Load the state from a checkpoint.
 267# The checkpoint can either contain a sam encoder state
 268# or it can be a checkpoint for model finetuning.
 269def _load_checkpoint(checkpoint_path):
 270    # Over-ride the unpickler with our custom one.
 271    # This enables imports from torch_em checkpoints even if it cannot be fully unpickled.
 272    custom_pickle = pickle
 273    custom_pickle.Unpickler = _CustomUnpickler
 274
 275    state = torch.load(checkpoint_path, map_location="cpu", pickle_module=custom_pickle)
 276    if "model_state" in state:
 277        # Copy the model weights from torch_em's training format.
 278        model_state = state["model_state"]
 279        sam_prefix = "sam."
 280        model_state = OrderedDict(
 281            [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()]
 282        )
 283    else:
 284        model_state = state
 285
 286    return state, model_state
 287
 288
 289def get_sam_model(
 290    model_type: str = _DEFAULT_MODEL,
 291    device: Optional[Union[str, torch.device]] = None,
 292    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 293    return_sam: bool = False,
 294    return_state: bool = False,
 295    peft_kwargs: Optional[Dict] = None,
 296    flexible_load_checkpoint: bool = False,
 297    **model_kwargs,
 298) -> SamPredictor:
 299    r"""Get the SegmentAnything Predictor.
 300
 301    This function will download the required model or load it from the cached weight file.
 302    This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
 303    The name of the requested model can be set via `model_type`.
 304    See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
 305    for an overview of the available models
 306
 307    Alternatively this function can also load a model from weights stored in a local filepath.
 308    The corresponding file path is given via `checkpoint_path`. In this case `model_type`
 309    must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
 310    a SAM model with vit_b encoder.
 311
 312    By default the models are downloaded to a folder named 'micro_sam/models'
 313    inside your default cache directory, eg:
 314    * Mac: ~/Library/Caches/<AppName>
 315    * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined.
 316    * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache
 317    See the pooch.os_cache() documentation for more details:
 318    https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
 319
 320    Args:
 321        model_type: The Segment Anything model to use. Will use the standard `vit_l` model by default.
 322            To get a list of all available model names you can call `get_model_names`.
 323        device: The device for the model. If none is given will use GPU if available.
 324        checkpoint_path: The path to a file with weights that should be used instead of using the
 325            weights corresponding to `model_type`. If given, `model_type` must match the architecture
 326            corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder
 327            then `model_type` must be given as "vit_b".
 328        return_sam: Return the sam model object as well as the predictor.
 329        return_state: Return the unpickled checkpoint state.
 330        peft_kwargs: Keyword arguments for th PEFT wrapper class.
 331        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
 332        model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
 333
 334    Returns:
 335        The segment anything predictor.
 336    """
 337    device = get_device(device)
 338
 339    # We support passing a local filepath to a checkpoint.
 340    # In this case we do not download any weights but just use the local weight file,
 341    # as it is, without copying it over anywhere or checking it's hashes.
 342
 343    # checkpoint_path has not been passed, we download a known model and derive the correct
 344    # URL from the model_type. If the model_type is invalid pooch will raise an error.
 345    if checkpoint_path is None:
 346        model_registry = models()
 347        checkpoint_path = model_registry.fetch(model_type, progressbar=True)
 348        model_hash = model_registry.registry[model_type]
 349
 350        # If we have a custom model then we may also have a decoder checkpoint.
 351        # Download it here, so that we can add it to the state.
 352        decoder_name = f"{model_type}_decoder"
 353        decoder_path = model_registry.fetch(
 354            decoder_name, progressbar=True
 355        ) if decoder_name in model_registry.registry else None
 356
 357    # checkpoint_path has been passed, we use it instead of downloading a model.
 358    else:
 359        # Check if the file exists and raise an error otherwise.
 360        # We can't check any hashes here, and we don't check if the file is actually a valid weight file.
 361        # (If it isn't the model creation will fail below.)
 362        if not os.path.exists(checkpoint_path):
 363            raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.")
 364        model_hash = _compute_hash(checkpoint_path)
 365        decoder_path = None
 366
 367    # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped
 368    # before calling sam_model_registry.
 369    abbreviated_model_type = model_type[:5]
 370    if abbreviated_model_type not in _MODEL_TYPES:
 371        raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}")
 372    if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
 373        raise RuntimeError(
 374            "'mobile_sam' is required for the vit-tiny. "
 375            "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
 376        )
 377
 378    state, model_state = _load_checkpoint(checkpoint_path)
 379
 380    # Whether to update parameters necessary to initialize the model
 381    if model_kwargs:  # Checks whether model_kwargs have been provided or not
 382        if abbreviated_model_type == "vit_t":
 383            raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.")
 384        sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs)
 385
 386    else:
 387        sam = sam_model_registry[abbreviated_model_type]()
 388
 389    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
 390    # Overwrites the SAM model by freezing the backbone and allow PEFT.
 391    if peft_kwargs and isinstance(peft_kwargs, dict):
 392        # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference.
 393        peft_kwargs.pop("quantize", None)
 394
 395        if abbreviated_model_type == "vit_t":
 396            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
 397
 398        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
 399    # In case the model checkpoints have some issues when it is initialized with different parameters than default.
 400    if flexible_load_checkpoint:
 401        sam = _handle_checkpoint_loading(sam, model_state)
 402    else:
 403        sam.load_state_dict(model_state)
 404    sam.to(device=device)
 405
 406    predictor = SamPredictor(sam)
 407    predictor.model_type = abbreviated_model_type
 408    predictor._hash = model_hash
 409    predictor.model_name = model_type
 410
 411    # Add the decoder to the state if we have one and if the state is returned.
 412    if decoder_path is not None and return_state:
 413        state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False)
 414
 415    if return_sam and return_state:
 416        return predictor, sam, state
 417    if return_sam:
 418        return predictor, sam
 419    if return_state:
 420        return predictor, state
 421    return predictor
 422
 423
 424def _handle_checkpoint_loading(sam, model_state):
 425    # Whether to handle the mismatch issues in a bit more elegant way.
 426    # eg. while training for multi-class semantic segmentation in the mask encoder,
 427    # parameters are updated - leading to "size mismatch" errors
 428
 429    new_state_dict = {}  # for loading matching parameters
 430    mismatched_layers = []  # for tracking mismatching parameters
 431
 432    reference_state = sam.state_dict()
 433
 434    for k, v in model_state.items():
 435        if k in reference_state:  # This is done to get rid of unwanted layers from pretrained SAM.
 436            if reference_state[k].size() == v.size():
 437                new_state_dict[k] = v
 438            else:
 439                mismatched_layers.append(k)
 440
 441    reference_state.update(new_state_dict)
 442
 443    if len(mismatched_layers) > 0:
 444        warnings.warn(f"The layers with size mismatch: {mismatched_layers}")
 445
 446    for mlayer in mismatched_layers:
 447        if 'weight' in mlayer:
 448            torch.nn.init.kaiming_uniform_(reference_state[mlayer])
 449        elif 'bias' in mlayer:
 450            reference_state[mlayer].zero_()
 451
 452    sam.load_state_dict(reference_state)
 453
 454    return sam
 455
 456
 457def export_custom_sam_model(
 458    checkpoint_path: Union[str, os.PathLike],
 459    model_type: str,
 460    save_path: Union[str, os.PathLike],
 461    with_segmentation_decoder: bool = False,
 462) -> None:
 463    """Export a finetuned Segment Anything Model to the standard model format.
 464
 465    The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
 466
 467    Args:
 468        checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
 469        model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
 470        save_path: Where to save the exported model.
 471        with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well.
 472            If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
 473    """
 474    _, state = get_sam_model(
 475        model_type=model_type, checkpoint_path=checkpoint_path, return_state=True, device="cpu",
 476    )
 477    model_state = state["model_state"]
 478    prefix = "sam."
 479    model_state = OrderedDict(
 480        [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]
 481    )
 482
 483    # Store the 'decoder_state' as well, if desired.
 484    if with_segmentation_decoder:
 485        if "decoder_state" not in state:
 486            raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.")
 487        decoder_state = state["decoder_state"]
 488        save_state = {"model_state": model_state, "decoder_state": decoder_state}
 489    else:
 490        save_state = model_state
 491
 492    torch.save(save_state, save_path)
 493
 494
 495def export_custom_qlora_model(
 496    checkpoint_path: Union[str, os.PathLike],
 497    finetuned_path: Union[str, os.PathLike],
 498    model_type: str,
 499    save_path: Union[str, os.PathLike],
 500) -> None:
 501    """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
 502
 503    The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`.
 504
 505    Args:
 506        checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
 507        finetuned_path: The path to the new finetuned model, using QLoRA.
 508        model_type: The Segment Anything Model type corresponding to the checkpoint.
 509        save_path: Where to save the exported model.
 510    """
 511    # Step 1: Get the base SAM model: used to start finetuning from.
 512    _, sam = get_sam_model(
 513        model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True,
 514    )
 515
 516    # Step 2: Load the QLoRA-style finetuned model.
 517    ft_state, ft_model_state = _load_checkpoint(finetuned_path)
 518
 519    # Step 3: Get LoRA weights from QLoRA and retain all original parameters from the base SAM model.
 520    updated_model_state = {}
 521
 522    # - At first, we get all LoRA layers from the QLoRA-style finetuned model checkpoint.
 523    for k, v in ft_model_state.items():
 524        if k.find("w_b_linear") != -1 or k.find("w_a_linear") != -1:
 525            updated_model_state[k] = v
 526
 527    # - Next, we get all the remaining parameters from the base SAM model.
 528    for k, v in sam.state_dict().items():
 529        if k.find("attn.qkv.") != -1:
 530            k = k.replace("qkv", "qkv.qkv_proj")
 531            updated_model_state[k] = v
 532        else:
 533
 534            updated_model_state[k] = v
 535
 536    # - Finally, we replace the old model state with the new one (to retain other relevant stuff)
 537    ft_state['model_state'] = updated_model_state
 538
 539    # Step 4: Store the new "state" to "save_path"
 540    torch.save(ft_state, save_path)
 541
 542
 543def get_model_names() -> Iterable:
 544    model_registry = models()
 545    model_names = model_registry.registry.keys()
 546    return model_names
 547
 548
 549#
 550# Functionality for precomputing image embeddings.
 551#
 552
 553
 554def _to_image(input_):
 555    # we require the input to be uint8
 556    if input_.dtype != np.dtype("uint8"):
 557        # first normalize the input to [0, 1]
 558        input_ = input_.astype("float32") - input_.min()
 559        input_ = input_ / input_.max()
 560        # then bring to [0, 255] and cast to uint8
 561        input_ = (input_ * 255).astype("uint8")
 562
 563    if input_.ndim == 2:
 564        image = np.concatenate([input_[..., None]] * 3, axis=-1)
 565    elif input_.ndim == 3 and input_.shape[-1] == 3:
 566        image = input_
 567    else:
 568        raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image.")
 569
 570    return image
 571
 572
 573def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update):
 574    tiling = blocking([0, 0], input_.shape[:2], tile_shape)
 575    n_tiles = tiling.numberOfBlocks
 576
 577    features = f.require_group("features")
 578    features.attrs["shape"] = input_.shape[:2]
 579    features.attrs["tile_shape"] = tile_shape
 580    features.attrs["halo"] = halo
 581
 582    pbar_init(n_tiles, "Compute Image Embeddings 2D tiled.")
 583    for tile_id in range(n_tiles):
 584        tile = tiling.getBlockWithHalo(tile_id, list(halo))
 585        outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end))
 586
 587        predictor.reset_image()
 588        tile_input = _to_image(input_[outer_tile])
 589        predictor.set_image(tile_input)
 590        tile_features = predictor.get_image_embedding()
 591        original_size = predictor.original_size
 592        input_size = predictor.input_size
 593
 594        ds = features.create_dataset(
 595            str(tile_id), data=tile_features.cpu().numpy(), compression="gzip", chunks=tile_features.shape
 596        )
 597        ds.attrs["original_size"] = original_size
 598        ds.attrs["input_size"] = input_size
 599        pbar_update(1)
 600
 601    _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None)
 602
 603    return features
 604
 605
 606def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update):
 607    assert input_.ndim == 3
 608
 609    shape = input_.shape[1:]
 610    tiling = blocking([0, 0], shape, tile_shape)
 611    n_tiles = tiling.numberOfBlocks
 612
 613    features = f.require_group("features")
 614    features.attrs["shape"] = shape
 615    features.attrs["tile_shape"] = tile_shape
 616    features.attrs["halo"] = halo
 617
 618    n_slices = input_.shape[0]
 619    pbar_init(n_tiles * n_slices, "Compute Image Embeddings 3D tiled.")
 620
 621    for tile_id in range(n_tiles):
 622        tile = tiling.getBlockWithHalo(tile_id, list(halo))
 623        outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end))
 624
 625        ds = None
 626        for z in range(n_slices):
 627            predictor.reset_image()
 628            tile_input = _to_image(input_[z][outer_tile])
 629            predictor.set_image(tile_input)
 630            tile_features = predictor.get_image_embedding()
 631
 632            if ds is None:
 633                shape = (input_.shape[0],) + tile_features.shape
 634                chunks = (1,) + tile_features.shape
 635                ds = features.create_dataset(
 636                    str(tile_id), shape=shape, dtype="float32", compression="gzip", chunks=chunks
 637                )
 638
 639            ds[z] = tile_features.cpu().numpy()
 640            pbar_update(1)
 641
 642        original_size = predictor.original_size
 643        input_size = predictor.input_size
 644
 645        ds.attrs["original_size"] = original_size
 646        ds.attrs["input_size"] = input_size
 647
 648    _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None)
 649
 650    return features
 651
 652
 653def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update):
 654    # Check if the embeddings are already cached.
 655    if save_path is not None and "input_size" in f.attrs:
 656        # In this case we load the embeddings.
 657        features = f["features"][:]
 658        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 659        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 660        # Also set the embeddings.
 661        set_precomputed(predictor, image_embeddings)
 662        return image_embeddings
 663
 664    pbar_init(1, "Compute Image Embeddings 2D.")
 665    # Otherwise we have to compute the embeddings.
 666    predictor.reset_image()
 667    predictor.set_image(_to_image(input_))
 668    features = predictor.get_image_embedding().cpu().numpy()
 669    original_size = predictor.original_size
 670    input_size = predictor.input_size
 671    pbar_update(1)
 672
 673    # Save the embeddings if we have a save_path.
 674    if save_path is not None:
 675        f.create_dataset("features", data=features, compression="gzip", chunks=features.shape)
 676        _write_embedding_signature(
 677            f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size,
 678        )
 679
 680    image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 681    return image_embeddings
 682
 683
 684def _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update):
 685    # Check if the features are already computed.
 686    if "input_size" in f.attrs:
 687        features = f["features"]
 688        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 689        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 690        return image_embeddings
 691
 692    # Otherwise compute them. Note: saving happens automatically because we
 693    # always write the features to zarr. If no save path is given we use an in-memory zarr.
 694    features = _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update)
 695    image_embeddings = {"features": features, "input_size": None, "original_size": None}
 696    return image_embeddings
 697
 698
 699def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update):
 700    # Check if the embeddings are already fully cached.
 701    if save_path is not None and "input_size" in f.attrs:
 702        # In this case we load the embeddings.
 703        features = f["features"] if lazy_loading else f["features"][:]
 704        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 705        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 706        return image_embeddings
 707
 708    # Otherwise we have to compute the embeddings.
 709
 710    # First check if we have a save path or not and set things up accordingly.
 711    if save_path is None:
 712        features = []
 713        save_features = False
 714        partial_features = False
 715    else:
 716        save_features = True
 717        embed_shape = (1, 256, 64, 64)
 718        shape = (input_.shape[0],) + embed_shape
 719        chunks = (1,) + embed_shape
 720        if "features" in f:
 721            partial_features = True
 722            features = f["features"]
 723            if features.shape != shape or features.chunks != chunks:
 724                raise RuntimeError("Invalid partial features")
 725        else:
 726            partial_features = False
 727            features = f.create_dataset("features", shape=shape, chunks=chunks, dtype="float32")
 728
 729    # Initialize the pbar.
 730    pbar_init(input_.shape[0], "Compute Image Embeddings 3D")
 731
 732    # Compute the embeddings for each slice.
 733    for z, z_slice in enumerate(input_):
 734        # Skip feature computation in case of partial features in non-zero slice.
 735        if partial_features and np.count_nonzero(features[z]) != 0:
 736            continue
 737
 738        predictor.reset_image()
 739        predictor.set_image(_to_image(z_slice))
 740        embedding = predictor.get_image_embedding()
 741        original_size, input_size = predictor.original_size, predictor.input_size
 742
 743        if save_features:
 744            features[z] = embedding.cpu().numpy()
 745        else:
 746            features.append(embedding[None])
 747        pbar_update(1)
 748
 749    if save_features:
 750        _write_embedding_signature(
 751            f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size,
 752        )
 753    else:
 754        # Concatenate across the z axis.
 755        features = torch.cat(features).cpu().numpy()
 756
 757    image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 758    return image_embeddings
 759
 760
 761def _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update):
 762    # Check if the features are already computed.
 763    if "input_size" in f.attrs:
 764        features = f["features"]
 765        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 766        image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 767        return image_embeddings
 768
 769    # Otherwise compute them. Note: saving happens automatically because we
 770    # always write the features to zarr. If no save path is given we use an in-memory zarr.
 771    features = _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update)
 772    image_embeddings = {"features": features, "input_size": None, "original_size": None}
 773    return image_embeddings
 774
 775
 776def _compute_data_signature(input_):
 777    data_signature = hashlib.sha1(np.asarray(input_).tobytes()).hexdigest()
 778    return data_signature
 779
 780
 781# Create all metadata that is stored along with the embeddings.
 782def _get_embedding_signature(input_, predictor, tile_shape, halo, data_signature=None):
 783    if data_signature is None:
 784        data_signature = _compute_data_signature(input_)
 785
 786    signature = {
 787        "data_signature": data_signature,
 788        "tile_shape": tile_shape if tile_shape is None else list(tile_shape),
 789        "halo": halo if halo is None else list(halo),
 790        "model_type": predictor.model_type,
 791        "model_name": predictor.model_name,
 792        "micro_sam_version": __version__,
 793        "model_hash": getattr(predictor, "_hash", None),
 794    }
 795    return signature
 796
 797
 798# Note: the input size and orginal size are different if embeddings are tiled or not.
 799# That's why we do not include them in the main signature that is being checked
 800# (_get_embedding_signature), but just add it for serialization here.
 801def _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size, original_size):
 802    signature = _get_embedding_signature(input_, predictor, tile_shape, halo)
 803    signature.update({"input_size": input_size, "original_size": original_size})
 804    for key, val in signature.items():
 805        f.attrs[key] = val
 806
 807
 808def _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo):
 809    # We may have an empty zarr file that was already created to save the embeddings in.
 810    # In this case the embeddings will be computed and we don't need to perform any checks.
 811    if "input_size" not in f.attrs:
 812        return
 813
 814    signature = _get_embedding_signature(input_, predictor, tile_shape, halo)
 815    for key, val in signature.items():
 816        # Check whether the key is missing from the attrs or if the value is not matching.
 817        if key not in f.attrs or f.attrs[key] != val:
 818            # These keys were recently added, so we don't want to fail yet if they don't
 819            # match in order to not invalidate previous embedding files.
 820            # Instead we just raise a warning. (For the version we probably also don't want to fail
 821            # i the future since it should not invalidate the embeddings).
 822            if key in ("micro_sam_version", "model_hash", "model_name"):
 823                warnings.warn(
 824                    f"The signature for {key} in embeddings file {save_path} has a mismatch: "
 825                    f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. "
 826                    "But please recompute them if model predictions don't look as expected."
 827                )
 828            else:
 829                raise RuntimeError(
 830                    f"Embeddings file {save_path} is invalid due to mismatch in {key}: "
 831                    f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file."
 832                )
 833
 834
 835# Helper function for optional external progress bars.
 836def handle_pbar(verbose, pbar_init, pbar_update):
 837    """@private"""
 838
 839    # Noop to provide dummy functions.
 840    def noop(*args):
 841        pass
 842
 843    if verbose and pbar_init is None:  # we are verbose and don't have an external progress bar.
 844        assert pbar_update is None  # avoid inconsistent state of callbacks
 845
 846        # Create our own progress bar and callbacks
 847        pbar = tqdm()
 848
 849        def pbar_init(total, description):
 850            pbar.total = total
 851            pbar.set_description(description)
 852
 853        def pbar_update(update):
 854            pbar.update(update)
 855
 856        def pbar_close():
 857            pbar.close()
 858
 859    elif verbose and pbar_init is not None:  # external pbar -> we don't have to do anything
 860        assert pbar_update is not None
 861        pbar = None
 862        pbar_close = noop
 863
 864    else:  # we are not verbose, do nothing
 865        pbar = None
 866        pbar_init, pbar_update, pbar_close = noop, noop, noop
 867
 868    return pbar, pbar_init, pbar_update, pbar_close
 869
 870
 871def precompute_image_embeddings(
 872    predictor: SamPredictor,
 873    input_: np.ndarray,
 874    save_path: Optional[Union[str, os.PathLike]] = None,
 875    lazy_loading: bool = False,
 876    ndim: Optional[int] = None,
 877    tile_shape: Optional[Tuple[int, int]] = None,
 878    halo: Optional[Tuple[int, int]] = None,
 879    verbose: bool = True,
 880    pbar_init: Optional[callable] = None,
 881    pbar_update: Optional[callable] = None,
 882) -> ImageEmbeddings:
 883    """Compute the image embeddings (output of the encoder) for the input.
 884
 885    If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
 886
 887    Args:
 888        predictor: The SegmentAnything predictor.
 889        input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
 890        save_path: Path to save the embeddings in a zarr container.
 891        lazy_loading: Whether to load all embeddings into memory or return an
 892            object to load them on demand when required. This only has an effect if 'save_path' is given
 893            and if the input is 3 dimensional.
 894        ndim: The dimensionality of the data. If not given will be deduced from the input data.
 895        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
 896        halo: Overlap of the tiles for tiled prediction.
 897        verbose: Whether to be verbose in the computation.
 898        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 899            Can be used together with pbar_update to handle napari progress bar in other thread.
 900            To enables using this function within a threadworker.
 901        pbar_update: Callback to update an external progress bar.
 902
 903    Returns:
 904        The image embeddings.
 905    """
 906    ndim = input_.ndim if ndim is None else ndim
 907
 908    # Handle the embedding save_path.
 909    # We don't have a save path, open in memory zarr file to hold tiled embeddings.
 910    if save_path is None:
 911        f = zarr.group()
 912
 913    # We have a save path and it already exists. Embeddings will be loaded from it,
 914    # check that the saved embeddings in there match the parameters of the function call.
 915    elif os.path.exists(save_path):
 916        f = zarr.open(save_path, "a")
 917        _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo)
 918
 919    # We have a save path and it does not exist yet. Create the zarr file to which the
 920    # embeddings will then be saved.
 921    else:
 922        f = zarr.open(save_path, "a")
 923
 924    _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update)
 925
 926    if ndim == 2 and tile_shape is None:
 927        embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update)
 928    elif ndim == 2 and tile_shape is not None:
 929        embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update)
 930    elif ndim == 3 and tile_shape is None:
 931        embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update)
 932    elif ndim == 3 and tile_shape is not None:
 933        embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update)
 934    else:
 935        raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.")
 936
 937    pbar_close()
 938    return embeddings
 939
 940
 941def set_precomputed(
 942    predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None,
 943) -> SamPredictor:
 944    """Set the precomputed image embeddings for a predictor.
 945
 946    Args:
 947        predictor: The SegmentAnything predictor.
 948        image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`.
 949        i: Index for the image data. Required if `image` has three spatial dimensions
 950            or a time dimension and two spatial dimensions.
 951        tile_id: Index for the tile. This is required if the embeddings are tiled.
 952
 953    Returns:
 954        The predictor with set features.
 955    """
 956    if tile_id is not None:
 957        tile_features = image_embeddings["features"][tile_id]
 958        tile_image_embeddings = {
 959            "features": tile_features,
 960            "input_size": tile_features.attrs["input_size"],
 961            "original_size": tile_features.attrs["original_size"]
 962        }
 963        return set_precomputed(predictor, tile_image_embeddings, i=i)
 964
 965    device = predictor.device
 966    features = image_embeddings["features"]
 967    assert features.ndim in (4, 5), f"{features.ndim}"
 968    if features.ndim == 5 and i is None:
 969        raise ValueError("The data is 3D so an index i is needed.")
 970    elif features.ndim == 4 and i is not None:
 971        raise ValueError("The data is 2D so an index is not needed.")
 972
 973    if i is None:
 974        predictor.features = features.to(device) if torch.is_tensor(features) else \
 975            torch.from_numpy(features[:]).to(device)
 976    else:
 977        predictor.features = features[i].to(device) if torch.is_tensor(features) else \
 978            torch.from_numpy(features[i]).to(device)
 979
 980    predictor.original_size = image_embeddings["original_size"]
 981    predictor.input_size = image_embeddings["input_size"]
 982    predictor.is_image_set = True
 983
 984    return predictor
 985
 986
 987#
 988# Misc functionality
 989#
 990
 991
 992def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
 993    """Compute the intersection over union of two masks.
 994
 995    Args:
 996        mask1: The first mask.
 997        mask2: The second mask.
 998
 999    Returns:
1000        The intersection over union of the two masks.
1001    """
1002    overlap = np.logical_and(mask1 == 1, mask2 == 1).sum()
1003    union = np.logical_or(mask1 == 1, mask2 == 1).sum()
1004    eps = 1e-7
1005    iou = float(overlap) / (float(union) + eps)
1006    return iou
1007
1008
1009def get_centers_and_bounding_boxes(
1010    segmentation: np.ndarray, mode: str = "v"
1011) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]:
1012    """Returns the center coordinates of the foreground instances in the ground-truth.
1013
1014    Args:
1015        segmentation: The segmentation.
1016        mode: Determines the functionality used for computing the centers.
1017            If 'v', the object's eccentricity centers computed by vigra are used.
1018            If 'p' the object's centroids computed by skimage are used.
1019
1020    Returns:
1021        A dictionary that maps object ids to the corresponding centroid.
1022        A dictionary that maps object_ids to the corresponding bounding box.
1023    """
1024    assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra"
1025
1026    properties = regionprops(segmentation)
1027
1028    if mode == "p":
1029        center_coordinates = {prop.label: prop.centroid for prop in properties}
1030    elif mode == "v":
1031        center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32'))
1032        center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0}
1033
1034    bbox_coordinates = {prop.label: prop.bbox for prop in properties}
1035
1036    assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}"
1037    return center_coordinates, bbox_coordinates
1038
1039
1040def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray:
1041    """Helper function to load image data from file.
1042
1043    Args:
1044        path: The filepath to the image data.
1045        key: The internal filepath for complex data formats like hdf5.
1046        lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
1047
1048    Returns:
1049        The image data.
1050    """
1051    if key is None:
1052        image_data = imageio.imread(path)
1053    else:
1054        with open_file(path, mode="r") as f:
1055            image_data = f[key]
1056            if not lazy_loading:
1057                image_data = image_data[:]
1058
1059    return image_data
1060
1061
1062def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor:
1063    """Convert the segmentation to one-hot encoded masks.
1064
1065    Args:
1066        segmentation: The segmentation.
1067        segmentation_ids: Optional subset of ids that will be used to subsample the masks.
1068
1069    Returns:
1070        The one-hot encoded masks.
1071    """
1072    masks = segmentation.copy()
1073    if segmentation_ids is None:
1074        n_ids = int(segmentation.max())
1075
1076    else:
1077        msg = "No foreground objects were found."
1078        if len(segmentation_ids) == 0:  # The list should not be completely empty.
1079            raise RuntimeError(msg)
1080
1081        if 0 in segmentation_ids:  # The list should not have 'zero' as a value.
1082            raise RuntimeError(msg)
1083
1084        # the segmentation ids have to be sorted
1085        segmentation_ids = np.sort(segmentation_ids)
1086
1087        # set the non selected objects to zero and relabel sequentially
1088        masks[~np.isin(masks, segmentation_ids)] = 0
1089        masks = relabel_sequential(masks)[0]
1090        n_ids = len(segmentation_ids)
1091
1092    masks = torch.from_numpy(masks)
1093
1094    one_hot_shape = (n_ids + 1,) + masks.shape
1095    masks = masks.unsqueeze(0)  # add dimension to scatter
1096    masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:]
1097
1098    # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W
1099    masks = masks.unsqueeze(1)
1100    return masks
1101
1102
1103def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1104    """Get a suitable block shape for chunking a given shape.
1105
1106    The primary use for this is determining chunk sizes for
1107    zarr arrays or block shapes for parallelization.
1108
1109    Args:
1110        shape: The image or volume shape.
1111
1112    Returns:
1113        The block shape.
1114    """
1115    ndim = len(shape)
1116    if ndim == 2:
1117        block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape))
1118    elif ndim == 3:
1119        block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape))
1120    else:
1121        raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.")
1122
1123    return block_shape
def get_cache_directory() -> None:
60def get_cache_directory() -> None:
61    """Get micro-sam cache directory location.
62
63    Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
64    """
65    default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam"))
66    cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory))
67    return cache_directory

Get micro-sam cache directory location.

Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.

def microsam_cachedir() -> None:
75def microsam_cachedir() -> None:
76    """Return the micro-sam cache directory.
77
78    Returns the top level cache directory for micro-sam models and sample data.
79
80    Every time this function is called, we check for any user updates made to
81    the MICROSAM_CACHEDIR os environment variable since the last time.
82    """
83    cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam")
84    return cache_directory

Return the micro-sam cache directory.

Returns the top level cache directory for micro-sam models and sample data.

Every time this function is called, we check for any user updates made to the MICROSAM_CACHEDIR os environment variable since the last time.

def models():
 87def models():
 88    """Return the segmentation models registry.
 89
 90    We recreate the model registry every time this function is called,
 91    so any user changes to the default micro-sam cache directory location
 92    are respected.
 93    """
 94
 95    # We use xxhash to compute the hash of the models, see
 96    # https://github.com/computational-cell-analytics/micro-sam/issues/283
 97    # (It is now a dependency, so we don't provide the sha256 fallback anymore.)
 98    # To generate the xxh128 hash:
 99    #     xxh128sum filename
100    encoder_registry = {
101        # The default segment anything models:
102        "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a",
103        "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2",
104        "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b",
105        # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM.
106        "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0",
107        # The current version of our models in the modelzoo.
108        # LM generalist models:
109        "vit_l_lm": "xxh128:fc32ea6f7fcc7eb02737d1304f81f5f2",
110        "vit_b_lm": "xxh128:8fd5806be3c3ba213e19a709d6d1495f",
111        "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e",
112        # EM models:
113        "vit_l_em_organelles": "xxh128:096c9695966803ca6fde24f4c1e3c3fb",
114        "vit_b_em_organelles": "xxh128:f6f6593aeecd0e15a07bdac86360b6cc",
115        "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9",
116        # Histopathology models:
117        "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974",
118        "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2",
119        "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e",
120        # Medical Imaging models:
121        "vit_b_medical_imaging": "xxh128:5be672f1458263a9edc9fd40d7f56ac1",
122    }
123    # Additional decoders for instance segmentation.
124    decoder_registry = {
125        # LM generalist models:
126        "vit_l_lm_decoder": "xxh128:779b5a50ecc6d46d495753fba8717f2f",
127        "vit_b_lm_decoder": "xxh128:9f580a96984b3085389ced5d9a4ae75d",
128        "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823",
129        # EM models:
130        "vit_l_em_organelles_decoder": "xxh128:d60fd96bd6060856f6430f29e42568fb",
131        "vit_b_em_organelles_decoder": "xxh128:b2d4dcffb99f76d83497d39ee500088f",
132        "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44",
133        # Histopathology models:
134        "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213",
135        "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04",
136        "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47",
137    }
138    registry = {**encoder_registry, **decoder_registry}
139
140    encoder_urls = {
141        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
142        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
143        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
144        "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
145        "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l.pt",
146        "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b.pt",
147        "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt",
148        "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt",  # noqa
149        "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt",
150        "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt",  # noqa
151        "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download",
152        "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download",
153        "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download",
154        "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/AB69HGhj8wuozXQ/download",
155    }
156
157    decoder_urls = {
158        "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l_decoder.pt",  # noqa
159        "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b_decoder.pt",  # noqa
160        "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt",  # noqa
161        "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt",  # noqa
162        "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt",  # noqa
163        "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt",  # noqa
164        "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download",
165        "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download",
166        "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download",
167    }
168    urls = {**encoder_urls, **decoder_urls}
169
170    models = pooch.create(
171        path=os.path.join(microsam_cachedir(), "models"),
172        base_url="",
173        registry=registry,
174        urls=urls,
175    )
176    return models

Return the segmentation models registry.

We recreate the model registry every time this function is called, so any user changes to the default micro-sam cache directory location are respected.

def get_device( device: Union[str, torch.device, NoneType] = None) -> Union[str, torch.device]:
198def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
199    """Get the torch device.
200
201    If no device is passed the default device for your system is used.
202    Else it will be checked if the device you have passed is supported.
203
204    Args:
205        device: The input device.
206
207    Returns:
208        The device.
209    """
210    if device is None or device == "auto":
211        device = _get_default_device()
212    else:
213        device_type = device if isinstance(device, str) else device.type
214        if device_type.lower() == "cuda":
215            if not torch.cuda.is_available():
216                raise RuntimeError("PyTorch CUDA backend is not available.")
217        elif device_type.lower() == "mps":
218            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
219                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
220        elif device_type.lower() == "cpu":
221            pass  # cpu is always available
222        else:
223            raise RuntimeError(
224                f"Unsupported device: {device}\n"
225                "Please choose from 'cpu', 'cuda', or 'mps'."
226            )
227
228    return device

Get the torch device.

If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.

Arguments:
  • device: The input device.
Returns:

The device.

def get_sam_model( model_type: str = 'vit_b_lm', device: Union[str, torch.device, NoneType] = None, checkpoint_path: Union[str, os.PathLike, NoneType] = None, return_sam: bool = False, return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs) -> mobile_sam.predictor.SamPredictor:
290def get_sam_model(
291    model_type: str = _DEFAULT_MODEL,
292    device: Optional[Union[str, torch.device]] = None,
293    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
294    return_sam: bool = False,
295    return_state: bool = False,
296    peft_kwargs: Optional[Dict] = None,
297    flexible_load_checkpoint: bool = False,
298    **model_kwargs,
299) -> SamPredictor:
300    r"""Get the SegmentAnything Predictor.
301
302    This function will download the required model or load it from the cached weight file.
303    This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
304    The name of the requested model can be set via `model_type`.
305    See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
306    for an overview of the available models
307
308    Alternatively this function can also load a model from weights stored in a local filepath.
309    The corresponding file path is given via `checkpoint_path`. In this case `model_type`
310    must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
311    a SAM model with vit_b encoder.
312
313    By default the models are downloaded to a folder named 'micro_sam/models'
314    inside your default cache directory, eg:
315    * Mac: ~/Library/Caches/<AppName>
316    * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined.
317    * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache
318    See the pooch.os_cache() documentation for more details:
319    https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
320
321    Args:
322        model_type: The Segment Anything model to use. Will use the standard `vit_l` model by default.
323            To get a list of all available model names you can call `get_model_names`.
324        device: The device for the model. If none is given will use GPU if available.
325        checkpoint_path: The path to a file with weights that should be used instead of using the
326            weights corresponding to `model_type`. If given, `model_type` must match the architecture
327            corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder
328            then `model_type` must be given as "vit_b".
329        return_sam: Return the sam model object as well as the predictor.
330        return_state: Return the unpickled checkpoint state.
331        peft_kwargs: Keyword arguments for th PEFT wrapper class.
332        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
333        model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
334
335    Returns:
336        The segment anything predictor.
337    """
338    device = get_device(device)
339
340    # We support passing a local filepath to a checkpoint.
341    # In this case we do not download any weights but just use the local weight file,
342    # as it is, without copying it over anywhere or checking it's hashes.
343
344    # checkpoint_path has not been passed, we download a known model and derive the correct
345    # URL from the model_type. If the model_type is invalid pooch will raise an error.
346    if checkpoint_path is None:
347        model_registry = models()
348        checkpoint_path = model_registry.fetch(model_type, progressbar=True)
349        model_hash = model_registry.registry[model_type]
350
351        # If we have a custom model then we may also have a decoder checkpoint.
352        # Download it here, so that we can add it to the state.
353        decoder_name = f"{model_type}_decoder"
354        decoder_path = model_registry.fetch(
355            decoder_name, progressbar=True
356        ) if decoder_name in model_registry.registry else None
357
358    # checkpoint_path has been passed, we use it instead of downloading a model.
359    else:
360        # Check if the file exists and raise an error otherwise.
361        # We can't check any hashes here, and we don't check if the file is actually a valid weight file.
362        # (If it isn't the model creation will fail below.)
363        if not os.path.exists(checkpoint_path):
364            raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.")
365        model_hash = _compute_hash(checkpoint_path)
366        decoder_path = None
367
368    # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped
369    # before calling sam_model_registry.
370    abbreviated_model_type = model_type[:5]
371    if abbreviated_model_type not in _MODEL_TYPES:
372        raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}")
373    if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
374        raise RuntimeError(
375            "'mobile_sam' is required for the vit-tiny. "
376            "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
377        )
378
379    state, model_state = _load_checkpoint(checkpoint_path)
380
381    # Whether to update parameters necessary to initialize the model
382    if model_kwargs:  # Checks whether model_kwargs have been provided or not
383        if abbreviated_model_type == "vit_t":
384            raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.")
385        sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs)
386
387    else:
388        sam = sam_model_registry[abbreviated_model_type]()
389
390    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
391    # Overwrites the SAM model by freezing the backbone and allow PEFT.
392    if peft_kwargs and isinstance(peft_kwargs, dict):
393        # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference.
394        peft_kwargs.pop("quantize", None)
395
396        if abbreviated_model_type == "vit_t":
397            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
398
399        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
400    # In case the model checkpoints have some issues when it is initialized with different parameters than default.
401    if flexible_load_checkpoint:
402        sam = _handle_checkpoint_loading(sam, model_state)
403    else:
404        sam.load_state_dict(model_state)
405    sam.to(device=device)
406
407    predictor = SamPredictor(sam)
408    predictor.model_type = abbreviated_model_type
409    predictor._hash = model_hash
410    predictor.model_name = model_type
411
412    # Add the decoder to the state if we have one and if the state is returned.
413    if decoder_path is not None and return_state:
414        state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False)
415
416    if return_sam and return_state:
417        return predictor, sam, state
418    if return_sam:
419        return predictor, sam
420    if return_state:
421        return predictor, state
422    return predictor

Get the SegmentAnything Predictor.

This function will download the required model or load it from the cached weight file. This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. The name of the requested model can be set via model_type. See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models for an overview of the available models

Alternatively this function can also load a model from weights stored in a local filepath. The corresponding file path is given via checkpoint_path. In this case model_type must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for a SAM model with vit_b encoder.

By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg:

Arguments:
  • model_type: The Segment Anything model to use. Will use the standard vit_l model by default. To get a list of all available model names you can call get_model_names.
  • device: The device for the model. If none is given will use GPU if available.
  • checkpoint_path: The path to a file with weights that should be used instead of using the weights corresponding to model_type. If given, model_type must match the architecture corresponding to the weight file. e.g. if you use weights for SAM with vit_b encoder then model_type must be given as "vit_b".
  • return_sam: Return the sam model object as well as the predictor.
  • return_state: Return the unpickled checkpoint state.
  • peft_kwargs: Keyword arguments for th PEFT wrapper class.
  • flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
  • model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
Returns:

The segment anything predictor.

def export_custom_sam_model( checkpoint_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike], with_segmentation_decoder: bool = False) -> None:
458def export_custom_sam_model(
459    checkpoint_path: Union[str, os.PathLike],
460    model_type: str,
461    save_path: Union[str, os.PathLike],
462    with_segmentation_decoder: bool = False,
463) -> None:
464    """Export a finetuned Segment Anything Model to the standard model format.
465
466    The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
467
468    Args:
469        checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
470        model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
471        save_path: Where to save the exported model.
472        with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well.
473            If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
474    """
475    _, state = get_sam_model(
476        model_type=model_type, checkpoint_path=checkpoint_path, return_state=True, device="cpu",
477    )
478    model_state = state["model_state"]
479    prefix = "sam."
480    model_state = OrderedDict(
481        [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]
482    )
483
484    # Store the 'decoder_state' as well, if desired.
485    if with_segmentation_decoder:
486        if "decoder_state" not in state:
487            raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.")
488        decoder_state = state["decoder_state"]
489        save_state = {"model_state": model_state, "decoder_state": decoder_state}
490    else:
491        save_state = model_state
492
493    torch.save(save_state, save_path)

Export a finetuned Segment Anything Model to the standard model format.

The exported model can be used by the interactive annotation tools in micro_sam.annotator.

Arguments:
  • checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
  • model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
  • save_path: Where to save the exported model.
  • with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
def export_custom_qlora_model( checkpoint_path: Union[str, os.PathLike], finetuned_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike]) -> None:
496def export_custom_qlora_model(
497    checkpoint_path: Union[str, os.PathLike],
498    finetuned_path: Union[str, os.PathLike],
499    model_type: str,
500    save_path: Union[str, os.PathLike],
501) -> None:
502    """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
503
504    The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`.
505
506    Args:
507        checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
508        finetuned_path: The path to the new finetuned model, using QLoRA.
509        model_type: The Segment Anything Model type corresponding to the checkpoint.
510        save_path: Where to save the exported model.
511    """
512    # Step 1: Get the base SAM model: used to start finetuning from.
513    _, sam = get_sam_model(
514        model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True,
515    )
516
517    # Step 2: Load the QLoRA-style finetuned model.
518    ft_state, ft_model_state = _load_checkpoint(finetuned_path)
519
520    # Step 3: Get LoRA weights from QLoRA and retain all original parameters from the base SAM model.
521    updated_model_state = {}
522
523    # - At first, we get all LoRA layers from the QLoRA-style finetuned model checkpoint.
524    for k, v in ft_model_state.items():
525        if k.find("w_b_linear") != -1 or k.find("w_a_linear") != -1:
526            updated_model_state[k] = v
527
528    # - Next, we get all the remaining parameters from the base SAM model.
529    for k, v in sam.state_dict().items():
530        if k.find("attn.qkv.") != -1:
531            k = k.replace("qkv", "qkv.qkv_proj")
532            updated_model_state[k] = v
533        else:
534
535            updated_model_state[k] = v
536
537    # - Finally, we replace the old model state with the new one (to retain other relevant stuff)
538    ft_state['model_state'] = updated_model_state
539
540    # Step 4: Store the new "state" to "save_path"
541    torch.save(ft_state, save_path)

Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.

The exported model can be used with the LoRA backbone by passing the relevant peft_kwargs to get_sam_model.

Arguments:
  • checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
  • finetuned_path: The path to the new finetuned model, using QLoRA.
  • model_type: The Segment Anything Model type corresponding to the checkpoint.
  • save_path: Where to save the exported model.
def get_model_names() -> Iterable:
544def get_model_names() -> Iterable:
545    model_registry = models()
546    model_names = model_registry.registry.keys()
547    return model_names
def precompute_image_embeddings( predictor: mobile_sam.predictor.SamPredictor, input_: numpy.ndarray, save_path: Union[str, os.PathLike, NoneType] = None, lazy_loading: bool = False, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> Dict[str, Any]:
872def precompute_image_embeddings(
873    predictor: SamPredictor,
874    input_: np.ndarray,
875    save_path: Optional[Union[str, os.PathLike]] = None,
876    lazy_loading: bool = False,
877    ndim: Optional[int] = None,
878    tile_shape: Optional[Tuple[int, int]] = None,
879    halo: Optional[Tuple[int, int]] = None,
880    verbose: bool = True,
881    pbar_init: Optional[callable] = None,
882    pbar_update: Optional[callable] = None,
883) -> ImageEmbeddings:
884    """Compute the image embeddings (output of the encoder) for the input.
885
886    If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
887
888    Args:
889        predictor: The SegmentAnything predictor.
890        input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
891        save_path: Path to save the embeddings in a zarr container.
892        lazy_loading: Whether to load all embeddings into memory or return an
893            object to load them on demand when required. This only has an effect if 'save_path' is given
894            and if the input is 3 dimensional.
895        ndim: The dimensionality of the data. If not given will be deduced from the input data.
896        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
897        halo: Overlap of the tiles for tiled prediction.
898        verbose: Whether to be verbose in the computation.
899        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
900            Can be used together with pbar_update to handle napari progress bar in other thread.
901            To enables using this function within a threadworker.
902        pbar_update: Callback to update an external progress bar.
903
904    Returns:
905        The image embeddings.
906    """
907    ndim = input_.ndim if ndim is None else ndim
908
909    # Handle the embedding save_path.
910    # We don't have a save path, open in memory zarr file to hold tiled embeddings.
911    if save_path is None:
912        f = zarr.group()
913
914    # We have a save path and it already exists. Embeddings will be loaded from it,
915    # check that the saved embeddings in there match the parameters of the function call.
916    elif os.path.exists(save_path):
917        f = zarr.open(save_path, "a")
918        _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo)
919
920    # We have a save path and it does not exist yet. Create the zarr file to which the
921    # embeddings will then be saved.
922    else:
923        f = zarr.open(save_path, "a")
924
925    _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update)
926
927    if ndim == 2 and tile_shape is None:
928        embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update)
929    elif ndim == 2 and tile_shape is not None:
930        embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update)
931    elif ndim == 3 and tile_shape is None:
932        embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update)
933    elif ndim == 3 and tile_shape is not None:
934        embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update)
935    else:
936        raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.")
937
938    pbar_close()
939    return embeddings

Compute the image embeddings (output of the encoder) for the input.

If 'save_path' is given the embeddings will be loaded/saved in a zarr container.

Arguments:
  • predictor: The SegmentAnything predictor.
  • input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
  • save_path: Path to save the embeddings in a zarr container.
  • lazy_loading: Whether to load all embeddings into memory or return an object to load them on demand when required. This only has an effect if 'save_path' is given and if the input is 3 dimensional.
  • ndim: The dimensionality of the data. If not given will be deduced from the input data.
  • tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction.
  • verbose: Whether to be verbose in the computation.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
Returns:

The image embeddings.

def set_precomputed( predictor: mobile_sam.predictor.SamPredictor, image_embeddings: Dict[str, Any], i: Optional[int] = None, tile_id: Optional[int] = None) -> mobile_sam.predictor.SamPredictor:
942def set_precomputed(
943    predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None,
944) -> SamPredictor:
945    """Set the precomputed image embeddings for a predictor.
946
947    Args:
948        predictor: The SegmentAnything predictor.
949        image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`.
950        i: Index for the image data. Required if `image` has three spatial dimensions
951            or a time dimension and two spatial dimensions.
952        tile_id: Index for the tile. This is required if the embeddings are tiled.
953
954    Returns:
955        The predictor with set features.
956    """
957    if tile_id is not None:
958        tile_features = image_embeddings["features"][tile_id]
959        tile_image_embeddings = {
960            "features": tile_features,
961            "input_size": tile_features.attrs["input_size"],
962            "original_size": tile_features.attrs["original_size"]
963        }
964        return set_precomputed(predictor, tile_image_embeddings, i=i)
965
966    device = predictor.device
967    features = image_embeddings["features"]
968    assert features.ndim in (4, 5), f"{features.ndim}"
969    if features.ndim == 5 and i is None:
970        raise ValueError("The data is 3D so an index i is needed.")
971    elif features.ndim == 4 and i is not None:
972        raise ValueError("The data is 2D so an index is not needed.")
973
974    if i is None:
975        predictor.features = features.to(device) if torch.is_tensor(features) else \
976            torch.from_numpy(features[:]).to(device)
977    else:
978        predictor.features = features[i].to(device) if torch.is_tensor(features) else \
979            torch.from_numpy(features[i]).to(device)
980
981    predictor.original_size = image_embeddings["original_size"]
982    predictor.input_size = image_embeddings["input_size"]
983    predictor.is_image_set = True
984
985    return predictor

Set the precomputed image embeddings for a predictor.

Arguments:
  • predictor: The SegmentAnything predictor.
  • image_embeddings: The precomputed image embeddings computed by precompute_image_embeddings.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_id: Index for the tile. This is required if the embeddings are tiled.
Returns:

The predictor with set features.

def compute_iou(mask1: numpy.ndarray, mask2: numpy.ndarray) -> float:
 993def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
 994    """Compute the intersection over union of two masks.
 995
 996    Args:
 997        mask1: The first mask.
 998        mask2: The second mask.
 999
1000    Returns:
1001        The intersection over union of the two masks.
1002    """
1003    overlap = np.logical_and(mask1 == 1, mask2 == 1).sum()
1004    union = np.logical_or(mask1 == 1, mask2 == 1).sum()
1005    eps = 1e-7
1006    iou = float(overlap) / (float(union) + eps)
1007    return iou

Compute the intersection over union of two masks.

Arguments:
  • mask1: The first mask.
  • mask2: The second mask.
Returns:

The intersection over union of the two masks.

def get_centers_and_bounding_boxes( segmentation: numpy.ndarray, mode: str = 'v') -> Tuple[Dict[int, numpy.ndarray], Dict[int, tuple]]:
1010def get_centers_and_bounding_boxes(
1011    segmentation: np.ndarray, mode: str = "v"
1012) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]:
1013    """Returns the center coordinates of the foreground instances in the ground-truth.
1014
1015    Args:
1016        segmentation: The segmentation.
1017        mode: Determines the functionality used for computing the centers.
1018            If 'v', the object's eccentricity centers computed by vigra are used.
1019            If 'p' the object's centroids computed by skimage are used.
1020
1021    Returns:
1022        A dictionary that maps object ids to the corresponding centroid.
1023        A dictionary that maps object_ids to the corresponding bounding box.
1024    """
1025    assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra"
1026
1027    properties = regionprops(segmentation)
1028
1029    if mode == "p":
1030        center_coordinates = {prop.label: prop.centroid for prop in properties}
1031    elif mode == "v":
1032        center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32'))
1033        center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0}
1034
1035    bbox_coordinates = {prop.label: prop.bbox for prop in properties}
1036
1037    assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}"
1038    return center_coordinates, bbox_coordinates

Returns the center coordinates of the foreground instances in the ground-truth.

Arguments:
  • segmentation: The segmentation.
  • mode: Determines the functionality used for computing the centers. If 'v', the object's eccentricity centers computed by vigra are used. If 'p' the object's centroids computed by skimage are used.
Returns:

A dictionary that maps object ids to the corresponding centroid. A dictionary that maps object_ids to the corresponding bounding box.

def load_image_data( path: str, key: Optional[str] = None, lazy_loading: bool = False) -> numpy.ndarray:
1041def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray:
1042    """Helper function to load image data from file.
1043
1044    Args:
1045        path: The filepath to the image data.
1046        key: The internal filepath for complex data formats like hdf5.
1047        lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
1048
1049    Returns:
1050        The image data.
1051    """
1052    if key is None:
1053        image_data = imageio.imread(path)
1054    else:
1055        with open_file(path, mode="r") as f:
1056            image_data = f[key]
1057            if not lazy_loading:
1058                image_data = image_data[:]
1059
1060    return image_data

Helper function to load image data from file.

Arguments:
  • path: The filepath to the image data.
  • key: The internal filepath for complex data formats like hdf5.
  • lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
Returns:

The image data.

def segmentation_to_one_hot( segmentation: numpy.ndarray, segmentation_ids: Optional[numpy.ndarray] = None) -> torch.Tensor:
1063def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor:
1064    """Convert the segmentation to one-hot encoded masks.
1065
1066    Args:
1067        segmentation: The segmentation.
1068        segmentation_ids: Optional subset of ids that will be used to subsample the masks.
1069
1070    Returns:
1071        The one-hot encoded masks.
1072    """
1073    masks = segmentation.copy()
1074    if segmentation_ids is None:
1075        n_ids = int(segmentation.max())
1076
1077    else:
1078        msg = "No foreground objects were found."
1079        if len(segmentation_ids) == 0:  # The list should not be completely empty.
1080            raise RuntimeError(msg)
1081
1082        if 0 in segmentation_ids:  # The list should not have 'zero' as a value.
1083            raise RuntimeError(msg)
1084
1085        # the segmentation ids have to be sorted
1086        segmentation_ids = np.sort(segmentation_ids)
1087
1088        # set the non selected objects to zero and relabel sequentially
1089        masks[~np.isin(masks, segmentation_ids)] = 0
1090        masks = relabel_sequential(masks)[0]
1091        n_ids = len(segmentation_ids)
1092
1093    masks = torch.from_numpy(masks)
1094
1095    one_hot_shape = (n_ids + 1,) + masks.shape
1096    masks = masks.unsqueeze(0)  # add dimension to scatter
1097    masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:]
1098
1099    # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W
1100    masks = masks.unsqueeze(1)
1101    return masks

Convert the segmentation to one-hot encoded masks.

Arguments:
  • segmentation: The segmentation.
  • segmentation_ids: Optional subset of ids that will be used to subsample the masks.
Returns:

The one-hot encoded masks.

def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1104def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1105    """Get a suitable block shape for chunking a given shape.
1106
1107    The primary use for this is determining chunk sizes for
1108    zarr arrays or block shapes for parallelization.
1109
1110    Args:
1111        shape: The image or volume shape.
1112
1113    Returns:
1114        The block shape.
1115    """
1116    ndim = len(shape)
1117    if ndim == 2:
1118        block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape))
1119    elif ndim == 3:
1120        block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape))
1121    else:
1122        raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.")
1123
1124    return block_shape

Get a suitable block shape for chunking a given shape.

The primary use for this is determining chunk sizes for zarr arrays or block shapes for parallelization.

Arguments:
  • shape: The image or volume shape.
Returns:

The block shape.