micro_sam.util

Helper functions for downloading Segment Anything models and predicting image embeddings.

   1"""
   2Helper functions for downloading Segment Anything models and predicting image embeddings.
   3"""
   4
   5
   6import os
   7import pickle
   8import hashlib
   9import warnings
  10from pathlib import Path
  11from collections import OrderedDict
  12from typing import Any, Dict, Iterable, Optional, Tuple, Union
  13
  14import zarr
  15import vigra
  16import torch
  17import pooch
  18import xxhash
  19import numpy as np
  20import imageio.v3 as imageio
  21from skimage.measure import regionprops
  22from skimage.segmentation import relabel_sequential
  23
  24from elf.io import open_file
  25
  26from nifty.tools import blocking
  27
  28from .__version__ import __version__
  29from . import models as custom_models
  30
  31try:
  32    # Avoid import warnigns from mobile_sam
  33    with warnings.catch_warnings():
  34        warnings.simplefilter("ignore")
  35        from mobile_sam import sam_model_registry, SamPredictor
  36    VIT_T_SUPPORT = True
  37except ImportError:
  38    from segment_anything import sam_model_registry, SamPredictor
  39    VIT_T_SUPPORT = False
  40
  41try:
  42    from napari.utils import progress as tqdm
  43except ImportError:
  44    from tqdm import tqdm
  45
  46# this is the default model used in micro_sam
  47# currently set to the default vit_l
  48_DEFAULT_MODEL = "vit_l"
  49
  50# The valid model types. Each type corresponds to the architecture of the
  51# vision transformer used within SAM.
  52_MODEL_TYPES = ("vit_l", "vit_b", "vit_h", "vit_t")
  53
  54
  55# TODO define the proper type for image embeddings
  56ImageEmbeddings = Dict[str, Any]
  57"""@private"""
  58
  59
  60def get_cache_directory() -> None:
  61    """Get micro-sam cache directory location.
  62
  63    Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
  64    """
  65    default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam"))
  66    cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory))
  67    return cache_directory
  68
  69
  70#
  71# Functionality for model download and export
  72#
  73
  74
  75def microsam_cachedir() -> None:
  76    """Return the micro-sam cache directory.
  77
  78    Returns the top level cache directory for micro-sam models and sample data.
  79
  80    Every time this function is called, we check for any user updates made to
  81    the MICROSAM_CACHEDIR os environment variable since the last time.
  82    """
  83    cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam")
  84    return cache_directory
  85
  86
  87def models():
  88    """Return the segmentation models registry.
  89
  90    We recreate the model registry every time this function is called,
  91    so any user changes to the default micro-sam cache directory location
  92    are respected.
  93    """
  94
  95    # We use xxhash to compute the hash of the models, see
  96    # https://github.com/computational-cell-analytics/micro-sam/issues/283
  97    # (It is now a dependency, so we don't provide the sha256 fallback anymore.)
  98    # To generate the xxh128 hash:
  99    #     xxh128sum filename
 100    encoder_registry = {
 101        # The default segment anything models:
 102        "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a",
 103        "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2",
 104        "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b",
 105        # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM.
 106        "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0",
 107        # The current version of our models in the modelzoo.
 108        # LM generalist models:
 109        "vit_l_lm": "xxh128:ad3afe783b0d05a788eaf3cc24b308d2",
 110        "vit_b_lm": "xxh128:61ce01ea731d89ae41a252480368f886",
 111        "vit_t_lm": "xxh128:f90e2ba3dd3d5b935aa870cf2e48f689",
 112        # EM models:
 113        "vit_l_em_organelles": "xxh128:096c9695966803ca6fde24f4c1e3c3fb",
 114        "vit_b_em_organelles": "xxh128:f6f6593aeecd0e15a07bdac86360b6cc",
 115        "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9",
 116    }
 117    # Additional decoders for instance segmentation.
 118    decoder_registry = {
 119        # LM generalist models:
 120        "vit_l_lm_decoder": "xxh128:40c1ae378cfdce24008b9be24889a5b1",
 121        "vit_b_lm_decoder": "xxh128:1bac305195777ba7375634ca15a3c370",
 122        "vit_t_lm_decoder": "xxh128:82d3604e64f289bb66ec46a5643da169",
 123        # EM models:
 124        "vit_l_em_organelles_decoder": "xxh128:d60fd96bd6060856f6430f29e42568fb",
 125        "vit_b_em_organelles_decoder": "xxh128:b2d4dcffb99f76d83497d39ee500088f",
 126        "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44",
 127    }
 128    registry = {**encoder_registry, **decoder_registry}
 129
 130    encoder_urls = {
 131        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
 132        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
 133        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
 134        "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
 135        "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l.pt",
 136        "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b.pt",
 137        "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t.pt",
 138        "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt",  # noqa
 139        "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt",
 140        "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt",  # noqa
 141    }
 142
 143    decoder_urls = {
 144        "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l_decoder.pt",  # noqa
 145        "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b_decoder.pt",  # noqa
 146        "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t_decoder.pt",  # noqa
 147        "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt",  # noqa
 148        "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt",  # noqa
 149        "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt",  # noqa
 150    }
 151    urls = {**encoder_urls, **decoder_urls}
 152
 153    models = pooch.create(
 154        path=os.path.join(microsam_cachedir(), "models"),
 155        base_url="",
 156        registry=registry,
 157        urls=urls,
 158    )
 159    return models
 160
 161
 162def _get_default_device():
 163    # check that we're in CI and use the CPU if we are
 164    # otherwise the tests may run out of memory on MAC if MPS is used.
 165    if os.getenv("GITHUB_ACTIONS") == "true":
 166        return "cpu"
 167    # Use cuda enabled gpu if it's available.
 168    if torch.cuda.is_available():
 169        device = "cuda"
 170    # As second priority use mps.
 171    # See https://pytorch.org/docs/stable/notes/mps.html for details
 172    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
 173        print("Using apple MPS device.")
 174        device = "mps"
 175    # Use the CPU as fallback.
 176    else:
 177        device = "cpu"
 178    return device
 179
 180
 181def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
 182    """Get the torch device.
 183
 184    If no device is passed the default device for your system is used.
 185    Else it will be checked if the device you have passed is supported.
 186
 187    Args:
 188        device: The input device.
 189
 190    Returns:
 191        The device.
 192    """
 193    if device is None or device == "auto":
 194        device = _get_default_device()
 195    else:
 196        device_type = device if isinstance(device, str) else device.type
 197        if device_type.lower() == "cuda":
 198            if not torch.cuda.is_available():
 199                raise RuntimeError("PyTorch CUDA backend is not available.")
 200        elif device_type.lower() == "mps":
 201            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
 202                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
 203        elif device_type.lower() == "cpu":
 204            pass  # cpu is always available
 205        else:
 206            raise RuntimeError(f"Unsupported device: {device}\n"
 207                               "Please choose from 'cpu', 'cuda', or 'mps'.")
 208    return device
 209
 210
 211def _available_devices():
 212    available_devices = []
 213    for i in ["cuda", "mps", "cpu"]:
 214        try:
 215            device = get_device(i)
 216        except RuntimeError:
 217            pass
 218        else:
 219            available_devices.append(device)
 220    return available_devices
 221
 222
 223# We write a custom unpickler that skips objects that cannot be found instead of
 224# throwing an AttributeError or ModueNotFoundError.
 225# NOTE: since we just want to unpickle the model to load its weights these errors don't matter.
 226# See also https://stackoverflow.com/questions/27732354/unable-to-load-files-using-pickle-and-multiple-modules
 227class _CustomUnpickler(pickle.Unpickler):
 228    def find_class(self, module, name):
 229        try:
 230            return super().find_class(module, name)
 231        except (AttributeError, ModuleNotFoundError) as e:
 232            warnings.warn(f"Did not find {module}:{name} and will skip it, due to error {e}")
 233            return None
 234
 235
 236def _compute_hash(path, chunk_size=8192):
 237    hash_obj = xxhash.xxh128()
 238    with open(path, "rb") as f:
 239        chunk = f.read(chunk_size)
 240        while chunk:
 241            hash_obj.update(chunk)
 242            chunk = f.read(chunk_size)
 243    hash_val = hash_obj.hexdigest()
 244    return f"xxh128:{hash_val}"
 245
 246
 247# Load the state from a checkpoint.
 248# The checkpoint can either contain a sam encoder state
 249# or it can be a checkpoint for model finetuning.
 250def _load_checkpoint(checkpoint_path):
 251    # Over-ride the unpickler with our custom one.
 252    # This enables imports from torch_em checkpoints even if it cannot be fully unpickled.
 253    custom_pickle = pickle
 254    custom_pickle.Unpickler = _CustomUnpickler
 255
 256    state = torch.load(checkpoint_path, map_location="cpu", pickle_module=custom_pickle)
 257    if "model_state" in state:
 258        # Copy the model weights from torch_em's training format.
 259        model_state = state["model_state"]
 260        sam_prefix = "sam."
 261        model_state = OrderedDict(
 262            [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()]
 263        )
 264    else:
 265        model_state = state
 266
 267    return state, model_state
 268
 269
 270def get_sam_model(
 271    model_type: str = _DEFAULT_MODEL,
 272    device: Optional[Union[str, torch.device]] = None,
 273    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
 274    return_sam: bool = False,
 275    return_state: bool = False,
 276    peft_kwargs: Optional[Dict] = None,
 277    flexible_load_checkpoint: bool = False,
 278    **model_kwargs,
 279) -> SamPredictor:
 280    r"""Get the SegmentAnything Predictor.
 281
 282    This function will download the required model or load it from the cached weight file.
 283    This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
 284    The name of the requested model can be set via `model_type`.
 285    See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
 286    for an overview of the available models
 287
 288    Alternatively this function can also load a model from weights stored in a local filepath.
 289    The corresponding file path is given via `checkpoint_path`. In this case `model_type`
 290    must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
 291    a SAM model with vit_b encoder.
 292
 293    By default the models are downloaded to a folder named 'micro_sam/models'
 294    inside your default cache directory, eg:
 295    * Mac: ~/Library/Caches/<AppName>
 296    * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined.
 297    * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache
 298    See the pooch.os_cache() documentation for more details:
 299    https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
 300
 301    Args:
 302        model_type: The SegmentAnything model to use. Will use the standard vit_h model by default.
 303            To get a list of all available model names you can call `get_model_names`.
 304        device: The device for the model. If none is given will use GPU if available.
 305        checkpoint_path: The path to a file with weights that should be used instead of using the
 306            weights corresponding to `model_type`. If given, `model_type` must match the architecture
 307            corresponding to the weight file. E.g. if you use weights for SAM with vit_b encoder
 308            then `model_type` must be given as "vit_b".
 309        return_sam: Return the sam model object as well as the predictor.
 310        return_state: Return the unpickled checkpoint state.
 311        peft_kwargs: Keyword arguments for th PEFT wrapper class.
 312        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
 313
 314    Returns:
 315        The segment anything predictor.
 316    """
 317    device = get_device(device)
 318
 319    # We support passing a local filepath to a checkpoint.
 320    # In this case we do not download any weights but just use the local weight file,
 321    # as it is, without copying it over anywhere or checking it's hashes.
 322
 323    # checkpoint_path has not been passed, we download a known model and derive the correct
 324    # URL from the model_type. If the model_type is invalid pooch will raise an error.
 325    if checkpoint_path is None:
 326        model_registry = models()
 327        checkpoint_path = model_registry.fetch(model_type, progressbar=True)
 328        model_hash = model_registry.registry[model_type]
 329
 330        # If we have a custom model then we may also have a decoder checkpoint.
 331        # Download it here, so that we can add it to the state.
 332        decoder_name = f"{model_type}_decoder"
 333        decoder_path = model_registry.fetch(
 334            decoder_name, progressbar=True
 335        ) if decoder_name in model_registry.registry else None
 336
 337    # checkpoint_path has been passed, we use it instead of downloading a model.
 338    else:
 339        # Check if the file exists and raise an error otherwise.
 340        # We can't check any hashes here, and we don't check if the file is actually a valid weight file.
 341        # (If it isn't the model creation will fail below.)
 342        if not os.path.exists(checkpoint_path):
 343            raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.")
 344        model_hash = _compute_hash(checkpoint_path)
 345        decoder_path = None
 346
 347    # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped
 348    # before calling sam_model_registry.
 349    abbreviated_model_type = model_type[:5]
 350    if abbreviated_model_type not in _MODEL_TYPES:
 351        raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}")
 352    if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
 353        raise RuntimeError(
 354            "mobile_sam is required for the vit-tiny."
 355            "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
 356        )
 357
 358    state, model_state = _load_checkpoint(checkpoint_path)
 359
 360    # Whether to update parameters necessary to initialize the model
 361    if model_kwargs:  # Checks whether model_kwargs have been provided or not
 362        if abbreviated_model_type == "vit_t":
 363            raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.")
 364        sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs)
 365
 366    else:
 367        sam = sam_model_registry[abbreviated_model_type]()
 368
 369    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
 370    # Overwrites the SAM model by freezing the backbone and allow PEFT.
 371    if peft_kwargs and isinstance(peft_kwargs, dict):
 372        if abbreviated_model_type == "vit_t":
 373            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
 374
 375        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
 376
 377    # In case the model checkpoints have some issues when it is initialized with different parameters than default.
 378    if flexible_load_checkpoint:
 379        sam = _handle_checkpoint_loading(sam, model_state)
 380    else:
 381        sam.load_state_dict(model_state)
 382
 383    sam.to(device=device)
 384
 385    predictor = SamPredictor(sam)
 386    predictor.model_type = abbreviated_model_type
 387    predictor._hash = model_hash
 388    predictor.model_name = model_type
 389
 390    # Add the decoder to the state if we have one and if the state is returned.
 391    if decoder_path is not None and return_state:
 392        state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False)
 393
 394    if return_sam and return_state:
 395        return predictor, sam, state
 396    if return_sam:
 397        return predictor, sam
 398    if return_state:
 399        return predictor, state
 400    return predictor
 401
 402
 403def _handle_checkpoint_loading(sam, model_state):
 404    # Whether to handle the mismatch issues in a bit more elegant way.
 405    # eg. while training for multi-class semantic segmentation in the mask encoder,
 406    # parameters are updated - leading to "size mismatch" errors
 407
 408    new_state_dict = {}  # for loading matching parameters
 409    mismatched_layers = []  # for tracking mismatching parameters
 410
 411    reference_state = sam.state_dict()
 412
 413    for k, v in model_state.items():
 414        if k in reference_state:  # This is done to get rid of unwanted layers from pretrained SAM.
 415            if reference_state[k].size() == v.size():
 416                new_state_dict[k] = v
 417            else:
 418                mismatched_layers.append(k)
 419
 420    reference_state.update(new_state_dict)
 421
 422    if len(mismatched_layers) > 0:
 423        warnings.warn(f"The layers with size mismatch: {mismatched_layers}")
 424
 425    for mlayer in mismatched_layers:
 426        if 'weight' in mlayer:
 427            torch.nn.init.kaiming_uniform_(reference_state[mlayer])
 428        elif 'bias' in mlayer:
 429            reference_state[mlayer].zero_()
 430
 431    sam.load_state_dict(reference_state)
 432
 433    return sam
 434
 435
 436def export_custom_sam_model(
 437    checkpoint_path: Union[str, os.PathLike],
 438    model_type: str,
 439    save_path: Union[str, os.PathLike],
 440) -> None:
 441    """Export a finetuned segment anything model to the standard model format.
 442
 443    The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
 444
 445    Args:
 446        checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
 447        model_type: The SegmentAnything model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
 448        save_path: Where to save the exported model.
 449    """
 450    _, state = get_sam_model(
 451        model_type=model_type, checkpoint_path=checkpoint_path, return_state=True, device="cpu",
 452    )
 453    model_state = state["model_state"]
 454    prefix = "sam."
 455    model_state = OrderedDict(
 456        [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]
 457    )
 458    torch.save(model_state, save_path)
 459
 460
 461def get_model_names() -> Iterable:
 462    model_registry = models()
 463    model_names = model_registry.registry.keys()
 464    return model_names
 465
 466
 467#
 468# Functionality for precomputing image embeddings.
 469#
 470
 471
 472def _to_image(input_):
 473    # we require the input to be uint8
 474    if input_.dtype != np.dtype("uint8"):
 475        # first normalize the input to [0, 1]
 476        input_ = input_.astype("float32") - input_.min()
 477        input_ = input_ / input_.max()
 478        # then bring to [0, 255] and cast to uint8
 479        input_ = (input_ * 255).astype("uint8")
 480    if input_.ndim == 2:
 481        image = np.concatenate([input_[..., None]] * 3, axis=-1)
 482    elif input_.ndim == 3 and input_.shape[-1] == 3:
 483        image = input_
 484    else:
 485        raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image.")
 486    return image
 487
 488
 489def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update):
 490    tiling = blocking([0, 0], input_.shape[:2], tile_shape)
 491    n_tiles = tiling.numberOfBlocks
 492
 493    features = f.require_group("features")
 494    features.attrs["shape"] = input_.shape[:2]
 495    features.attrs["tile_shape"] = tile_shape
 496    features.attrs["halo"] = halo
 497
 498    pbar_init(n_tiles, "Compute Image Embeddings 2D tiled.")
 499    for tile_id in range(n_tiles):
 500        tile = tiling.getBlockWithHalo(tile_id, list(halo))
 501        outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end))
 502
 503        predictor.reset_image()
 504        tile_input = _to_image(input_[outer_tile])
 505        predictor.set_image(tile_input)
 506        tile_features = predictor.get_image_embedding()
 507        original_size = predictor.original_size
 508        input_size = predictor.input_size
 509
 510        ds = features.create_dataset(
 511            str(tile_id), data=tile_features.cpu().numpy(), compression="gzip", chunks=tile_features.shape
 512        )
 513        ds.attrs["original_size"] = original_size
 514        ds.attrs["input_size"] = input_size
 515        pbar_update(1)
 516
 517    _write_embedding_signature(
 518        f, input_, predictor, tile_shape, halo, input_size=None, original_size=None,
 519    )
 520    return features
 521
 522
 523def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update):
 524    assert input_.ndim == 3
 525
 526    shape = input_.shape[1:]
 527    tiling = blocking([0, 0], shape, tile_shape)
 528    n_tiles = tiling.numberOfBlocks
 529
 530    features = f.require_group("features")
 531    features.attrs["shape"] = shape
 532    features.attrs["tile_shape"] = tile_shape
 533    features.attrs["halo"] = halo
 534
 535    n_slices = input_.shape[0]
 536    pbar_init(n_tiles * n_slices, "Compute Image Embeddings 3D tiled.")
 537
 538    for tile_id in range(n_tiles):
 539        tile = tiling.getBlockWithHalo(tile_id, list(halo))
 540        outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end))
 541
 542        ds = None
 543        for z in range(n_slices):
 544            predictor.reset_image()
 545            tile_input = _to_image(input_[z][outer_tile])
 546            predictor.set_image(tile_input)
 547            tile_features = predictor.get_image_embedding()
 548
 549            if ds is None:
 550                shape = (input_.shape[0],) + tile_features.shape
 551                chunks = (1,) + tile_features.shape
 552                ds = features.create_dataset(
 553                    str(tile_id), shape=shape, dtype="float32", compression="gzip", chunks=chunks
 554                )
 555
 556            ds[z] = tile_features.cpu().numpy()
 557            pbar_update(1)
 558
 559        original_size = predictor.original_size
 560        input_size = predictor.input_size
 561
 562        ds.attrs["original_size"] = original_size
 563        ds.attrs["input_size"] = input_size
 564
 565    _write_embedding_signature(
 566        f, input_, predictor, tile_shape, halo, input_size=None, original_size=None,
 567    )
 568
 569    return features
 570
 571
 572def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update):
 573    # Check if the embeddings are already cached.
 574    if save_path is not None and "input_size" in f.attrs:
 575        # In this case we load the embeddings.
 576        features = f["features"][:]
 577        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 578        image_embeddings = {
 579            "features": features, "input_size": input_size, "original_size": original_size,
 580        }
 581        # Also set the embeddings.
 582        set_precomputed(predictor, image_embeddings)
 583        return image_embeddings
 584
 585    pbar_init(1, "Compute Image Embeddings 2D.")
 586    # Otherwise we have to compute the embeddings.
 587    predictor.reset_image()
 588    predictor.set_image(_to_image(input_))
 589    features = predictor.get_image_embedding().cpu().numpy()
 590    original_size = predictor.original_size
 591    input_size = predictor.input_size
 592    pbar_update(1)
 593
 594    # Save the embeddings if we have a save_path.
 595    if save_path is not None:
 596        f.create_dataset(
 597            "features", data=features, compression="gzip", chunks=features.shape
 598        )
 599        _write_embedding_signature(
 600            f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size,
 601        )
 602
 603    image_embeddings = {
 604        "features": features, "input_size": input_size, "original_size": original_size,
 605    }
 606    return image_embeddings
 607
 608
 609def _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update):
 610    # Check if the features are already computed.
 611    if "input_size" in f.attrs:
 612        features = f["features"]
 613        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 614        image_embeddings = {
 615            "features": features, "input_size": input_size, "original_size": original_size,
 616        }
 617        return image_embeddings
 618
 619    # Otherwise compute them. Note: saving happens automatically because we
 620    # always write the features to zarr. If no save path is given we use an in-memory zarr.
 621    features = _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update)
 622    image_embeddings = {"features": features, "input_size": None, "original_size": None}
 623    return image_embeddings
 624
 625
 626def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update):
 627    # Check if the embeddings are already fully cached.
 628    if save_path is not None and "input_size" in f.attrs:
 629        # In this case we load the embeddings.
 630        features = f["features"] if lazy_loading else f["features"][:]
 631        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 632        image_embeddings = {
 633            "features": features, "input_size": input_size, "original_size": original_size,
 634        }
 635        return image_embeddings
 636
 637    # Otherwise we have to compute the embeddings.
 638
 639    # First check if we have a save path or not and set things up accordingly.
 640    if save_path is None:
 641        features = []
 642        save_features = False
 643        partial_features = False
 644    else:
 645        save_features = True
 646        embed_shape = (1, 256, 64, 64)
 647        shape = (input_.shape[0],) + embed_shape
 648        chunks = (1,) + embed_shape
 649        if "features" in f:
 650            partial_features = True
 651            features = f["features"]
 652            if features.shape != shape or features.chunks != chunks:
 653                raise RuntimeError("Invalid partial features")
 654        else:
 655            partial_features = False
 656            features = f.create_dataset("features", shape=shape, chunks=chunks, dtype="float32")
 657
 658    # Initialize the pbar.
 659    pbar_init(input_.shape[0], "Compute Image Embeddings 3D")
 660
 661    # Compute the embeddings for each slice.
 662    for z, z_slice in enumerate(input_):
 663        # Skip feature computation in case of partial features in non-zero slice.
 664        if partial_features and np.count_nonzero(features[z]) != 0:
 665            continue
 666
 667        predictor.reset_image()
 668        predictor.set_image(_to_image(z_slice))
 669        embedding = predictor.get_image_embedding()
 670        original_size, input_size = predictor.original_size, predictor.input_size
 671
 672        if save_features:
 673            features[z] = embedding.cpu().numpy()
 674        else:
 675            features.append(embedding[None])
 676        pbar_update(1)
 677
 678    if save_features:
 679        _write_embedding_signature(
 680            f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size,
 681        )
 682    else:
 683        # Concatenate across the z axis.
 684        features = torch.cat(features).cpu().numpy()
 685
 686    image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size}
 687    return image_embeddings
 688
 689
 690def _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update):
 691    # Check if the features are already computed.
 692    if "input_size" in f.attrs:
 693        features = f["features"]
 694        original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
 695        image_embeddings = {
 696            "features": features, "input_size": input_size, "original_size": original_size,
 697        }
 698        return image_embeddings
 699
 700    # Otherwise compute them. Note: saving happens automatically because we
 701    # always write the features to zarr. If no save path is given we use an in-memory zarr.
 702    features = _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update)
 703    image_embeddings = {"features": features, "input_size": None, "original_size": None}
 704    return image_embeddings
 705
 706
 707def _compute_data_signature(input_):
 708    data_signature = hashlib.sha1(np.asarray(input_).tobytes()).hexdigest()
 709    return data_signature
 710
 711
 712# Create all metadata that is stored along with the embeddings.
 713def _get_embedding_signature(input_, predictor, tile_shape, halo, data_signature=None):
 714    if data_signature is None:
 715        data_signature = _compute_data_signature(input_)
 716    signature = {
 717        "data_signature": data_signature,
 718        "tile_shape": tile_shape if tile_shape is None else list(tile_shape),
 719        "halo": halo if halo is None else list(halo),
 720        "model_type": predictor.model_type,
 721        "model_name": predictor.model_name,
 722        "micro_sam_version": __version__,
 723        "model_hash": getattr(predictor, "_hash", None),
 724    }
 725    return signature
 726
 727
 728# Note: the input size and orginal size are different if embeddings are tiled or not.
 729# That's why we do not include them in the main signature that is being checked
 730# (_get_embedding_signature), but just add it for serialization here.
 731def _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size, original_size):
 732    signature = _get_embedding_signature(input_, predictor, tile_shape, halo)
 733    signature.update({"input_size": input_size, "original_size": original_size})
 734    for key, val in signature.items():
 735        f.attrs[key] = val
 736
 737
 738def _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo):
 739    # We may have an empty zarr file that was already created to save the embeddings in.
 740    # In this case the embeddings will be computed and we don't need to perform any checks.
 741    if "input_size" not in f.attrs:
 742        return
 743    signature = _get_embedding_signature(input_, predictor, tile_shape, halo)
 744    for key, val in signature.items():
 745        # Check whether the key is missing from the attrs or if the value is not matching.
 746        if key not in f.attrs or f.attrs[key] != val:
 747            # These keys were recently added, so we don't want to fail yet if they don't
 748            # match in order to not invalidate previous embedding files.
 749            # Instead we just raise a warning. (For the version we probably also don't want to fail
 750            # i the future since it should not invalidate the embeddings).
 751            if key in ("micro_sam_version", "model_hash", "model_name"):
 752                warnings.warn(
 753                    f"The signature for {key} in embeddings file {save_path} has a mismatch: "
 754                    f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. "
 755                    "But please recompute them if model predictions don't look as expected."
 756                )
 757            else:
 758                raise RuntimeError(
 759                    f"Embeddings file {save_path} is invalid due to mismatch in {key}: "
 760                    f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file."
 761                )
 762
 763
 764# Helper function for optional external progress bars.
 765def handle_pbar(verbose, pbar_init, pbar_update):
 766    """@private"""
 767
 768    # Noop to provide dummy functions.
 769    def noop(*args):
 770        pass
 771
 772    if verbose and pbar_init is None:  # we are verbose and don't have an external progress bar.
 773        assert pbar_update is None  # avoid inconsistent state of callbacks
 774
 775        # Create our own progress bar and callbacks
 776        pbar = tqdm()
 777
 778        def pbar_init(total, description):
 779            pbar.total = total
 780            pbar.set_description(description)
 781
 782        def pbar_update(update):
 783            pbar.update(update)
 784
 785        def pbar_close():
 786            pbar.close()
 787
 788    elif verbose and pbar_init is not None:  # external pbar -> we don't have to do anything
 789        assert pbar_update is not None
 790        pbar = None
 791        pbar_close = noop
 792
 793    else:  # we are not verbose, do nothing
 794        pbar = None
 795        pbar_init, pbar_update, pbar_close = noop, noop, noop
 796
 797    return pbar, pbar_init, pbar_update, pbar_close
 798
 799
 800def precompute_image_embeddings(
 801    predictor: SamPredictor,
 802    input_: np.ndarray,
 803    save_path: Optional[Union[str, os.PathLike]] = None,
 804    lazy_loading: bool = False,
 805    ndim: Optional[int] = None,
 806    tile_shape: Optional[Tuple[int, int]] = None,
 807    halo: Optional[Tuple[int, int]] = None,
 808    verbose: bool = True,
 809    pbar_init: Optional[callable] = None,
 810    pbar_update: Optional[callable] = None,
 811) -> ImageEmbeddings:
 812    """Compute the image embeddings (output of the encoder) for the input.
 813
 814    If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
 815
 816    Args:
 817        predictor: The SegmentAnything predictor.
 818        input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
 819        save_path: Path to save the embeddings in a zarr container.
 820        lazy_loading: Whether to load all embeddings into memory or return an
 821            object to load them on demand when required. This only has an effect if 'save_path' is given
 822            and if the input is 3 dimensional.
 823        ndim: The dimensionality of the data. If not given will be deduced from the input data.
 824        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
 825        halo: Overlap of the tiles for tiled prediction.
 826        verbose: Whether to be verbose in the computation.
 827        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
 828            Can be used together with pbar_update to handle napari progress bar in other thread.
 829            To enables using this function within a threadworker.
 830        pbar_update: Callback to update an external progress bar.
 831
 832    Returns:
 833        The image embeddings.
 834    """
 835    ndim = input_.ndim if ndim is None else ndim
 836
 837    # Handle the embedding save_path.
 838    # We don't have a save path, open in memory zarr file to hold tiled embeddings.
 839    if save_path is None:
 840        f = zarr.group()
 841
 842    # We have a save path and it already exists. Embeddings will be loaded from it,
 843    # check that the saved embeddings in there match the parameters of the function call.
 844    elif os.path.exists(save_path):
 845        f = zarr.open(save_path, "a")
 846        _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo)
 847
 848    # We have a save path and it does not exist yet. Create the zarr file to which the
 849    # embeddings will then be saved.
 850    else:
 851        f = zarr.open(save_path, "a")
 852
 853    _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update)
 854
 855    if ndim == 2 and tile_shape is None:
 856        embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update)
 857    elif ndim == 2 and tile_shape is not None:
 858        embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update)
 859    elif ndim == 3 and tile_shape is None:
 860        embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update)
 861    elif ndim == 3 and tile_shape is not None:
 862        embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update)
 863    else:
 864        raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.")
 865
 866    pbar_close()
 867    return embeddings
 868
 869
 870def set_precomputed(
 871    predictor: SamPredictor,
 872    image_embeddings: ImageEmbeddings,
 873    i: Optional[int] = None,
 874    tile_id: Optional[int] = None,
 875) -> SamPredictor:
 876    """Set the precomputed image embeddings for a predictor.
 877
 878    Args:
 879        predictor: The SegmentAnything predictor.
 880        image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`.
 881        i: Index for the image data. Required if `image` has three spatial dimensions
 882            or a time dimension and two spatial dimensions.
 883        tile_id: Index for the tile. This is required if the embeddings are tiled.
 884
 885    Returns:
 886        The predictor with set features.
 887    """
 888    if tile_id is not None:
 889        tile_features = image_embeddings["features"][tile_id]
 890        tile_image_embeddings = {
 891            "features": tile_features,
 892            "input_size": tile_features.attrs["input_size"],
 893            "original_size": tile_features.attrs["original_size"]
 894        }
 895        return set_precomputed(predictor, tile_image_embeddings, i=i)
 896
 897    device = predictor.device
 898    features = image_embeddings["features"]
 899    assert features.ndim in (4, 5), f"{features.ndim}"
 900    if features.ndim == 5 and i is None:
 901        raise ValueError("The data is 3D so an index i is needed.")
 902    elif features.ndim == 4 and i is not None:
 903        raise ValueError("The data is 2D so an index is not needed.")
 904
 905    if i is None:
 906        predictor.features = features.to(device) if torch.is_tensor(features) else \
 907            torch.from_numpy(features[:]).to(device)
 908    else:
 909        predictor.features = features[i].to(device) if torch.is_tensor(features) else \
 910            torch.from_numpy(features[i]).to(device)
 911    predictor.original_size = image_embeddings["original_size"]
 912    predictor.input_size = image_embeddings["input_size"]
 913    predictor.is_image_set = True
 914
 915    return predictor
 916
 917
 918#
 919# Misc functionality
 920#
 921
 922
 923def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
 924    """Compute the intersection over union of two masks.
 925
 926    Args:
 927        mask1: The first mask.
 928        mask2: The second mask.
 929
 930    Returns:
 931        The intersection over union of the two masks.
 932    """
 933    overlap = np.logical_and(mask1 == 1, mask2 == 1).sum()
 934    union = np.logical_or(mask1 == 1, mask2 == 1).sum()
 935    eps = 1e-7
 936    iou = float(overlap) / (float(union) + eps)
 937    return iou
 938
 939
 940def get_centers_and_bounding_boxes(
 941    segmentation: np.ndarray,
 942    mode: str = "v"
 943) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]:
 944    """Returns the center coordinates of the foreground instances in the ground-truth.
 945
 946    Args:
 947        segmentation: The segmentation.
 948        mode: Determines the functionality used for computing the centers.
 949        If 'v', the object's eccentricity centers computed by vigra are used.
 950        If 'p' the object's centroids computed by skimage are used.
 951
 952    Returns:
 953        A dictionary that maps object ids to the corresponding centroid.
 954        A dictionary that maps object_ids to the corresponding bounding box.
 955    """
 956    assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra"
 957
 958    properties = regionprops(segmentation)
 959
 960    if mode == "p":
 961        center_coordinates = {prop.label: prop.centroid for prop in properties}
 962    elif mode == "v":
 963        center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32'))
 964        center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0}
 965
 966    bbox_coordinates = {prop.label: prop.bbox for prop in properties}
 967
 968    assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}"
 969    return center_coordinates, bbox_coordinates
 970
 971
 972def load_image_data(
 973    path: str,
 974    key: Optional[str] = None,
 975    lazy_loading: bool = False
 976) -> np.ndarray:
 977    """Helper function to load image data from file.
 978
 979    Args:
 980        path: The filepath to the image data.
 981        key: The internal filepath for complex data formats like hdf5.
 982        lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
 983
 984    Returns:
 985        The image data.
 986    """
 987    if key is None:
 988        image_data = imageio.imread(path)
 989    else:
 990        with open_file(path, mode="r") as f:
 991            image_data = f[key]
 992            if not lazy_loading:
 993                image_data = image_data[:]
 994    return image_data
 995
 996
 997def segmentation_to_one_hot(
 998    segmentation: np.ndarray,
 999    segmentation_ids: Optional[np.ndarray] = None,
1000) -> torch.Tensor:
1001    """Convert the segmentation to one-hot encoded masks.
1002
1003    Args:
1004        segmentation: The segmentation.
1005        segmentation_ids: Optional subset of ids that will be used to subsample the masks.
1006
1007    Returns:
1008        The one-hot encoded masks.
1009    """
1010    masks = segmentation.copy()
1011    if segmentation_ids is None:
1012        n_ids = int(segmentation.max())
1013
1014    else:
1015        assert segmentation_ids[0] != 0, "No objects were found."
1016
1017        # the segmentation ids have to be sorted
1018        segmentation_ids = np.sort(segmentation_ids)
1019
1020        # set the non selected objects to zero and relabel sequentially
1021        masks[~np.isin(masks, segmentation_ids)] = 0
1022        masks = relabel_sequential(masks)[0]
1023        n_ids = len(segmentation_ids)
1024
1025    masks = torch.from_numpy(masks)
1026
1027    one_hot_shape = (n_ids + 1,) + masks.shape
1028    masks = masks.unsqueeze(0)  # add dimension to scatter
1029    masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:]
1030
1031    # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W
1032    masks = masks.unsqueeze(1)
1033    return masks
1034
1035
1036def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1037    """Get a suitable block shape for chunking a given shape.
1038
1039    The primary use for this is determining chunk sizes for
1040    zarr arrays or block shapes for parallelization.
1041
1042    Args:
1043        shape: The image or volume shape.
1044
1045    Returns:
1046        The block shape.
1047    """
1048    ndim = len(shape)
1049    if ndim == 2:
1050        block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape))
1051    elif ndim == 3:
1052        block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape))
1053    else:
1054        raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.")
1055
1056    return block_shape
def get_cache_directory() -> None:
61def get_cache_directory() -> None:
62    """Get micro-sam cache directory location.
63
64    Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
65    """
66    default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam"))
67    cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory))
68    return cache_directory

Get micro-sam cache directory location.

Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.

def microsam_cachedir() -> None:
76def microsam_cachedir() -> None:
77    """Return the micro-sam cache directory.
78
79    Returns the top level cache directory for micro-sam models and sample data.
80
81    Every time this function is called, we check for any user updates made to
82    the MICROSAM_CACHEDIR os environment variable since the last time.
83    """
84    cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam")
85    return cache_directory

Return the micro-sam cache directory.

Returns the top level cache directory for micro-sam models and sample data.

Every time this function is called, we check for any user updates made to the MICROSAM_CACHEDIR os environment variable since the last time.

def models():
 88def models():
 89    """Return the segmentation models registry.
 90
 91    We recreate the model registry every time this function is called,
 92    so any user changes to the default micro-sam cache directory location
 93    are respected.
 94    """
 95
 96    # We use xxhash to compute the hash of the models, see
 97    # https://github.com/computational-cell-analytics/micro-sam/issues/283
 98    # (It is now a dependency, so we don't provide the sha256 fallback anymore.)
 99    # To generate the xxh128 hash:
100    #     xxh128sum filename
101    encoder_registry = {
102        # The default segment anything models:
103        "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a",
104        "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2",
105        "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b",
106        # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM.
107        "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0",
108        # The current version of our models in the modelzoo.
109        # LM generalist models:
110        "vit_l_lm": "xxh128:ad3afe783b0d05a788eaf3cc24b308d2",
111        "vit_b_lm": "xxh128:61ce01ea731d89ae41a252480368f886",
112        "vit_t_lm": "xxh128:f90e2ba3dd3d5b935aa870cf2e48f689",
113        # EM models:
114        "vit_l_em_organelles": "xxh128:096c9695966803ca6fde24f4c1e3c3fb",
115        "vit_b_em_organelles": "xxh128:f6f6593aeecd0e15a07bdac86360b6cc",
116        "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9",
117    }
118    # Additional decoders for instance segmentation.
119    decoder_registry = {
120        # LM generalist models:
121        "vit_l_lm_decoder": "xxh128:40c1ae378cfdce24008b9be24889a5b1",
122        "vit_b_lm_decoder": "xxh128:1bac305195777ba7375634ca15a3c370",
123        "vit_t_lm_decoder": "xxh128:82d3604e64f289bb66ec46a5643da169",
124        # EM models:
125        "vit_l_em_organelles_decoder": "xxh128:d60fd96bd6060856f6430f29e42568fb",
126        "vit_b_em_organelles_decoder": "xxh128:b2d4dcffb99f76d83497d39ee500088f",
127        "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44",
128    }
129    registry = {**encoder_registry, **decoder_registry}
130
131    encoder_urls = {
132        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
133        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
134        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
135        "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download",
136        "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l.pt",
137        "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b.pt",
138        "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t.pt",
139        "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt",  # noqa
140        "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt",
141        "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt",  # noqa
142    }
143
144    decoder_urls = {
145        "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l_decoder.pt",  # noqa
146        "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b_decoder.pt",  # noqa
147        "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t_decoder.pt",  # noqa
148        "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt",  # noqa
149        "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt",  # noqa
150        "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt",  # noqa
151    }
152    urls = {**encoder_urls, **decoder_urls}
153
154    models = pooch.create(
155        path=os.path.join(microsam_cachedir(), "models"),
156        base_url="",
157        registry=registry,
158        urls=urls,
159    )
160    return models

Return the segmentation models registry.

We recreate the model registry every time this function is called, so any user changes to the default micro-sam cache directory location are respected.

def get_device( device: Union[str, torch.device, NoneType] = None) -> Union[str, torch.device]:
182def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]:
183    """Get the torch device.
184
185    If no device is passed the default device for your system is used.
186    Else it will be checked if the device you have passed is supported.
187
188    Args:
189        device: The input device.
190
191    Returns:
192        The device.
193    """
194    if device is None or device == "auto":
195        device = _get_default_device()
196    else:
197        device_type = device if isinstance(device, str) else device.type
198        if device_type.lower() == "cuda":
199            if not torch.cuda.is_available():
200                raise RuntimeError("PyTorch CUDA backend is not available.")
201        elif device_type.lower() == "mps":
202            if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
203                raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.")
204        elif device_type.lower() == "cpu":
205            pass  # cpu is always available
206        else:
207            raise RuntimeError(f"Unsupported device: {device}\n"
208                               "Please choose from 'cpu', 'cuda', or 'mps'.")
209    return device

Get the torch device.

If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.

Arguments:
  • device: The input device.
Returns:

The device.

def get_sam_model( model_type: str = 'vit_l', device: Union[str, torch.device, NoneType] = None, checkpoint_path: Union[str, os.PathLike, NoneType] = None, return_sam: bool = False, return_state: bool = False, peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs) -> mobile_sam.predictor.SamPredictor:
271def get_sam_model(
272    model_type: str = _DEFAULT_MODEL,
273    device: Optional[Union[str, torch.device]] = None,
274    checkpoint_path: Optional[Union[str, os.PathLike]] = None,
275    return_sam: bool = False,
276    return_state: bool = False,
277    peft_kwargs: Optional[Dict] = None,
278    flexible_load_checkpoint: bool = False,
279    **model_kwargs,
280) -> SamPredictor:
281    r"""Get the SegmentAnything Predictor.
282
283    This function will download the required model or load it from the cached weight file.
284    This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
285    The name of the requested model can be set via `model_type`.
286    See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
287    for an overview of the available models
288
289    Alternatively this function can also load a model from weights stored in a local filepath.
290    The corresponding file path is given via `checkpoint_path`. In this case `model_type`
291    must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
292    a SAM model with vit_b encoder.
293
294    By default the models are downloaded to a folder named 'micro_sam/models'
295    inside your default cache directory, eg:
296    * Mac: ~/Library/Caches/<AppName>
297    * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined.
298    * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache
299    See the pooch.os_cache() documentation for more details:
300    https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
301
302    Args:
303        model_type: The SegmentAnything model to use. Will use the standard vit_h model by default.
304            To get a list of all available model names you can call `get_model_names`.
305        device: The device for the model. If none is given will use GPU if available.
306        checkpoint_path: The path to a file with weights that should be used instead of using the
307            weights corresponding to `model_type`. If given, `model_type` must match the architecture
308            corresponding to the weight file. E.g. if you use weights for SAM with vit_b encoder
309            then `model_type` must be given as "vit_b".
310        return_sam: Return the sam model object as well as the predictor.
311        return_state: Return the unpickled checkpoint state.
312        peft_kwargs: Keyword arguments for th PEFT wrapper class.
313        flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
314
315    Returns:
316        The segment anything predictor.
317    """
318    device = get_device(device)
319
320    # We support passing a local filepath to a checkpoint.
321    # In this case we do not download any weights but just use the local weight file,
322    # as it is, without copying it over anywhere or checking it's hashes.
323
324    # checkpoint_path has not been passed, we download a known model and derive the correct
325    # URL from the model_type. If the model_type is invalid pooch will raise an error.
326    if checkpoint_path is None:
327        model_registry = models()
328        checkpoint_path = model_registry.fetch(model_type, progressbar=True)
329        model_hash = model_registry.registry[model_type]
330
331        # If we have a custom model then we may also have a decoder checkpoint.
332        # Download it here, so that we can add it to the state.
333        decoder_name = f"{model_type}_decoder"
334        decoder_path = model_registry.fetch(
335            decoder_name, progressbar=True
336        ) if decoder_name in model_registry.registry else None
337
338    # checkpoint_path has been passed, we use it instead of downloading a model.
339    else:
340        # Check if the file exists and raise an error otherwise.
341        # We can't check any hashes here, and we don't check if the file is actually a valid weight file.
342        # (If it isn't the model creation will fail below.)
343        if not os.path.exists(checkpoint_path):
344            raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.")
345        model_hash = _compute_hash(checkpoint_path)
346        decoder_path = None
347
348    # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped
349    # before calling sam_model_registry.
350    abbreviated_model_type = model_type[:5]
351    if abbreviated_model_type not in _MODEL_TYPES:
352        raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}")
353    if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
354        raise RuntimeError(
355            "mobile_sam is required for the vit-tiny."
356            "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
357        )
358
359    state, model_state = _load_checkpoint(checkpoint_path)
360
361    # Whether to update parameters necessary to initialize the model
362    if model_kwargs:  # Checks whether model_kwargs have been provided or not
363        if abbreviated_model_type == "vit_t":
364            raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.")
365        sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs)
366
367    else:
368        sam = sam_model_registry[abbreviated_model_type]()
369
370    # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
371    # Overwrites the SAM model by freezing the backbone and allow PEFT.
372    if peft_kwargs and isinstance(peft_kwargs, dict):
373        if abbreviated_model_type == "vit_t":
374            raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")
375
376        sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam
377
378    # In case the model checkpoints have some issues when it is initialized with different parameters than default.
379    if flexible_load_checkpoint:
380        sam = _handle_checkpoint_loading(sam, model_state)
381    else:
382        sam.load_state_dict(model_state)
383
384    sam.to(device=device)
385
386    predictor = SamPredictor(sam)
387    predictor.model_type = abbreviated_model_type
388    predictor._hash = model_hash
389    predictor.model_name = model_type
390
391    # Add the decoder to the state if we have one and if the state is returned.
392    if decoder_path is not None and return_state:
393        state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False)
394
395    if return_sam and return_state:
396        return predictor, sam, state
397    if return_sam:
398        return predictor, sam
399    if return_state:
400        return predictor, state
401    return predictor

Get the SegmentAnything Predictor.

This function will download the required model or load it from the cached weight file. This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. The name of the requested model can be set via model_type. See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models for an overview of the available models

Alternatively this function can also load a model from weights stored in a local filepath. The corresponding file path is given via checkpoint_path. In this case model_type must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for a SAM model with vit_b encoder.

By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg:

Arguments:
  • model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. To get a list of all available model names you can call get_model_names.
  • device: The device for the model. If none is given will use GPU if available.
  • checkpoint_path: The path to a file with weights that should be used instead of using the weights corresponding to model_type. If given, model_type must match the architecture corresponding to the weight file. E.g. if you use weights for SAM with vit_b encoder then model_type must be given as "vit_b".
  • return_sam: Return the sam model object as well as the predictor.
  • return_state: Return the unpickled checkpoint state.
  • peft_kwargs: Keyword arguments for th PEFT wrapper class.
  • flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
Returns:

The segment anything predictor.

def export_custom_sam_model( checkpoint_path: Union[str, os.PathLike], model_type: str, save_path: Union[str, os.PathLike]) -> None:
437def export_custom_sam_model(
438    checkpoint_path: Union[str, os.PathLike],
439    model_type: str,
440    save_path: Union[str, os.PathLike],
441) -> None:
442    """Export a finetuned segment anything model to the standard model format.
443
444    The exported model can be used by the interactive annotation tools in `micro_sam.annotator`.
445
446    Args:
447        checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
448        model_type: The SegmentAnything model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
449        save_path: Where to save the exported model.
450    """
451    _, state = get_sam_model(
452        model_type=model_type, checkpoint_path=checkpoint_path, return_state=True, device="cpu",
453    )
454    model_state = state["model_state"]
455    prefix = "sam."
456    model_state = OrderedDict(
457        [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()]
458    )
459    torch.save(model_state, save_path)

Export a finetuned segment anything model to the standard model format.

The exported model can be used by the interactive annotation tools in micro_sam.annotator.

Arguments:
  • checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
  • model_type: The SegmentAnything model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
  • save_path: Where to save the exported model.
def get_model_names() -> Iterable:
462def get_model_names() -> Iterable:
463    model_registry = models()
464    model_names = model_registry.registry.keys()
465    return model_names
def precompute_image_embeddings( predictor: mobile_sam.predictor.SamPredictor, input_: numpy.ndarray, save_path: Union[str, os.PathLike, NoneType] = None, lazy_loading: bool = False, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, pbar_init: Optional[<built-in function callable>] = None, pbar_update: Optional[<built-in function callable>] = None) -> Dict[str, Any]:
801def precompute_image_embeddings(
802    predictor: SamPredictor,
803    input_: np.ndarray,
804    save_path: Optional[Union[str, os.PathLike]] = None,
805    lazy_loading: bool = False,
806    ndim: Optional[int] = None,
807    tile_shape: Optional[Tuple[int, int]] = None,
808    halo: Optional[Tuple[int, int]] = None,
809    verbose: bool = True,
810    pbar_init: Optional[callable] = None,
811    pbar_update: Optional[callable] = None,
812) -> ImageEmbeddings:
813    """Compute the image embeddings (output of the encoder) for the input.
814
815    If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
816
817    Args:
818        predictor: The SegmentAnything predictor.
819        input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
820        save_path: Path to save the embeddings in a zarr container.
821        lazy_loading: Whether to load all embeddings into memory or return an
822            object to load them on demand when required. This only has an effect if 'save_path' is given
823            and if the input is 3 dimensional.
824        ndim: The dimensionality of the data. If not given will be deduced from the input data.
825        tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
826        halo: Overlap of the tiles for tiled prediction.
827        verbose: Whether to be verbose in the computation.
828        pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
829            Can be used together with pbar_update to handle napari progress bar in other thread.
830            To enables using this function within a threadworker.
831        pbar_update: Callback to update an external progress bar.
832
833    Returns:
834        The image embeddings.
835    """
836    ndim = input_.ndim if ndim is None else ndim
837
838    # Handle the embedding save_path.
839    # We don't have a save path, open in memory zarr file to hold tiled embeddings.
840    if save_path is None:
841        f = zarr.group()
842
843    # We have a save path and it already exists. Embeddings will be loaded from it,
844    # check that the saved embeddings in there match the parameters of the function call.
845    elif os.path.exists(save_path):
846        f = zarr.open(save_path, "a")
847        _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo)
848
849    # We have a save path and it does not exist yet. Create the zarr file to which the
850    # embeddings will then be saved.
851    else:
852        f = zarr.open(save_path, "a")
853
854    _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update)
855
856    if ndim == 2 and tile_shape is None:
857        embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update)
858    elif ndim == 2 and tile_shape is not None:
859        embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update)
860    elif ndim == 3 and tile_shape is None:
861        embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update)
862    elif ndim == 3 and tile_shape is not None:
863        embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update)
864    else:
865        raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.")
866
867    pbar_close()
868    return embeddings

Compute the image embeddings (output of the encoder) for the input.

If 'save_path' is given the embeddings will be loaded/saved in a zarr container.

Arguments:
  • predictor: The SegmentAnything predictor.
  • input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
  • save_path: Path to save the embeddings in a zarr container.
  • lazy_loading: Whether to load all embeddings into memory or return an object to load them on demand when required. This only has an effect if 'save_path' is given and if the input is 3 dimensional.
  • ndim: The dimensionality of the data. If not given will be deduced from the input data.
  • tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
  • halo: Overlap of the tiles for tiled prediction.
  • verbose: Whether to be verbose in the computation.
  • pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
  • pbar_update: Callback to update an external progress bar.
Returns:

The image embeddings.

def set_precomputed( predictor: mobile_sam.predictor.SamPredictor, image_embeddings: Dict[str, Any], i: Optional[int] = None, tile_id: Optional[int] = None) -> mobile_sam.predictor.SamPredictor:
871def set_precomputed(
872    predictor: SamPredictor,
873    image_embeddings: ImageEmbeddings,
874    i: Optional[int] = None,
875    tile_id: Optional[int] = None,
876) -> SamPredictor:
877    """Set the precomputed image embeddings for a predictor.
878
879    Args:
880        predictor: The SegmentAnything predictor.
881        image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`.
882        i: Index for the image data. Required if `image` has three spatial dimensions
883            or a time dimension and two spatial dimensions.
884        tile_id: Index for the tile. This is required if the embeddings are tiled.
885
886    Returns:
887        The predictor with set features.
888    """
889    if tile_id is not None:
890        tile_features = image_embeddings["features"][tile_id]
891        tile_image_embeddings = {
892            "features": tile_features,
893            "input_size": tile_features.attrs["input_size"],
894            "original_size": tile_features.attrs["original_size"]
895        }
896        return set_precomputed(predictor, tile_image_embeddings, i=i)
897
898    device = predictor.device
899    features = image_embeddings["features"]
900    assert features.ndim in (4, 5), f"{features.ndim}"
901    if features.ndim == 5 and i is None:
902        raise ValueError("The data is 3D so an index i is needed.")
903    elif features.ndim == 4 and i is not None:
904        raise ValueError("The data is 2D so an index is not needed.")
905
906    if i is None:
907        predictor.features = features.to(device) if torch.is_tensor(features) else \
908            torch.from_numpy(features[:]).to(device)
909    else:
910        predictor.features = features[i].to(device) if torch.is_tensor(features) else \
911            torch.from_numpy(features[i]).to(device)
912    predictor.original_size = image_embeddings["original_size"]
913    predictor.input_size = image_embeddings["input_size"]
914    predictor.is_image_set = True
915
916    return predictor

Set the precomputed image embeddings for a predictor.

Arguments:
  • predictor: The SegmentAnything predictor.
  • image_embeddings: The precomputed image embeddings computed by precompute_image_embeddings.
  • i: Index for the image data. Required if image has three spatial dimensions or a time dimension and two spatial dimensions.
  • tile_id: Index for the tile. This is required if the embeddings are tiled.
Returns:

The predictor with set features.

def compute_iou(mask1: numpy.ndarray, mask2: numpy.ndarray) -> float:
924def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
925    """Compute the intersection over union of two masks.
926
927    Args:
928        mask1: The first mask.
929        mask2: The second mask.
930
931    Returns:
932        The intersection over union of the two masks.
933    """
934    overlap = np.logical_and(mask1 == 1, mask2 == 1).sum()
935    union = np.logical_or(mask1 == 1, mask2 == 1).sum()
936    eps = 1e-7
937    iou = float(overlap) / (float(union) + eps)
938    return iou

Compute the intersection over union of two masks.

Arguments:
  • mask1: The first mask.
  • mask2: The second mask.
Returns:

The intersection over union of the two masks.

def get_centers_and_bounding_boxes( segmentation: numpy.ndarray, mode: str = 'v') -> Tuple[Dict[int, numpy.ndarray], Dict[int, tuple]]:
941def get_centers_and_bounding_boxes(
942    segmentation: np.ndarray,
943    mode: str = "v"
944) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]:
945    """Returns the center coordinates of the foreground instances in the ground-truth.
946
947    Args:
948        segmentation: The segmentation.
949        mode: Determines the functionality used for computing the centers.
950        If 'v', the object's eccentricity centers computed by vigra are used.
951        If 'p' the object's centroids computed by skimage are used.
952
953    Returns:
954        A dictionary that maps object ids to the corresponding centroid.
955        A dictionary that maps object_ids to the corresponding bounding box.
956    """
957    assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra"
958
959    properties = regionprops(segmentation)
960
961    if mode == "p":
962        center_coordinates = {prop.label: prop.centroid for prop in properties}
963    elif mode == "v":
964        center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32'))
965        center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0}
966
967    bbox_coordinates = {prop.label: prop.bbox for prop in properties}
968
969    assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}"
970    return center_coordinates, bbox_coordinates

Returns the center coordinates of the foreground instances in the ground-truth.

Arguments:
  • segmentation: The segmentation.
  • mode: Determines the functionality used for computing the centers.
  • If 'v', the object's eccentricity centers computed by vigra are used.
  • If 'p' the object's centroids computed by skimage are used.
Returns:

A dictionary that maps object ids to the corresponding centroid. A dictionary that maps object_ids to the corresponding bounding box.

def load_image_data( path: str, key: Optional[str] = None, lazy_loading: bool = False) -> numpy.ndarray:
973def load_image_data(
974    path: str,
975    key: Optional[str] = None,
976    lazy_loading: bool = False
977) -> np.ndarray:
978    """Helper function to load image data from file.
979
980    Args:
981        path: The filepath to the image data.
982        key: The internal filepath for complex data formats like hdf5.
983        lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
984
985    Returns:
986        The image data.
987    """
988    if key is None:
989        image_data = imageio.imread(path)
990    else:
991        with open_file(path, mode="r") as f:
992            image_data = f[key]
993            if not lazy_loading:
994                image_data = image_data[:]
995    return image_data

Helper function to load image data from file.

Arguments:
  • path: The filepath to the image data.
  • key: The internal filepath for complex data formats like hdf5.
  • lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
Returns:

The image data.

def segmentation_to_one_hot( segmentation: numpy.ndarray, segmentation_ids: Optional[numpy.ndarray] = None) -> torch.Tensor:
 998def segmentation_to_one_hot(
 999    segmentation: np.ndarray,
1000    segmentation_ids: Optional[np.ndarray] = None,
1001) -> torch.Tensor:
1002    """Convert the segmentation to one-hot encoded masks.
1003
1004    Args:
1005        segmentation: The segmentation.
1006        segmentation_ids: Optional subset of ids that will be used to subsample the masks.
1007
1008    Returns:
1009        The one-hot encoded masks.
1010    """
1011    masks = segmentation.copy()
1012    if segmentation_ids is None:
1013        n_ids = int(segmentation.max())
1014
1015    else:
1016        assert segmentation_ids[0] != 0, "No objects were found."
1017
1018        # the segmentation ids have to be sorted
1019        segmentation_ids = np.sort(segmentation_ids)
1020
1021        # set the non selected objects to zero and relabel sequentially
1022        masks[~np.isin(masks, segmentation_ids)] = 0
1023        masks = relabel_sequential(masks)[0]
1024        n_ids = len(segmentation_ids)
1025
1026    masks = torch.from_numpy(masks)
1027
1028    one_hot_shape = (n_ids + 1,) + masks.shape
1029    masks = masks.unsqueeze(0)  # add dimension to scatter
1030    masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:]
1031
1032    # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W
1033    masks = masks.unsqueeze(1)
1034    return masks

Convert the segmentation to one-hot encoded masks.

Arguments:
  • segmentation: The segmentation.
  • segmentation_ids: Optional subset of ids that will be used to subsample the masks.
Returns:

The one-hot encoded masks.

def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1037def get_block_shape(shape: Tuple[int]) -> Tuple[int]:
1038    """Get a suitable block shape for chunking a given shape.
1039
1040    The primary use for this is determining chunk sizes for
1041    zarr arrays or block shapes for parallelization.
1042
1043    Args:
1044        shape: The image or volume shape.
1045
1046    Returns:
1047        The block shape.
1048    """
1049    ndim = len(shape)
1050    if ndim == 2:
1051        block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape))
1052    elif ndim == 3:
1053        block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape))
1054    else:
1055        raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.")
1056
1057    return block_shape

Get a suitable block shape for chunking a given shape.

The primary use for this is determining chunk sizes for zarr arrays or block shapes for parallelization.

Arguments:
  • shape: The image or volume shape.
Returns:

The block shape.