micro_sam.util
Helper functions for downloading Segment Anything models and predicting image embeddings.
1""" 2Helper functions for downloading Segment Anything models and predicting image embeddings. 3""" 4 5 6import os 7import pickle 8import hashlib 9import warnings 10from pathlib import Path 11from collections import OrderedDict 12from typing import Any, Dict, Iterable, Optional, Tuple, Union 13 14import zarr 15import vigra 16import torch 17import pooch 18import xxhash 19import numpy as np 20import imageio.v3 as imageio 21from skimage.measure import regionprops 22from skimage.segmentation import relabel_sequential 23 24from elf.io import open_file 25 26from nifty.tools import blocking 27 28from .__version__ import __version__ 29from . import models as custom_models 30 31try: 32 # Avoid import warnigns from mobile_sam 33 with warnings.catch_warnings(): 34 warnings.simplefilter("ignore") 35 from mobile_sam import sam_model_registry, SamPredictor 36 VIT_T_SUPPORT = True 37except ImportError: 38 from segment_anything import sam_model_registry, SamPredictor 39 VIT_T_SUPPORT = False 40 41try: 42 from napari.utils import progress as tqdm 43except ImportError: 44 from tqdm import tqdm 45 46# this is the default model used in micro_sam 47# currently set to the default vit_l 48_DEFAULT_MODEL = "vit_l" 49 50# The valid model types. Each type corresponds to the architecture of the 51# vision transformer used within SAM. 52_MODEL_TYPES = ("vit_l", "vit_b", "vit_h", "vit_t") 53 54 55# TODO define the proper type for image embeddings 56ImageEmbeddings = Dict[str, Any] 57"""@private""" 58 59 60def get_cache_directory() -> None: 61 """Get micro-sam cache directory location. 62 63 Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory. 64 """ 65 default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam")) 66 cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory)) 67 return cache_directory 68 69 70# 71# Functionality for model download and export 72# 73 74 75def microsam_cachedir() -> None: 76 """Return the micro-sam cache directory. 77 78 Returns the top level cache directory for micro-sam models and sample data. 79 80 Every time this function is called, we check for any user updates made to 81 the MICROSAM_CACHEDIR os environment variable since the last time. 82 """ 83 cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam") 84 return cache_directory 85 86 87def models(): 88 """Return the segmentation models registry. 89 90 We recreate the model registry every time this function is called, 91 so any user changes to the default micro-sam cache directory location 92 are respected. 93 """ 94 95 # We use xxhash to compute the hash of the models, see 96 # https://github.com/computational-cell-analytics/micro-sam/issues/283 97 # (It is now a dependency, so we don't provide the sha256 fallback anymore.) 98 # To generate the xxh128 hash: 99 # xxh128sum filename 100 encoder_registry = { 101 # The default segment anything models: 102 "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a", 103 "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2", 104 "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b", 105 # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM. 106 "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0", 107 # The current version of our models in the modelzoo. 108 # LM generalist models: 109 "vit_l_lm": "xxh128:ad3afe783b0d05a788eaf3cc24b308d2", 110 "vit_b_lm": "xxh128:61ce01ea731d89ae41a252480368f886", 111 "vit_t_lm": "xxh128:f90e2ba3dd3d5b935aa870cf2e48f689", 112 # EM models: 113 "vit_l_em_organelles": "xxh128:096c9695966803ca6fde24f4c1e3c3fb", 114 "vit_b_em_organelles": "xxh128:f6f6593aeecd0e15a07bdac86360b6cc", 115 "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9", 116 } 117 # Additional decoders for instance segmentation. 118 decoder_registry = { 119 # LM generalist models: 120 "vit_l_lm_decoder": "xxh128:40c1ae378cfdce24008b9be24889a5b1", 121 "vit_b_lm_decoder": "xxh128:1bac305195777ba7375634ca15a3c370", 122 "vit_t_lm_decoder": "xxh128:82d3604e64f289bb66ec46a5643da169", 123 # EM models: 124 "vit_l_em_organelles_decoder": "xxh128:d60fd96bd6060856f6430f29e42568fb", 125 "vit_b_em_organelles_decoder": "xxh128:b2d4dcffb99f76d83497d39ee500088f", 126 "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44", 127 } 128 registry = {**encoder_registry, **decoder_registry} 129 130 encoder_urls = { 131 "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 132 "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 133 "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 134 "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", 135 "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l.pt", 136 "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b.pt", 137 "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t.pt", 138 "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt", # noqa 139 "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt", 140 "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa 141 } 142 143 decoder_urls = { 144 "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l_decoder.pt", # noqa 145 "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b_decoder.pt", # noqa 146 "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t_decoder.pt", # noqa 147 "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt", # noqa 148 "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt", # noqa 149 "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa 150 } 151 urls = {**encoder_urls, **decoder_urls} 152 153 models = pooch.create( 154 path=os.path.join(microsam_cachedir(), "models"), 155 base_url="", 156 registry=registry, 157 urls=urls, 158 ) 159 return models 160 161 162def _get_default_device(): 163 # check that we're in CI and use the CPU if we are 164 # otherwise the tests may run out of memory on MAC if MPS is used. 165 if os.getenv("GITHUB_ACTIONS") == "true": 166 return "cpu" 167 # Use cuda enabled gpu if it's available. 168 if torch.cuda.is_available(): 169 device = "cuda" 170 # As second priority use mps. 171 # See https://pytorch.org/docs/stable/notes/mps.html for details 172 elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): 173 print("Using apple MPS device.") 174 device = "mps" 175 # Use the CPU as fallback. 176 else: 177 device = "cpu" 178 return device 179 180 181def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 182 """Get the torch device. 183 184 If no device is passed the default device for your system is used. 185 Else it will be checked if the device you have passed is supported. 186 187 Args: 188 device: The input device. 189 190 Returns: 191 The device. 192 """ 193 if device is None or device == "auto": 194 device = _get_default_device() 195 else: 196 device_type = device if isinstance(device, str) else device.type 197 if device_type.lower() == "cuda": 198 if not torch.cuda.is_available(): 199 raise RuntimeError("PyTorch CUDA backend is not available.") 200 elif device_type.lower() == "mps": 201 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 202 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 203 elif device_type.lower() == "cpu": 204 pass # cpu is always available 205 else: 206 raise RuntimeError(f"Unsupported device: {device}\n" 207 "Please choose from 'cpu', 'cuda', or 'mps'.") 208 return device 209 210 211def _available_devices(): 212 available_devices = [] 213 for i in ["cuda", "mps", "cpu"]: 214 try: 215 device = get_device(i) 216 except RuntimeError: 217 pass 218 else: 219 available_devices.append(device) 220 return available_devices 221 222 223# We write a custom unpickler that skips objects that cannot be found instead of 224# throwing an AttributeError or ModueNotFoundError. 225# NOTE: since we just want to unpickle the model to load its weights these errors don't matter. 226# See also https://stackoverflow.com/questions/27732354/unable-to-load-files-using-pickle-and-multiple-modules 227class _CustomUnpickler(pickle.Unpickler): 228 def find_class(self, module, name): 229 try: 230 return super().find_class(module, name) 231 except (AttributeError, ModuleNotFoundError) as e: 232 warnings.warn(f"Did not find {module}:{name} and will skip it, due to error {e}") 233 return None 234 235 236def _compute_hash(path, chunk_size=8192): 237 hash_obj = xxhash.xxh128() 238 with open(path, "rb") as f: 239 chunk = f.read(chunk_size) 240 while chunk: 241 hash_obj.update(chunk) 242 chunk = f.read(chunk_size) 243 hash_val = hash_obj.hexdigest() 244 return f"xxh128:{hash_val}" 245 246 247# Load the state from a checkpoint. 248# The checkpoint can either contain a sam encoder state 249# or it can be a checkpoint for model finetuning. 250def _load_checkpoint(checkpoint_path): 251 # Over-ride the unpickler with our custom one. 252 # This enables imports from torch_em checkpoints even if it cannot be fully unpickled. 253 custom_pickle = pickle 254 custom_pickle.Unpickler = _CustomUnpickler 255 256 state = torch.load(checkpoint_path, map_location="cpu", pickle_module=custom_pickle) 257 if "model_state" in state: 258 # Copy the model weights from torch_em's training format. 259 model_state = state["model_state"] 260 sam_prefix = "sam." 261 model_state = OrderedDict( 262 [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()] 263 ) 264 else: 265 model_state = state 266 267 return state, model_state 268 269 270def get_sam_model( 271 model_type: str = _DEFAULT_MODEL, 272 device: Optional[Union[str, torch.device]] = None, 273 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 274 return_sam: bool = False, 275 return_state: bool = False, 276 peft_kwargs: Optional[Dict] = None, 277 flexible_load_checkpoint: bool = False, 278 **model_kwargs, 279) -> SamPredictor: 280 r"""Get the SegmentAnything Predictor. 281 282 This function will download the required model or load it from the cached weight file. 283 This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. 284 The name of the requested model can be set via `model_type`. 285 See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models 286 for an overview of the available models 287 288 Alternatively this function can also load a model from weights stored in a local filepath. 289 The corresponding file path is given via `checkpoint_path`. In this case `model_type` 290 must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for 291 a SAM model with vit_b encoder. 292 293 By default the models are downloaded to a folder named 'micro_sam/models' 294 inside your default cache directory, eg: 295 * Mac: ~/Library/Caches/<AppName> 296 * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined. 297 * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache 298 See the pooch.os_cache() documentation for more details: 299 https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html 300 301 Args: 302 model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. 303 To get a list of all available model names you can call `get_model_names`. 304 device: The device for the model. If none is given will use GPU if available. 305 checkpoint_path: The path to a file with weights that should be used instead of using the 306 weights corresponding to `model_type`. If given, `model_type` must match the architecture 307 corresponding to the weight file. E.g. if you use weights for SAM with vit_b encoder 308 then `model_type` must be given as "vit_b". 309 return_sam: Return the sam model object as well as the predictor. 310 return_state: Return the unpickled checkpoint state. 311 peft_kwargs: Keyword arguments for th PEFT wrapper class. 312 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 313 314 Returns: 315 The segment anything predictor. 316 """ 317 device = get_device(device) 318 319 # We support passing a local filepath to a checkpoint. 320 # In this case we do not download any weights but just use the local weight file, 321 # as it is, without copying it over anywhere or checking it's hashes. 322 323 # checkpoint_path has not been passed, we download a known model and derive the correct 324 # URL from the model_type. If the model_type is invalid pooch will raise an error. 325 if checkpoint_path is None: 326 model_registry = models() 327 checkpoint_path = model_registry.fetch(model_type, progressbar=True) 328 model_hash = model_registry.registry[model_type] 329 330 # If we have a custom model then we may also have a decoder checkpoint. 331 # Download it here, so that we can add it to the state. 332 decoder_name = f"{model_type}_decoder" 333 decoder_path = model_registry.fetch( 334 decoder_name, progressbar=True 335 ) if decoder_name in model_registry.registry else None 336 337 # checkpoint_path has been passed, we use it instead of downloading a model. 338 else: 339 # Check if the file exists and raise an error otherwise. 340 # We can't check any hashes here, and we don't check if the file is actually a valid weight file. 341 # (If it isn't the model creation will fail below.) 342 if not os.path.exists(checkpoint_path): 343 raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.") 344 model_hash = _compute_hash(checkpoint_path) 345 decoder_path = None 346 347 # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped 348 # before calling sam_model_registry. 349 abbreviated_model_type = model_type[:5] 350 if abbreviated_model_type not in _MODEL_TYPES: 351 raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") 352 if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: 353 raise RuntimeError( 354 "mobile_sam is required for the vit-tiny." 355 "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" 356 ) 357 358 state, model_state = _load_checkpoint(checkpoint_path) 359 360 # Whether to update parameters necessary to initialize the model 361 if model_kwargs: # Checks whether model_kwargs have been provided or not 362 if abbreviated_model_type == "vit_t": 363 raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") 364 sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs) 365 366 else: 367 sam = sam_model_registry[abbreviated_model_type]() 368 369 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 370 # Overwrites the SAM model by freezing the backbone and allow PEFT. 371 if peft_kwargs and isinstance(peft_kwargs, dict): 372 if abbreviated_model_type == "vit_t": 373 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 374 375 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 376 377 # In case the model checkpoints have some issues when it is initialized with different parameters than default. 378 if flexible_load_checkpoint: 379 sam = _handle_checkpoint_loading(sam, model_state) 380 else: 381 sam.load_state_dict(model_state) 382 383 sam.to(device=device) 384 385 predictor = SamPredictor(sam) 386 predictor.model_type = abbreviated_model_type 387 predictor._hash = model_hash 388 predictor.model_name = model_type 389 390 # Add the decoder to the state if we have one and if the state is returned. 391 if decoder_path is not None and return_state: 392 state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False) 393 394 if return_sam and return_state: 395 return predictor, sam, state 396 if return_sam: 397 return predictor, sam 398 if return_state: 399 return predictor, state 400 return predictor 401 402 403def _handle_checkpoint_loading(sam, model_state): 404 # Whether to handle the mismatch issues in a bit more elegant way. 405 # eg. while training for multi-class semantic segmentation in the mask encoder, 406 # parameters are updated - leading to "size mismatch" errors 407 408 new_state_dict = {} # for loading matching parameters 409 mismatched_layers = [] # for tracking mismatching parameters 410 411 reference_state = sam.state_dict() 412 413 for k, v in model_state.items(): 414 if k in reference_state: # This is done to get rid of unwanted layers from pretrained SAM. 415 if reference_state[k].size() == v.size(): 416 new_state_dict[k] = v 417 else: 418 mismatched_layers.append(k) 419 420 reference_state.update(new_state_dict) 421 422 if len(mismatched_layers) > 0: 423 warnings.warn(f"The layers with size mismatch: {mismatched_layers}") 424 425 for mlayer in mismatched_layers: 426 if 'weight' in mlayer: 427 torch.nn.init.kaiming_uniform_(reference_state[mlayer]) 428 elif 'bias' in mlayer: 429 reference_state[mlayer].zero_() 430 431 sam.load_state_dict(reference_state) 432 433 return sam 434 435 436def export_custom_sam_model( 437 checkpoint_path: Union[str, os.PathLike], 438 model_type: str, 439 save_path: Union[str, os.PathLike], 440) -> None: 441 """Export a finetuned segment anything model to the standard model format. 442 443 The exported model can be used by the interactive annotation tools in `micro_sam.annotator`. 444 445 Args: 446 checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. 447 model_type: The SegmentAnything model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). 448 save_path: Where to save the exported model. 449 """ 450 _, state = get_sam_model( 451 model_type=model_type, checkpoint_path=checkpoint_path, return_state=True, device="cpu", 452 ) 453 model_state = state["model_state"] 454 prefix = "sam." 455 model_state = OrderedDict( 456 [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()] 457 ) 458 torch.save(model_state, save_path) 459 460 461def get_model_names() -> Iterable: 462 model_registry = models() 463 model_names = model_registry.registry.keys() 464 return model_names 465 466 467# 468# Functionality for precomputing image embeddings. 469# 470 471 472def _to_image(input_): 473 # we require the input to be uint8 474 if input_.dtype != np.dtype("uint8"): 475 # first normalize the input to [0, 1] 476 input_ = input_.astype("float32") - input_.min() 477 input_ = input_ / input_.max() 478 # then bring to [0, 255] and cast to uint8 479 input_ = (input_ * 255).astype("uint8") 480 if input_.ndim == 2: 481 image = np.concatenate([input_[..., None]] * 3, axis=-1) 482 elif input_.ndim == 3 and input_.shape[-1] == 3: 483 image = input_ 484 else: 485 raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image.") 486 return image 487 488 489def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update): 490 tiling = blocking([0, 0], input_.shape[:2], tile_shape) 491 n_tiles = tiling.numberOfBlocks 492 493 features = f.require_group("features") 494 features.attrs["shape"] = input_.shape[:2] 495 features.attrs["tile_shape"] = tile_shape 496 features.attrs["halo"] = halo 497 498 pbar_init(n_tiles, "Compute Image Embeddings 2D tiled.") 499 for tile_id in range(n_tiles): 500 tile = tiling.getBlockWithHalo(tile_id, list(halo)) 501 outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end)) 502 503 predictor.reset_image() 504 tile_input = _to_image(input_[outer_tile]) 505 predictor.set_image(tile_input) 506 tile_features = predictor.get_image_embedding() 507 original_size = predictor.original_size 508 input_size = predictor.input_size 509 510 ds = features.create_dataset( 511 str(tile_id), data=tile_features.cpu().numpy(), compression="gzip", chunks=tile_features.shape 512 ) 513 ds.attrs["original_size"] = original_size 514 ds.attrs["input_size"] = input_size 515 pbar_update(1) 516 517 _write_embedding_signature( 518 f, input_, predictor, tile_shape, halo, input_size=None, original_size=None, 519 ) 520 return features 521 522 523def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update): 524 assert input_.ndim == 3 525 526 shape = input_.shape[1:] 527 tiling = blocking([0, 0], shape, tile_shape) 528 n_tiles = tiling.numberOfBlocks 529 530 features = f.require_group("features") 531 features.attrs["shape"] = shape 532 features.attrs["tile_shape"] = tile_shape 533 features.attrs["halo"] = halo 534 535 n_slices = input_.shape[0] 536 pbar_init(n_tiles * n_slices, "Compute Image Embeddings 3D tiled.") 537 538 for tile_id in range(n_tiles): 539 tile = tiling.getBlockWithHalo(tile_id, list(halo)) 540 outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end)) 541 542 ds = None 543 for z in range(n_slices): 544 predictor.reset_image() 545 tile_input = _to_image(input_[z][outer_tile]) 546 predictor.set_image(tile_input) 547 tile_features = predictor.get_image_embedding() 548 549 if ds is None: 550 shape = (input_.shape[0],) + tile_features.shape 551 chunks = (1,) + tile_features.shape 552 ds = features.create_dataset( 553 str(tile_id), shape=shape, dtype="float32", compression="gzip", chunks=chunks 554 ) 555 556 ds[z] = tile_features.cpu().numpy() 557 pbar_update(1) 558 559 original_size = predictor.original_size 560 input_size = predictor.input_size 561 562 ds.attrs["original_size"] = original_size 563 ds.attrs["input_size"] = input_size 564 565 _write_embedding_signature( 566 f, input_, predictor, tile_shape, halo, input_size=None, original_size=None, 567 ) 568 569 return features 570 571 572def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update): 573 # Check if the embeddings are already cached. 574 if save_path is not None and "input_size" in f.attrs: 575 # In this case we load the embeddings. 576 features = f["features"][:] 577 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 578 image_embeddings = { 579 "features": features, "input_size": input_size, "original_size": original_size, 580 } 581 # Also set the embeddings. 582 set_precomputed(predictor, image_embeddings) 583 return image_embeddings 584 585 pbar_init(1, "Compute Image Embeddings 2D.") 586 # Otherwise we have to compute the embeddings. 587 predictor.reset_image() 588 predictor.set_image(_to_image(input_)) 589 features = predictor.get_image_embedding().cpu().numpy() 590 original_size = predictor.original_size 591 input_size = predictor.input_size 592 pbar_update(1) 593 594 # Save the embeddings if we have a save_path. 595 if save_path is not None: 596 f.create_dataset( 597 "features", data=features, compression="gzip", chunks=features.shape 598 ) 599 _write_embedding_signature( 600 f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size, 601 ) 602 603 image_embeddings = { 604 "features": features, "input_size": input_size, "original_size": original_size, 605 } 606 return image_embeddings 607 608 609def _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update): 610 # Check if the features are already computed. 611 if "input_size" in f.attrs: 612 features = f["features"] 613 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 614 image_embeddings = { 615 "features": features, "input_size": input_size, "original_size": original_size, 616 } 617 return image_embeddings 618 619 # Otherwise compute them. Note: saving happens automatically because we 620 # always write the features to zarr. If no save path is given we use an in-memory zarr. 621 features = _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update) 622 image_embeddings = {"features": features, "input_size": None, "original_size": None} 623 return image_embeddings 624 625 626def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update): 627 # Check if the embeddings are already fully cached. 628 if save_path is not None and "input_size" in f.attrs: 629 # In this case we load the embeddings. 630 features = f["features"] if lazy_loading else f["features"][:] 631 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 632 image_embeddings = { 633 "features": features, "input_size": input_size, "original_size": original_size, 634 } 635 return image_embeddings 636 637 # Otherwise we have to compute the embeddings. 638 639 # First check if we have a save path or not and set things up accordingly. 640 if save_path is None: 641 features = [] 642 save_features = False 643 partial_features = False 644 else: 645 save_features = True 646 embed_shape = (1, 256, 64, 64) 647 shape = (input_.shape[0],) + embed_shape 648 chunks = (1,) + embed_shape 649 if "features" in f: 650 partial_features = True 651 features = f["features"] 652 if features.shape != shape or features.chunks != chunks: 653 raise RuntimeError("Invalid partial features") 654 else: 655 partial_features = False 656 features = f.create_dataset("features", shape=shape, chunks=chunks, dtype="float32") 657 658 # Initialize the pbar. 659 pbar_init(input_.shape[0], "Compute Image Embeddings 3D") 660 661 # Compute the embeddings for each slice. 662 for z, z_slice in enumerate(input_): 663 # Skip feature computation in case of partial features in non-zero slice. 664 if partial_features and np.count_nonzero(features[z]) != 0: 665 continue 666 667 predictor.reset_image() 668 predictor.set_image(_to_image(z_slice)) 669 embedding = predictor.get_image_embedding() 670 original_size, input_size = predictor.original_size, predictor.input_size 671 672 if save_features: 673 features[z] = embedding.cpu().numpy() 674 else: 675 features.append(embedding[None]) 676 pbar_update(1) 677 678 if save_features: 679 _write_embedding_signature( 680 f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size, 681 ) 682 else: 683 # Concatenate across the z axis. 684 features = torch.cat(features).cpu().numpy() 685 686 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 687 return image_embeddings 688 689 690def _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update): 691 # Check if the features are already computed. 692 if "input_size" in f.attrs: 693 features = f["features"] 694 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 695 image_embeddings = { 696 "features": features, "input_size": input_size, "original_size": original_size, 697 } 698 return image_embeddings 699 700 # Otherwise compute them. Note: saving happens automatically because we 701 # always write the features to zarr. If no save path is given we use an in-memory zarr. 702 features = _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update) 703 image_embeddings = {"features": features, "input_size": None, "original_size": None} 704 return image_embeddings 705 706 707def _compute_data_signature(input_): 708 data_signature = hashlib.sha1(np.asarray(input_).tobytes()).hexdigest() 709 return data_signature 710 711 712# Create all metadata that is stored along with the embeddings. 713def _get_embedding_signature(input_, predictor, tile_shape, halo, data_signature=None): 714 if data_signature is None: 715 data_signature = _compute_data_signature(input_) 716 signature = { 717 "data_signature": data_signature, 718 "tile_shape": tile_shape if tile_shape is None else list(tile_shape), 719 "halo": halo if halo is None else list(halo), 720 "model_type": predictor.model_type, 721 "model_name": predictor.model_name, 722 "micro_sam_version": __version__, 723 "model_hash": getattr(predictor, "_hash", None), 724 } 725 return signature 726 727 728# Note: the input size and orginal size are different if embeddings are tiled or not. 729# That's why we do not include them in the main signature that is being checked 730# (_get_embedding_signature), but just add it for serialization here. 731def _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size, original_size): 732 signature = _get_embedding_signature(input_, predictor, tile_shape, halo) 733 signature.update({"input_size": input_size, "original_size": original_size}) 734 for key, val in signature.items(): 735 f.attrs[key] = val 736 737 738def _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo): 739 # We may have an empty zarr file that was already created to save the embeddings in. 740 # In this case the embeddings will be computed and we don't need to perform any checks. 741 if "input_size" not in f.attrs: 742 return 743 signature = _get_embedding_signature(input_, predictor, tile_shape, halo) 744 for key, val in signature.items(): 745 # Check whether the key is missing from the attrs or if the value is not matching. 746 if key not in f.attrs or f.attrs[key] != val: 747 # These keys were recently added, so we don't want to fail yet if they don't 748 # match in order to not invalidate previous embedding files. 749 # Instead we just raise a warning. (For the version we probably also don't want to fail 750 # i the future since it should not invalidate the embeddings). 751 if key in ("micro_sam_version", "model_hash", "model_name"): 752 warnings.warn( 753 f"The signature for {key} in embeddings file {save_path} has a mismatch: " 754 f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. " 755 "But please recompute them if model predictions don't look as expected." 756 ) 757 else: 758 raise RuntimeError( 759 f"Embeddings file {save_path} is invalid due to mismatch in {key}: " 760 f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file." 761 ) 762 763 764# Helper function for optional external progress bars. 765def handle_pbar(verbose, pbar_init, pbar_update): 766 """@private""" 767 768 # Noop to provide dummy functions. 769 def noop(*args): 770 pass 771 772 if verbose and pbar_init is None: # we are verbose and don't have an external progress bar. 773 assert pbar_update is None # avoid inconsistent state of callbacks 774 775 # Create our own progress bar and callbacks 776 pbar = tqdm() 777 778 def pbar_init(total, description): 779 pbar.total = total 780 pbar.set_description(description) 781 782 def pbar_update(update): 783 pbar.update(update) 784 785 def pbar_close(): 786 pbar.close() 787 788 elif verbose and pbar_init is not None: # external pbar -> we don't have to do anything 789 assert pbar_update is not None 790 pbar = None 791 pbar_close = noop 792 793 else: # we are not verbose, do nothing 794 pbar = None 795 pbar_init, pbar_update, pbar_close = noop, noop, noop 796 797 return pbar, pbar_init, pbar_update, pbar_close 798 799 800def precompute_image_embeddings( 801 predictor: SamPredictor, 802 input_: np.ndarray, 803 save_path: Optional[Union[str, os.PathLike]] = None, 804 lazy_loading: bool = False, 805 ndim: Optional[int] = None, 806 tile_shape: Optional[Tuple[int, int]] = None, 807 halo: Optional[Tuple[int, int]] = None, 808 verbose: bool = True, 809 pbar_init: Optional[callable] = None, 810 pbar_update: Optional[callable] = None, 811) -> ImageEmbeddings: 812 """Compute the image embeddings (output of the encoder) for the input. 813 814 If 'save_path' is given the embeddings will be loaded/saved in a zarr container. 815 816 Args: 817 predictor: The SegmentAnything predictor. 818 input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries. 819 save_path: Path to save the embeddings in a zarr container. 820 lazy_loading: Whether to load all embeddings into memory or return an 821 object to load them on demand when required. This only has an effect if 'save_path' is given 822 and if the input is 3 dimensional. 823 ndim: The dimensionality of the data. If not given will be deduced from the input data. 824 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 825 halo: Overlap of the tiles for tiled prediction. 826 verbose: Whether to be verbose in the computation. 827 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 828 Can be used together with pbar_update to handle napari progress bar in other thread. 829 To enables using this function within a threadworker. 830 pbar_update: Callback to update an external progress bar. 831 832 Returns: 833 The image embeddings. 834 """ 835 ndim = input_.ndim if ndim is None else ndim 836 837 # Handle the embedding save_path. 838 # We don't have a save path, open in memory zarr file to hold tiled embeddings. 839 if save_path is None: 840 f = zarr.group() 841 842 # We have a save path and it already exists. Embeddings will be loaded from it, 843 # check that the saved embeddings in there match the parameters of the function call. 844 elif os.path.exists(save_path): 845 f = zarr.open(save_path, "a") 846 _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) 847 848 # We have a save path and it does not exist yet. Create the zarr file to which the 849 # embeddings will then be saved. 850 else: 851 f = zarr.open(save_path, "a") 852 853 _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update) 854 855 if ndim == 2 and tile_shape is None: 856 embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update) 857 elif ndim == 2 and tile_shape is not None: 858 embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update) 859 elif ndim == 3 and tile_shape is None: 860 embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update) 861 elif ndim == 3 and tile_shape is not None: 862 embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update) 863 else: 864 raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.") 865 866 pbar_close() 867 return embeddings 868 869 870def set_precomputed( 871 predictor: SamPredictor, 872 image_embeddings: ImageEmbeddings, 873 i: Optional[int] = None, 874 tile_id: Optional[int] = None, 875) -> SamPredictor: 876 """Set the precomputed image embeddings for a predictor. 877 878 Args: 879 predictor: The SegmentAnything predictor. 880 image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`. 881 i: Index for the image data. Required if `image` has three spatial dimensions 882 or a time dimension and two spatial dimensions. 883 tile_id: Index for the tile. This is required if the embeddings are tiled. 884 885 Returns: 886 The predictor with set features. 887 """ 888 if tile_id is not None: 889 tile_features = image_embeddings["features"][tile_id] 890 tile_image_embeddings = { 891 "features": tile_features, 892 "input_size": tile_features.attrs["input_size"], 893 "original_size": tile_features.attrs["original_size"] 894 } 895 return set_precomputed(predictor, tile_image_embeddings, i=i) 896 897 device = predictor.device 898 features = image_embeddings["features"] 899 assert features.ndim in (4, 5), f"{features.ndim}" 900 if features.ndim == 5 and i is None: 901 raise ValueError("The data is 3D so an index i is needed.") 902 elif features.ndim == 4 and i is not None: 903 raise ValueError("The data is 2D so an index is not needed.") 904 905 if i is None: 906 predictor.features = features.to(device) if torch.is_tensor(features) else \ 907 torch.from_numpy(features[:]).to(device) 908 else: 909 predictor.features = features[i].to(device) if torch.is_tensor(features) else \ 910 torch.from_numpy(features[i]).to(device) 911 predictor.original_size = image_embeddings["original_size"] 912 predictor.input_size = image_embeddings["input_size"] 913 predictor.is_image_set = True 914 915 return predictor 916 917 918# 919# Misc functionality 920# 921 922 923def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: 924 """Compute the intersection over union of two masks. 925 926 Args: 927 mask1: The first mask. 928 mask2: The second mask. 929 930 Returns: 931 The intersection over union of the two masks. 932 """ 933 overlap = np.logical_and(mask1 == 1, mask2 == 1).sum() 934 union = np.logical_or(mask1 == 1, mask2 == 1).sum() 935 eps = 1e-7 936 iou = float(overlap) / (float(union) + eps) 937 return iou 938 939 940def get_centers_and_bounding_boxes( 941 segmentation: np.ndarray, 942 mode: str = "v" 943) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]: 944 """Returns the center coordinates of the foreground instances in the ground-truth. 945 946 Args: 947 segmentation: The segmentation. 948 mode: Determines the functionality used for computing the centers. 949 If 'v', the object's eccentricity centers computed by vigra are used. 950 If 'p' the object's centroids computed by skimage are used. 951 952 Returns: 953 A dictionary that maps object ids to the corresponding centroid. 954 A dictionary that maps object_ids to the corresponding bounding box. 955 """ 956 assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra" 957 958 properties = regionprops(segmentation) 959 960 if mode == "p": 961 center_coordinates = {prop.label: prop.centroid for prop in properties} 962 elif mode == "v": 963 center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32')) 964 center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0} 965 966 bbox_coordinates = {prop.label: prop.bbox for prop in properties} 967 968 assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}" 969 return center_coordinates, bbox_coordinates 970 971 972def load_image_data( 973 path: str, 974 key: Optional[str] = None, 975 lazy_loading: bool = False 976) -> np.ndarray: 977 """Helper function to load image data from file. 978 979 Args: 980 path: The filepath to the image data. 981 key: The internal filepath for complex data formats like hdf5. 982 lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data. 983 984 Returns: 985 The image data. 986 """ 987 if key is None: 988 image_data = imageio.imread(path) 989 else: 990 with open_file(path, mode="r") as f: 991 image_data = f[key] 992 if not lazy_loading: 993 image_data = image_data[:] 994 return image_data 995 996 997def segmentation_to_one_hot( 998 segmentation: np.ndarray, 999 segmentation_ids: Optional[np.ndarray] = None, 1000) -> torch.Tensor: 1001 """Convert the segmentation to one-hot encoded masks. 1002 1003 Args: 1004 segmentation: The segmentation. 1005 segmentation_ids: Optional subset of ids that will be used to subsample the masks. 1006 1007 Returns: 1008 The one-hot encoded masks. 1009 """ 1010 masks = segmentation.copy() 1011 if segmentation_ids is None: 1012 n_ids = int(segmentation.max()) 1013 1014 else: 1015 assert segmentation_ids[0] != 0, "No objects were found." 1016 1017 # the segmentation ids have to be sorted 1018 segmentation_ids = np.sort(segmentation_ids) 1019 1020 # set the non selected objects to zero and relabel sequentially 1021 masks[~np.isin(masks, segmentation_ids)] = 0 1022 masks = relabel_sequential(masks)[0] 1023 n_ids = len(segmentation_ids) 1024 1025 masks = torch.from_numpy(masks) 1026 1027 one_hot_shape = (n_ids + 1,) + masks.shape 1028 masks = masks.unsqueeze(0) # add dimension to scatter 1029 masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:] 1030 1031 # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W 1032 masks = masks.unsqueeze(1) 1033 return masks 1034 1035 1036def get_block_shape(shape: Tuple[int]) -> Tuple[int]: 1037 """Get a suitable block shape for chunking a given shape. 1038 1039 The primary use for this is determining chunk sizes for 1040 zarr arrays or block shapes for parallelization. 1041 1042 Args: 1043 shape: The image or volume shape. 1044 1045 Returns: 1046 The block shape. 1047 """ 1048 ndim = len(shape) 1049 if ndim == 2: 1050 block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape)) 1051 elif ndim == 3: 1052 block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape)) 1053 else: 1054 raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.") 1055 1056 return block_shape
61def get_cache_directory() -> None: 62 """Get micro-sam cache directory location. 63 64 Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory. 65 """ 66 default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam")) 67 cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory)) 68 return cache_directory
Get micro-sam cache directory location.
Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
76def microsam_cachedir() -> None: 77 """Return the micro-sam cache directory. 78 79 Returns the top level cache directory for micro-sam models and sample data. 80 81 Every time this function is called, we check for any user updates made to 82 the MICROSAM_CACHEDIR os environment variable since the last time. 83 """ 84 cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam") 85 return cache_directory
Return the micro-sam cache directory.
Returns the top level cache directory for micro-sam models and sample data.
Every time this function is called, we check for any user updates made to the MICROSAM_CACHEDIR os environment variable since the last time.
88def models(): 89 """Return the segmentation models registry. 90 91 We recreate the model registry every time this function is called, 92 so any user changes to the default micro-sam cache directory location 93 are respected. 94 """ 95 96 # We use xxhash to compute the hash of the models, see 97 # https://github.com/computational-cell-analytics/micro-sam/issues/283 98 # (It is now a dependency, so we don't provide the sha256 fallback anymore.) 99 # To generate the xxh128 hash: 100 # xxh128sum filename 101 encoder_registry = { 102 # The default segment anything models: 103 "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a", 104 "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2", 105 "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b", 106 # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM. 107 "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0", 108 # The current version of our models in the modelzoo. 109 # LM generalist models: 110 "vit_l_lm": "xxh128:ad3afe783b0d05a788eaf3cc24b308d2", 111 "vit_b_lm": "xxh128:61ce01ea731d89ae41a252480368f886", 112 "vit_t_lm": "xxh128:f90e2ba3dd3d5b935aa870cf2e48f689", 113 # EM models: 114 "vit_l_em_organelles": "xxh128:096c9695966803ca6fde24f4c1e3c3fb", 115 "vit_b_em_organelles": "xxh128:f6f6593aeecd0e15a07bdac86360b6cc", 116 "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9", 117 } 118 # Additional decoders for instance segmentation. 119 decoder_registry = { 120 # LM generalist models: 121 "vit_l_lm_decoder": "xxh128:40c1ae378cfdce24008b9be24889a5b1", 122 "vit_b_lm_decoder": "xxh128:1bac305195777ba7375634ca15a3c370", 123 "vit_t_lm_decoder": "xxh128:82d3604e64f289bb66ec46a5643da169", 124 # EM models: 125 "vit_l_em_organelles_decoder": "xxh128:d60fd96bd6060856f6430f29e42568fb", 126 "vit_b_em_organelles_decoder": "xxh128:b2d4dcffb99f76d83497d39ee500088f", 127 "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44", 128 } 129 registry = {**encoder_registry, **decoder_registry} 130 131 encoder_urls = { 132 "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 133 "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 134 "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 135 "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", 136 "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l.pt", 137 "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b.pt", 138 "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t.pt", 139 "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt", # noqa 140 "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt", 141 "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa 142 } 143 144 decoder_urls = { 145 "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l_decoder.pt", # noqa 146 "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b_decoder.pt", # noqa 147 "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t_decoder.pt", # noqa 148 "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt", # noqa 149 "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt", # noqa 150 "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa 151 } 152 urls = {**encoder_urls, **decoder_urls} 153 154 models = pooch.create( 155 path=os.path.join(microsam_cachedir(), "models"), 156 base_url="", 157 registry=registry, 158 urls=urls, 159 ) 160 return models
Return the segmentation models registry.
We recreate the model registry every time this function is called, so any user changes to the default micro-sam cache directory location are respected.
182def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 183 """Get the torch device. 184 185 If no device is passed the default device for your system is used. 186 Else it will be checked if the device you have passed is supported. 187 188 Args: 189 device: The input device. 190 191 Returns: 192 The device. 193 """ 194 if device is None or device == "auto": 195 device = _get_default_device() 196 else: 197 device_type = device if isinstance(device, str) else device.type 198 if device_type.lower() == "cuda": 199 if not torch.cuda.is_available(): 200 raise RuntimeError("PyTorch CUDA backend is not available.") 201 elif device_type.lower() == "mps": 202 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 203 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 204 elif device_type.lower() == "cpu": 205 pass # cpu is always available 206 else: 207 raise RuntimeError(f"Unsupported device: {device}\n" 208 "Please choose from 'cpu', 'cuda', or 'mps'.") 209 return device
Get the torch device.
If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.
Arguments:
- device: The input device.
Returns:
The device.
271def get_sam_model( 272 model_type: str = _DEFAULT_MODEL, 273 device: Optional[Union[str, torch.device]] = None, 274 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 275 return_sam: bool = False, 276 return_state: bool = False, 277 peft_kwargs: Optional[Dict] = None, 278 flexible_load_checkpoint: bool = False, 279 **model_kwargs, 280) -> SamPredictor: 281 r"""Get the SegmentAnything Predictor. 282 283 This function will download the required model or load it from the cached weight file. 284 This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. 285 The name of the requested model can be set via `model_type`. 286 See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models 287 for an overview of the available models 288 289 Alternatively this function can also load a model from weights stored in a local filepath. 290 The corresponding file path is given via `checkpoint_path`. In this case `model_type` 291 must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for 292 a SAM model with vit_b encoder. 293 294 By default the models are downloaded to a folder named 'micro_sam/models' 295 inside your default cache directory, eg: 296 * Mac: ~/Library/Caches/<AppName> 297 * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined. 298 * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache 299 See the pooch.os_cache() documentation for more details: 300 https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html 301 302 Args: 303 model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. 304 To get a list of all available model names you can call `get_model_names`. 305 device: The device for the model. If none is given will use GPU if available. 306 checkpoint_path: The path to a file with weights that should be used instead of using the 307 weights corresponding to `model_type`. If given, `model_type` must match the architecture 308 corresponding to the weight file. E.g. if you use weights for SAM with vit_b encoder 309 then `model_type` must be given as "vit_b". 310 return_sam: Return the sam model object as well as the predictor. 311 return_state: Return the unpickled checkpoint state. 312 peft_kwargs: Keyword arguments for th PEFT wrapper class. 313 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 314 315 Returns: 316 The segment anything predictor. 317 """ 318 device = get_device(device) 319 320 # We support passing a local filepath to a checkpoint. 321 # In this case we do not download any weights but just use the local weight file, 322 # as it is, without copying it over anywhere or checking it's hashes. 323 324 # checkpoint_path has not been passed, we download a known model and derive the correct 325 # URL from the model_type. If the model_type is invalid pooch will raise an error. 326 if checkpoint_path is None: 327 model_registry = models() 328 checkpoint_path = model_registry.fetch(model_type, progressbar=True) 329 model_hash = model_registry.registry[model_type] 330 331 # If we have a custom model then we may also have a decoder checkpoint. 332 # Download it here, so that we can add it to the state. 333 decoder_name = f"{model_type}_decoder" 334 decoder_path = model_registry.fetch( 335 decoder_name, progressbar=True 336 ) if decoder_name in model_registry.registry else None 337 338 # checkpoint_path has been passed, we use it instead of downloading a model. 339 else: 340 # Check if the file exists and raise an error otherwise. 341 # We can't check any hashes here, and we don't check if the file is actually a valid weight file. 342 # (If it isn't the model creation will fail below.) 343 if not os.path.exists(checkpoint_path): 344 raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.") 345 model_hash = _compute_hash(checkpoint_path) 346 decoder_path = None 347 348 # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped 349 # before calling sam_model_registry. 350 abbreviated_model_type = model_type[:5] 351 if abbreviated_model_type not in _MODEL_TYPES: 352 raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") 353 if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: 354 raise RuntimeError( 355 "mobile_sam is required for the vit-tiny." 356 "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" 357 ) 358 359 state, model_state = _load_checkpoint(checkpoint_path) 360 361 # Whether to update parameters necessary to initialize the model 362 if model_kwargs: # Checks whether model_kwargs have been provided or not 363 if abbreviated_model_type == "vit_t": 364 raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") 365 sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs) 366 367 else: 368 sam = sam_model_registry[abbreviated_model_type]() 369 370 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 371 # Overwrites the SAM model by freezing the backbone and allow PEFT. 372 if peft_kwargs and isinstance(peft_kwargs, dict): 373 if abbreviated_model_type == "vit_t": 374 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 375 376 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 377 378 # In case the model checkpoints have some issues when it is initialized with different parameters than default. 379 if flexible_load_checkpoint: 380 sam = _handle_checkpoint_loading(sam, model_state) 381 else: 382 sam.load_state_dict(model_state) 383 384 sam.to(device=device) 385 386 predictor = SamPredictor(sam) 387 predictor.model_type = abbreviated_model_type 388 predictor._hash = model_hash 389 predictor.model_name = model_type 390 391 # Add the decoder to the state if we have one and if the state is returned. 392 if decoder_path is not None and return_state: 393 state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False) 394 395 if return_sam and return_state: 396 return predictor, sam, state 397 if return_sam: 398 return predictor, sam 399 if return_state: 400 return predictor, state 401 return predictor
Get the SegmentAnything Predictor.
This function will download the required model or load it from the cached weight file.
This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
The name of the requested model can be set via model_type
.
See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
for an overview of the available models
Alternatively this function can also load a model from weights stored in a local filepath.
The corresponding file path is given via checkpoint_path
. In this case model_type
must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
a SAM model with vit_b encoder.
By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg:
- Mac: ~/Library/Caches/
- Unix: ~/.cache/
or the value of the XDG_CACHE_HOME environment variable, if defined. - Windows: C:\Users<user>\AppData\Local<AppAuthor><AppName>\Cache See the pooch.os_cache() documentation for more details: https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
Arguments:
- model_type: The SegmentAnything model to use. Will use the standard vit_h model by default.
To get a list of all available model names you can call
get_model_names
. - device: The device for the model. If none is given will use GPU if available.
- checkpoint_path: The path to a file with weights that should be used instead of using the
weights corresponding to
model_type
. If given,model_type
must match the architecture corresponding to the weight file. E.g. if you use weights for SAM with vit_b encoder thenmodel_type
must be given as "vit_b". - return_sam: Return the sam model object as well as the predictor.
- return_state: Return the unpickled checkpoint state.
- peft_kwargs: Keyword arguments for th PEFT wrapper class.
- flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
Returns:
The segment anything predictor.
437def export_custom_sam_model( 438 checkpoint_path: Union[str, os.PathLike], 439 model_type: str, 440 save_path: Union[str, os.PathLike], 441) -> None: 442 """Export a finetuned segment anything model to the standard model format. 443 444 The exported model can be used by the interactive annotation tools in `micro_sam.annotator`. 445 446 Args: 447 checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. 448 model_type: The SegmentAnything model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). 449 save_path: Where to save the exported model. 450 """ 451 _, state = get_sam_model( 452 model_type=model_type, checkpoint_path=checkpoint_path, return_state=True, device="cpu", 453 ) 454 model_state = state["model_state"] 455 prefix = "sam." 456 model_state = OrderedDict( 457 [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()] 458 ) 459 torch.save(model_state, save_path)
Export a finetuned segment anything model to the standard model format.
The exported model can be used by the interactive annotation tools in micro_sam.annotator
.
Arguments:
- checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
- model_type: The SegmentAnything model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
- save_path: Where to save the exported model.
801def precompute_image_embeddings( 802 predictor: SamPredictor, 803 input_: np.ndarray, 804 save_path: Optional[Union[str, os.PathLike]] = None, 805 lazy_loading: bool = False, 806 ndim: Optional[int] = None, 807 tile_shape: Optional[Tuple[int, int]] = None, 808 halo: Optional[Tuple[int, int]] = None, 809 verbose: bool = True, 810 pbar_init: Optional[callable] = None, 811 pbar_update: Optional[callable] = None, 812) -> ImageEmbeddings: 813 """Compute the image embeddings (output of the encoder) for the input. 814 815 If 'save_path' is given the embeddings will be loaded/saved in a zarr container. 816 817 Args: 818 predictor: The SegmentAnything predictor. 819 input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries. 820 save_path: Path to save the embeddings in a zarr container. 821 lazy_loading: Whether to load all embeddings into memory or return an 822 object to load them on demand when required. This only has an effect if 'save_path' is given 823 and if the input is 3 dimensional. 824 ndim: The dimensionality of the data. If not given will be deduced from the input data. 825 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 826 halo: Overlap of the tiles for tiled prediction. 827 verbose: Whether to be verbose in the computation. 828 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 829 Can be used together with pbar_update to handle napari progress bar in other thread. 830 To enables using this function within a threadworker. 831 pbar_update: Callback to update an external progress bar. 832 833 Returns: 834 The image embeddings. 835 """ 836 ndim = input_.ndim if ndim is None else ndim 837 838 # Handle the embedding save_path. 839 # We don't have a save path, open in memory zarr file to hold tiled embeddings. 840 if save_path is None: 841 f = zarr.group() 842 843 # We have a save path and it already exists. Embeddings will be loaded from it, 844 # check that the saved embeddings in there match the parameters of the function call. 845 elif os.path.exists(save_path): 846 f = zarr.open(save_path, "a") 847 _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) 848 849 # We have a save path and it does not exist yet. Create the zarr file to which the 850 # embeddings will then be saved. 851 else: 852 f = zarr.open(save_path, "a") 853 854 _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update) 855 856 if ndim == 2 and tile_shape is None: 857 embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update) 858 elif ndim == 2 and tile_shape is not None: 859 embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update) 860 elif ndim == 3 and tile_shape is None: 861 embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update) 862 elif ndim == 3 and tile_shape is not None: 863 embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update) 864 else: 865 raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.") 866 867 pbar_close() 868 return embeddings
Compute the image embeddings (output of the encoder) for the input.
If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
Arguments:
- predictor: The SegmentAnything predictor.
- input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
- save_path: Path to save the embeddings in a zarr container.
- lazy_loading: Whether to load all embeddings into memory or return an object to load them on demand when required. This only has an effect if 'save_path' is given and if the input is 3 dimensional.
- ndim: The dimensionality of the data. If not given will be deduced from the input data.
- tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction.
- verbose: Whether to be verbose in the computation.
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
Returns:
The image embeddings.
871def set_precomputed( 872 predictor: SamPredictor, 873 image_embeddings: ImageEmbeddings, 874 i: Optional[int] = None, 875 tile_id: Optional[int] = None, 876) -> SamPredictor: 877 """Set the precomputed image embeddings for a predictor. 878 879 Args: 880 predictor: The SegmentAnything predictor. 881 image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`. 882 i: Index for the image data. Required if `image` has three spatial dimensions 883 or a time dimension and two spatial dimensions. 884 tile_id: Index for the tile. This is required if the embeddings are tiled. 885 886 Returns: 887 The predictor with set features. 888 """ 889 if tile_id is not None: 890 tile_features = image_embeddings["features"][tile_id] 891 tile_image_embeddings = { 892 "features": tile_features, 893 "input_size": tile_features.attrs["input_size"], 894 "original_size": tile_features.attrs["original_size"] 895 } 896 return set_precomputed(predictor, tile_image_embeddings, i=i) 897 898 device = predictor.device 899 features = image_embeddings["features"] 900 assert features.ndim in (4, 5), f"{features.ndim}" 901 if features.ndim == 5 and i is None: 902 raise ValueError("The data is 3D so an index i is needed.") 903 elif features.ndim == 4 and i is not None: 904 raise ValueError("The data is 2D so an index is not needed.") 905 906 if i is None: 907 predictor.features = features.to(device) if torch.is_tensor(features) else \ 908 torch.from_numpy(features[:]).to(device) 909 else: 910 predictor.features = features[i].to(device) if torch.is_tensor(features) else \ 911 torch.from_numpy(features[i]).to(device) 912 predictor.original_size = image_embeddings["original_size"] 913 predictor.input_size = image_embeddings["input_size"] 914 predictor.is_image_set = True 915 916 return predictor
Set the precomputed image embeddings for a predictor.
Arguments:
- predictor: The SegmentAnything predictor.
- image_embeddings: The precomputed image embeddings computed by
precompute_image_embeddings
. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - tile_id: Index for the tile. This is required if the embeddings are tiled.
Returns:
The predictor with set features.
924def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: 925 """Compute the intersection over union of two masks. 926 927 Args: 928 mask1: The first mask. 929 mask2: The second mask. 930 931 Returns: 932 The intersection over union of the two masks. 933 """ 934 overlap = np.logical_and(mask1 == 1, mask2 == 1).sum() 935 union = np.logical_or(mask1 == 1, mask2 == 1).sum() 936 eps = 1e-7 937 iou = float(overlap) / (float(union) + eps) 938 return iou
Compute the intersection over union of two masks.
Arguments:
- mask1: The first mask.
- mask2: The second mask.
Returns:
The intersection over union of the two masks.
941def get_centers_and_bounding_boxes( 942 segmentation: np.ndarray, 943 mode: str = "v" 944) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]: 945 """Returns the center coordinates of the foreground instances in the ground-truth. 946 947 Args: 948 segmentation: The segmentation. 949 mode: Determines the functionality used for computing the centers. 950 If 'v', the object's eccentricity centers computed by vigra are used. 951 If 'p' the object's centroids computed by skimage are used. 952 953 Returns: 954 A dictionary that maps object ids to the corresponding centroid. 955 A dictionary that maps object_ids to the corresponding bounding box. 956 """ 957 assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra" 958 959 properties = regionprops(segmentation) 960 961 if mode == "p": 962 center_coordinates = {prop.label: prop.centroid for prop in properties} 963 elif mode == "v": 964 center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32')) 965 center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0} 966 967 bbox_coordinates = {prop.label: prop.bbox for prop in properties} 968 969 assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}" 970 return center_coordinates, bbox_coordinates
Returns the center coordinates of the foreground instances in the ground-truth.
Arguments:
- segmentation: The segmentation.
- mode: Determines the functionality used for computing the centers.
- If 'v', the object's eccentricity centers computed by vigra are used.
- If 'p' the object's centroids computed by skimage are used.
Returns:
A dictionary that maps object ids to the corresponding centroid. A dictionary that maps object_ids to the corresponding bounding box.
973def load_image_data( 974 path: str, 975 key: Optional[str] = None, 976 lazy_loading: bool = False 977) -> np.ndarray: 978 """Helper function to load image data from file. 979 980 Args: 981 path: The filepath to the image data. 982 key: The internal filepath for complex data formats like hdf5. 983 lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data. 984 985 Returns: 986 The image data. 987 """ 988 if key is None: 989 image_data = imageio.imread(path) 990 else: 991 with open_file(path, mode="r") as f: 992 image_data = f[key] 993 if not lazy_loading: 994 image_data = image_data[:] 995 return image_data
Helper function to load image data from file.
Arguments:
- path: The filepath to the image data.
- key: The internal filepath for complex data formats like hdf5.
- lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
Returns:
The image data.
998def segmentation_to_one_hot( 999 segmentation: np.ndarray, 1000 segmentation_ids: Optional[np.ndarray] = None, 1001) -> torch.Tensor: 1002 """Convert the segmentation to one-hot encoded masks. 1003 1004 Args: 1005 segmentation: The segmentation. 1006 segmentation_ids: Optional subset of ids that will be used to subsample the masks. 1007 1008 Returns: 1009 The one-hot encoded masks. 1010 """ 1011 masks = segmentation.copy() 1012 if segmentation_ids is None: 1013 n_ids = int(segmentation.max()) 1014 1015 else: 1016 assert segmentation_ids[0] != 0, "No objects were found." 1017 1018 # the segmentation ids have to be sorted 1019 segmentation_ids = np.sort(segmentation_ids) 1020 1021 # set the non selected objects to zero and relabel sequentially 1022 masks[~np.isin(masks, segmentation_ids)] = 0 1023 masks = relabel_sequential(masks)[0] 1024 n_ids = len(segmentation_ids) 1025 1026 masks = torch.from_numpy(masks) 1027 1028 one_hot_shape = (n_ids + 1,) + masks.shape 1029 masks = masks.unsqueeze(0) # add dimension to scatter 1030 masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:] 1031 1032 # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W 1033 masks = masks.unsqueeze(1) 1034 return masks
Convert the segmentation to one-hot encoded masks.
Arguments:
- segmentation: The segmentation.
- segmentation_ids: Optional subset of ids that will be used to subsample the masks.
Returns:
The one-hot encoded masks.
1037def get_block_shape(shape: Tuple[int]) -> Tuple[int]: 1038 """Get a suitable block shape for chunking a given shape. 1039 1040 The primary use for this is determining chunk sizes for 1041 zarr arrays or block shapes for parallelization. 1042 1043 Args: 1044 shape: The image or volume shape. 1045 1046 Returns: 1047 The block shape. 1048 """ 1049 ndim = len(shape) 1050 if ndim == 2: 1051 block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape)) 1052 elif ndim == 3: 1053 block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape)) 1054 else: 1055 raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.") 1056 1057 return block_shape
Get a suitable block shape for chunking a given shape.
The primary use for this is determining chunk sizes for zarr arrays or block shapes for parallelization.
Arguments:
- shape: The image or volume shape.
Returns:
The block shape.