micro_sam.util
Helper functions for downloading Segment Anything models and predicting image embeddings.
1""" 2Helper functions for downloading Segment Anything models and predicting image embeddings. 3""" 4 5import os 6import pickle 7import hashlib 8import warnings 9from pathlib import Path 10from collections import OrderedDict 11from typing import Any, Dict, Iterable, Optional, Tuple, Union, Callable 12 13import zarr 14import vigra 15import torch 16import pooch 17import xxhash 18import numpy as np 19import imageio.v3 as imageio 20from skimage.measure import regionprops 21from skimage.segmentation import relabel_sequential 22 23from elf.io import open_file 24 25from nifty.tools import blocking 26 27from .__version__ import __version__ 28from . import models as custom_models 29 30try: 31 # Avoid import warnigns from mobile_sam 32 with warnings.catch_warnings(): 33 warnings.simplefilter("ignore") 34 from mobile_sam import sam_model_registry, SamPredictor 35 VIT_T_SUPPORT = True 36except ImportError: 37 from segment_anything import sam_model_registry, SamPredictor 38 VIT_T_SUPPORT = False 39 40try: 41 from napari.utils import progress as tqdm 42except ImportError: 43 from tqdm import tqdm 44 45# This is the default model used in micro_sam 46# Currently it is set to vit_b_lm 47_DEFAULT_MODEL = "vit_b_lm" 48 49# The valid model types. Each type corresponds to the architecture of the 50# vision transformer used within SAM. 51_MODEL_TYPES = ("vit_l", "vit_b", "vit_h", "vit_t") 52 53 54# TODO define the proper type for image embeddings 55ImageEmbeddings = Dict[str, Any] 56"""@private""" 57 58 59def get_cache_directory() -> None: 60 """Get micro-sam cache directory location. 61 62 Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory. 63 """ 64 default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam")) 65 cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory)) 66 return cache_directory 67 68 69# 70# Functionality for model download and export 71# 72 73 74def microsam_cachedir() -> None: 75 """Return the micro-sam cache directory. 76 77 Returns the top level cache directory for micro-sam models and sample data. 78 79 Every time this function is called, we check for any user updates made to 80 the MICROSAM_CACHEDIR os environment variable since the last time. 81 """ 82 cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam") 83 return cache_directory 84 85 86def models(): 87 """Return the segmentation models registry. 88 89 We recreate the model registry every time this function is called, 90 so any user changes to the default micro-sam cache directory location 91 are respected. 92 """ 93 94 # We use xxhash to compute the hash of the models, see 95 # https://github.com/computational-cell-analytics/micro-sam/issues/283 96 # (It is now a dependency, so we don't provide the sha256 fallback anymore.) 97 # To generate the xxh128 hash: 98 # xxh128sum filename 99 encoder_registry = { 100 # The default segment anything models: 101 "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a", 102 "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2", 103 "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b", 104 # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM. 105 "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0", 106 # The current version of our models in the modelzoo. 107 # LM generalist models: 108 "vit_l_lm": "xxh128:fc32ea6f7fcc7eb02737d1304f81f5f2", 109 "vit_b_lm": "xxh128:8fd5806be3c3ba213e19a709d6d1495f", 110 "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e", 111 # EM models: 112 "vit_l_em_organelles": "xxh128:096c9695966803ca6fde24f4c1e3c3fb", 113 "vit_b_em_organelles": "xxh128:f6f6593aeecd0e15a07bdac86360b6cc", 114 "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9", 115 # Histopathology models: 116 "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974", 117 "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2", 118 "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e", 119 # Medical Imaging models: 120 "vit_b_medical_imaging": "xxh128:5be672f1458263a9edc9fd40d7f56ac1", 121 } 122 # Additional decoders for instance segmentation. 123 decoder_registry = { 124 # LM generalist models: 125 "vit_l_lm_decoder": "xxh128:779b5a50ecc6d46d495753fba8717f2f", 126 "vit_b_lm_decoder": "xxh128:9f580a96984b3085389ced5d9a4ae75d", 127 "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823", 128 # EM models: 129 "vit_l_em_organelles_decoder": "xxh128:d60fd96bd6060856f6430f29e42568fb", 130 "vit_b_em_organelles_decoder": "xxh128:b2d4dcffb99f76d83497d39ee500088f", 131 "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44", 132 # Histopathology models: 133 "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213", 134 "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04", 135 "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47", 136 } 137 registry = {**encoder_registry, **decoder_registry} 138 139 encoder_urls = { 140 "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 141 "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 142 "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 143 "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", 144 "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l.pt", 145 "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b.pt", 146 "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt", 147 "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt", # noqa 148 "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt", 149 "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa 150 "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download", 151 "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download", 152 "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download", 153 "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/AB69HGhj8wuozXQ/download", 154 } 155 156 decoder_urls = { 157 "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l_decoder.pt", # noqa 158 "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b_decoder.pt", # noqa 159 "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt", # noqa 160 "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt", # noqa 161 "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt", # noqa 162 "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa 163 "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download", 164 "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download", 165 "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download", 166 } 167 urls = {**encoder_urls, **decoder_urls} 168 169 models = pooch.create( 170 path=os.path.join(microsam_cachedir(), "models"), 171 base_url="", 172 registry=registry, 173 urls=urls, 174 ) 175 return models 176 177 178def _get_default_device(): 179 # check that we're in CI and use the CPU if we are 180 # otherwise the tests may run out of memory on MAC if MPS is used. 181 if os.getenv("GITHUB_ACTIONS") == "true": 182 return "cpu" 183 # Use cuda enabled gpu if it's available. 184 if torch.cuda.is_available(): 185 device = "cuda" 186 # As second priority use mps. 187 # See https://pytorch.org/docs/stable/notes/mps.html for details 188 elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): 189 print("Using apple MPS device.") 190 device = "mps" 191 # Use the CPU as fallback. 192 else: 193 device = "cpu" 194 return device 195 196 197def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 198 """Get the torch device. 199 200 If no device is passed the default device for your system is used. 201 Else it will be checked if the device you have passed is supported. 202 203 Args: 204 device: The input device. 205 206 Returns: 207 The device. 208 """ 209 if device is None or device == "auto": 210 device = _get_default_device() 211 else: 212 device_type = device if isinstance(device, str) else device.type 213 if device_type.lower() == "cuda": 214 if not torch.cuda.is_available(): 215 raise RuntimeError("PyTorch CUDA backend is not available.") 216 elif device_type.lower() == "mps": 217 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 218 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 219 elif device_type.lower() == "cpu": 220 pass # cpu is always available 221 else: 222 raise RuntimeError( 223 f"Unsupported device: {device}\n" 224 "Please choose from 'cpu', 'cuda', or 'mps'." 225 ) 226 227 return device 228 229 230def _available_devices(): 231 available_devices = [] 232 for i in ["cuda", "mps", "cpu"]: 233 try: 234 device = get_device(i) 235 except RuntimeError: 236 pass 237 else: 238 available_devices.append(device) 239 return available_devices 240 241 242# We write a custom unpickler that skips objects that cannot be found instead of 243# throwing an AttributeError or ModueNotFoundError. 244# NOTE: since we just want to unpickle the model to load its weights these errors don't matter. 245# See also https://stackoverflow.com/questions/27732354/unable-to-load-files-using-pickle-and-multiple-modules 246class _CustomUnpickler(pickle.Unpickler): 247 def find_class(self, module, name): 248 try: 249 return super().find_class(module, name) 250 except (AttributeError, ModuleNotFoundError) as e: 251 warnings.warn(f"Did not find {module}:{name} and will skip it, due to error {e}") 252 return None 253 254 255def _compute_hash(path, chunk_size=8192): 256 hash_obj = xxhash.xxh128() 257 with open(path, "rb") as f: 258 chunk = f.read(chunk_size) 259 while chunk: 260 hash_obj.update(chunk) 261 chunk = f.read(chunk_size) 262 hash_val = hash_obj.hexdigest() 263 return f"xxh128:{hash_val}" 264 265 266# Load the state from a checkpoint. 267# The checkpoint can either contain a sam encoder state 268# or it can be a checkpoint for model finetuning. 269def _load_checkpoint(checkpoint_path): 270 # Over-ride the unpickler with our custom one. 271 # This enables imports from torch_em checkpoints even if it cannot be fully unpickled. 272 custom_pickle = pickle 273 custom_pickle.Unpickler = _CustomUnpickler 274 275 state = torch.load(checkpoint_path, map_location="cpu", pickle_module=custom_pickle) 276 if "model_state" in state: 277 # Copy the model weights from torch_em's training format. 278 model_state = state["model_state"] 279 sam_prefix = "sam." 280 model_state = OrderedDict( 281 [(k[len(sam_prefix):] if k.startswith(sam_prefix) else k, v) for k, v in model_state.items()] 282 ) 283 else: 284 model_state = state 285 286 return state, model_state 287 288 289def get_sam_model( 290 model_type: str = _DEFAULT_MODEL, 291 device: Optional[Union[str, torch.device]] = None, 292 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 293 return_sam: bool = False, 294 return_state: bool = False, 295 peft_kwargs: Optional[Dict] = None, 296 flexible_load_checkpoint: bool = False, 297 progress_bar_factory: Optional[Callable] = None, 298 **model_kwargs, 299) -> SamPredictor: 300 r"""Get the Segment Anything Predictor. 301 302 This function will download the required model or load it from the cached weight file. 303 This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. 304 The name of the requested model can be set via `model_type`. 305 See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models 306 for an overview of the available models 307 308 Alternatively this function can also load a model from weights stored in a local filepath. 309 The corresponding file path is given via `checkpoint_path`. In this case `model_type` 310 must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for 311 a SAM model with vit_b encoder. 312 313 By default the models are downloaded to a folder named 'micro_sam/models' 314 inside your default cache directory, eg: 315 * Mac: ~/Library/Caches/<AppName> 316 * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined. 317 * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache 318 See the pooch.os_cache() documentation for more details: 319 https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html 320 321 Args: 322 model_type: The Segment Anything model to use. Will use the standard `vit_l` model by default. 323 To get a list of all available model names you can call `get_model_names`. 324 device: The device for the model. If none is given will use GPU if available. 325 checkpoint_path: The path to a file with weights that should be used instead of using the 326 weights corresponding to `model_type`. If given, `model_type` must match the architecture 327 corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder 328 then `model_type` must be given as "vit_b". 329 return_sam: Return the sam model object as well as the predictor. 330 return_state: Return the unpickled checkpoint state. 331 peft_kwargs: Keyword arguments for th PEFT wrapper class. 332 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 333 model_kwargs: Additional parameters necessary to initialize the Segment Anything model. 334 progress_bar_factory: A function to create a progress bar for the model download. 335 336 Returns: 337 The segment anything predictor. 338 """ 339 device = get_device(device) 340 341 # We support passing a local filepath to a checkpoint. 342 # In this case we do not download any weights but just use the local weight file, 343 # as it is, without copying it over anywhere or checking it's hashes. 344 345 # checkpoint_path has not been passed, we download a known model and derive the correct 346 # URL from the model_type. If the model_type is invalid pooch will raise an error. 347 _provided_checkpoint_path = checkpoint_path is not None 348 if checkpoint_path is None: 349 model_registry = models() 350 351 progress_bar = True 352 # Check if we have to download the model. 353 # If we do and have a progress bar factory, then we over-write the progress bar. 354 if not os.path.exists(os.path.join(get_cache_directory(), model_type)) and progress_bar_factory is not None: 355 progress_bar = progress_bar_factory(model_type) 356 357 checkpoint_path = model_registry.fetch(model_type, progressbar=progress_bar) 358 if not isinstance(progress_bar, bool): # Close the progress bar when the task finishes. 359 progress_bar.close() 360 361 model_hash = model_registry.registry[model_type] 362 363 # If we have a custom model then we may also have a decoder checkpoint. 364 # Download it here, so that we can add it to the state. 365 decoder_name = f"{model_type}_decoder" 366 decoder_path = model_registry.fetch( 367 decoder_name, progressbar=True 368 ) if decoder_name in model_registry.registry else None 369 370 # checkpoint_path has been passed, we use it instead of downloading a model. 371 else: 372 # Check if the file exists and raise an error otherwise. 373 # We can't check any hashes here, and we don't check if the file is actually a valid weight file. 374 # (If it isn't the model creation will fail below.) 375 if not os.path.exists(checkpoint_path): 376 raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.") 377 model_hash = _compute_hash(checkpoint_path) 378 decoder_path = None 379 380 # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped 381 # before calling sam_model_registry. 382 abbreviated_model_type = model_type[:5] 383 if abbreviated_model_type not in _MODEL_TYPES: 384 raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") 385 if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: 386 raise RuntimeError( 387 "'mobile_sam' is required for the vit-tiny. " 388 "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" 389 ) 390 391 state, model_state = _load_checkpoint(checkpoint_path) 392 393 if _provided_checkpoint_path: 394 # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type' 395 # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination. 396 from micro_sam.models.build_sam import _validate_model_type 397 _provided_model_type = _validate_model_type(model_state) 398 399 # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type' 400 # Otherwise replace 'abbreviated_model_type' with the later. 401 if abbreviated_model_type != _provided_model_type: 402 # Printing the message below to avoid any filtering of warnings on user's end. 403 print( 404 f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', " 405 f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. " 406 f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. " 407 "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future." 408 ) 409 410 # Replace the extracted 'abbreviated_model_type' subjected to the model weights. 411 abbreviated_model_type = _provided_model_type 412 413 # Whether to update parameters necessary to initialize the model 414 if model_kwargs: # Checks whether model_kwargs have been provided or not 415 if abbreviated_model_type == "vit_t": 416 raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") 417 sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs) 418 419 else: 420 sam = sam_model_registry[abbreviated_model_type]() 421 422 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 423 # Overwrites the SAM model by freezing the backbone and allow PEFT. 424 if peft_kwargs and isinstance(peft_kwargs, dict): 425 # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference. 426 peft_kwargs.pop("quantize", None) 427 428 if abbreviated_model_type == "vit_t": 429 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 430 431 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 432 # In case the model checkpoints have some issues when it is initialized with different parameters than default. 433 if flexible_load_checkpoint: 434 sam = _handle_checkpoint_loading(sam, model_state) 435 else: 436 sam.load_state_dict(model_state) 437 sam.to(device=device) 438 439 predictor = SamPredictor(sam) 440 predictor.model_type = abbreviated_model_type 441 predictor._hash = model_hash 442 predictor.model_name = model_type 443 predictor.checkpoint_path = checkpoint_path 444 445 # Add the decoder to the state if we have one and if the state is returned. 446 if decoder_path is not None and return_state: 447 state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False) 448 449 if return_sam and return_state: 450 return predictor, sam, state 451 if return_sam: 452 return predictor, sam 453 if return_state: 454 return predictor, state 455 return predictor 456 457 458def _handle_checkpoint_loading(sam, model_state): 459 # Whether to handle the mismatch issues in a bit more elegant way. 460 # eg. while training for multi-class semantic segmentation in the mask encoder, 461 # parameters are updated - leading to "size mismatch" errors 462 463 new_state_dict = {} # for loading matching parameters 464 mismatched_layers = [] # for tracking mismatching parameters 465 466 reference_state = sam.state_dict() 467 468 for k, v in model_state.items(): 469 if k in reference_state: # This is done to get rid of unwanted layers from pretrained SAM. 470 if reference_state[k].size() == v.size(): 471 new_state_dict[k] = v 472 else: 473 mismatched_layers.append(k) 474 475 reference_state.update(new_state_dict) 476 477 if len(mismatched_layers) > 0: 478 warnings.warn(f"The layers with size mismatch: {mismatched_layers}") 479 480 for mlayer in mismatched_layers: 481 if 'weight' in mlayer: 482 torch.nn.init.kaiming_uniform_(reference_state[mlayer]) 483 elif 'bias' in mlayer: 484 reference_state[mlayer].zero_() 485 486 sam.load_state_dict(reference_state) 487 488 return sam 489 490 491def export_custom_sam_model( 492 checkpoint_path: Union[str, os.PathLike], 493 model_type: str, 494 save_path: Union[str, os.PathLike], 495 with_segmentation_decoder: bool = False, 496) -> None: 497 """Export a finetuned Segment Anything Model to the standard model format. 498 499 The exported model can be used by the interactive annotation tools in `micro_sam.annotator`. 500 501 Args: 502 checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. 503 model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). 504 save_path: Where to save the exported model. 505 with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. 506 If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'. 507 """ 508 _, state = get_sam_model( 509 model_type=model_type, checkpoint_path=checkpoint_path, return_state=True, device="cpu", 510 ) 511 model_state = state["model_state"] 512 prefix = "sam." 513 model_state = OrderedDict( 514 [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()] 515 ) 516 517 # Store the 'decoder_state' as well, if desired. 518 if with_segmentation_decoder: 519 if "decoder_state" not in state: 520 raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.") 521 decoder_state = state["decoder_state"] 522 save_state = {"model_state": model_state, "decoder_state": decoder_state} 523 else: 524 save_state = model_state 525 526 torch.save(save_state, save_path) 527 528 529def export_custom_qlora_model( 530 checkpoint_path: Optional[Union[str, os.PathLike]], 531 finetuned_path: Union[str, os.PathLike], 532 model_type: str, 533 save_path: Union[str, os.PathLike], 534) -> None: 535 """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format. 536 537 The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`. 538 539 Args: 540 checkpoint_path: The path to the base foundation model from which the new model has been finetuned. 541 finetuned_path: The path to the new finetuned model, using QLoRA. 542 model_type: The Segment Anything Model type corresponding to the checkpoint. 543 save_path: Where to save the exported model. 544 """ 545 # Step 1: Get the base SAM model: used to start finetuning from. 546 _, sam = get_sam_model( 547 model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True, 548 ) 549 550 # Step 2: Load the QLoRA-style finetuned model. 551 ft_state, ft_model_state = _load_checkpoint(finetuned_path) 552 553 # Step 3: Get LoRA weights from QLoRA and retain all original parameters from the base SAM model. 554 updated_model_state = {} 555 556 # - At first, we get all LoRA layers from the QLoRA-style finetuned model checkpoint. 557 for k, v in ft_model_state.items(): 558 if k.find("w_b_linear") != -1 or k.find("w_a_linear") != -1: 559 updated_model_state[k] = v 560 561 # - Next, we get all the remaining parameters from the base SAM model. 562 for k, v in sam.state_dict().items(): 563 if k.find("attn.qkv.") != -1: 564 k = k.replace("qkv", "qkv.qkv_proj") 565 updated_model_state[k] = v 566 else: 567 568 updated_model_state[k] = v 569 570 # - Finally, we replace the old model state with the new one (to retain other relevant stuff) 571 ft_state['model_state'] = updated_model_state 572 573 # Step 4: Store the new "state" to "save_path" 574 torch.save(ft_state, save_path) 575 576 577def get_model_names() -> Iterable: 578 model_registry = models() 579 model_names = model_registry.registry.keys() 580 return model_names 581 582 583# 584# Functionality for precomputing image embeddings. 585# 586 587 588def _to_image(input_): 589 # we require the input to be uint8 590 if input_.dtype != np.dtype("uint8"): 591 # first normalize the input to [0, 1] 592 input_ = input_.astype("float32") - input_.min() 593 input_ = input_ / input_.max() 594 # then bring to [0, 255] and cast to uint8 595 input_ = (input_ * 255).astype("uint8") 596 597 if input_.ndim == 2: 598 image = np.concatenate([input_[..., None]] * 3, axis=-1) 599 elif input_.ndim == 3 and input_.shape[-1] == 3: 600 image = input_ 601 else: 602 raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image.") 603 604 # explicitly return a numpy array for compatibility with torchvision 605 # because the input_ array could be something like dask array 606 return np.array(image) 607 608 609@torch.no_grad 610def _compute_embeddings_batched(predictor, batched_images): 611 predictor.reset_image() 612 batched_tensors, original_sizes, input_sizes = [], [], [] 613 614 # Apply proeprocessing to all images in the batch, and then stack them. 615 # Note: after the transformation the images are all of the same size, 616 # so they can be stacked and processed as a batch, even if the input images were of different size. 617 for image in batched_images: 618 tensor = predictor.transform.apply_image(image) 619 tensor = torch.as_tensor(tensor, device=predictor.device) 620 tensor = tensor.permute(2, 0, 1).contiguous()[None, :, :, :] 621 622 original_sizes.append(image.shape[:2]) 623 input_sizes.append(tensor.shape[-2:]) 624 625 tensor = predictor.model.preprocess(tensor) 626 batched_tensors.append(tensor) 627 628 batched_tensors = torch.cat(batched_tensors) 629 features = predictor.model.image_encoder(batched_tensors) 630 631 predictor.original_size = original_sizes[-1] 632 predictor.input_size = input_sizes[-1] 633 predictor.features = features[-1] 634 predictor.is_image_set = True 635 636 return features, original_sizes, input_sizes 637 638 639def _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size): 640 tiling = blocking([0, 0], input_.shape[:2], tile_shape) 641 n_tiles = tiling.numberOfBlocks 642 643 features = f.require_group("features") 644 features.attrs["shape"] = input_.shape[:2] 645 features.attrs["tile_shape"] = tile_shape 646 features.attrs["halo"] = halo 647 648 pbar_init(n_tiles, "Compute Image Embeddings 2D tiled") 649 650 n_batches = int(np.ceil(n_tiles / batch_size)) 651 for batch_id in range(n_batches): 652 tile_start = batch_id * batch_size 653 tile_stop = min(tile_start + batch_size, n_tiles) 654 655 batched_images = [] 656 for tile_id in range(tile_start, tile_stop): 657 tile = tiling.getBlockWithHalo(tile_id, list(halo)) 658 outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end)) 659 tile_input = _to_image(input_[outer_tile]) 660 batched_images.append(tile_input) 661 662 batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images) 663 for i, tile_id in enumerate(range(tile_start, tile_stop)): 664 tile_embeddings, original_size, input_size = batched_embeddings[i], original_sizes[i], input_sizes[i] 665 # Unsqueeze the channel axis of the tile embeddings. 666 tile_embeddings = tile_embeddings.unsqueeze(0) 667 ds = features.create_dataset( 668 str(tile_id), data=tile_embeddings.cpu().numpy(), compression="gzip", chunks=tile_embeddings.shape 669 ) 670 ds.attrs["original_size"] = original_size 671 ds.attrs["input_size"] = input_size 672 pbar_update(1) 673 674 _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None) 675 return features 676 677 678def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size): 679 assert input_.ndim == 3 680 681 shape = input_.shape[1:] 682 tiling = blocking([0, 0], shape, tile_shape) 683 n_tiles = tiling.numberOfBlocks 684 685 features = f.require_group("features") 686 features.attrs["shape"] = shape 687 features.attrs["tile_shape"] = tile_shape 688 features.attrs["halo"] = halo 689 690 n_slices = input_.shape[0] 691 pbar_init(n_tiles * n_slices, "Compute Image Embeddings 3D tiled") 692 693 # We batch across the z axis. 694 n_batches = int(np.ceil(n_slices / batch_size)) 695 696 for tile_id in range(n_tiles): 697 tile = tiling.getBlockWithHalo(tile_id, list(halo)) 698 outer_tile = tuple(slice(beg, end) for beg, end in zip(tile.outerBlock.begin, tile.outerBlock.end)) 699 700 ds = None 701 for batch_id in range(n_batches): 702 z_start = batch_id * batch_size 703 z_stop = min(z_start + batch_size, n_slices) 704 705 batched_images = [] 706 for z in range(z_start, z_stop): 707 tile_input = _to_image(input_[z][outer_tile]) 708 batched_images.append(tile_input) 709 710 batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images) 711 for i, z in enumerate(range(z_start, z_stop)): 712 tile_embeddings = batched_embeddings[i].unsqueeze(0) 713 if ds is None: 714 shape = (n_slices,) + tile_embeddings.shape 715 chunks = (1,) + tile_embeddings.shape 716 ds = features.create_dataset( 717 str(tile_id), shape=shape, dtype="float32", compression="gzip", chunks=chunks 718 ) 719 720 ds[z] = tile_embeddings.cpu().numpy() 721 pbar_update(1) 722 723 ds.attrs["original_size"] = original_sizes[-1] 724 ds.attrs["input_size"] = input_sizes[-1] 725 726 _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None) 727 728 return features 729 730 731def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update): 732 # Check if the embeddings are already cached. 733 if save_path is not None and "input_size" in f.attrs: 734 # In this case we load the embeddings. 735 features = f["features"][:] 736 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 737 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 738 # Also set the embeddings. 739 set_precomputed(predictor, image_embeddings) 740 return image_embeddings 741 742 pbar_init(1, "Compute Image Embeddings 2D") 743 # Otherwise we have to compute the embeddings. 744 predictor.reset_image() 745 predictor.set_image(_to_image(input_)) 746 features = predictor.get_image_embedding().cpu().numpy() 747 original_size = predictor.original_size 748 input_size = predictor.input_size 749 pbar_update(1) 750 751 # Save the embeddings if we have a save_path. 752 if save_path is not None: 753 f.create_dataset("features", data=features, compression="gzip", chunks=features.shape) 754 _write_embedding_signature( 755 f, input_, predictor, tile_shape=None, halo=None, input_size=input_size, original_size=original_size, 756 ) 757 758 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 759 return image_embeddings 760 761 762def _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size): 763 # Check if the features are already computed. 764 if "input_size" in f.attrs: 765 features = f["features"] 766 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 767 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 768 return image_embeddings 769 770 # Otherwise compute them. Note: saving happens automatically because we 771 # always write the features to zarr. If no save path is given we use an in-memory zarr. 772 features = _compute_tiled_features_2d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size) 773 image_embeddings = {"features": features, "input_size": None, "original_size": None} 774 return image_embeddings 775 776 777def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size): 778 # Check if the embeddings are already fully cached. 779 if save_path is not None and "input_size" in f.attrs: 780 # In this case we load the embeddings. 781 features = f["features"] if lazy_loading else f["features"][:] 782 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 783 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 784 return image_embeddings 785 786 # Otherwise we have to compute the embeddings. 787 788 # First check if we have a save path or not and set things up accordingly. 789 if save_path is None: 790 features = [] 791 save_features = False 792 partial_features = False 793 else: 794 save_features = True 795 embed_shape = (1, 256, 64, 64) 796 shape = (input_.shape[0],) + embed_shape 797 chunks = (1,) + embed_shape 798 if "features" in f: 799 partial_features = True 800 features = f["features"] 801 if features.shape != shape or features.chunks != chunks: 802 raise RuntimeError("Invalid partial features") 803 else: 804 partial_features = False 805 features = f.create_dataset("features", shape=shape, chunks=chunks, dtype="float32") 806 807 # Initialize the pbar and batches. 808 n_slices = input_.shape[0] 809 pbar_init(n_slices, "Compute Image Embeddings 3D") 810 n_batches = int(np.ceil(n_slices / batch_size)) 811 812 for batch_id in range(n_batches): 813 z_start = batch_id * batch_size 814 z_stop = min(z_start + batch_size, n_slices) 815 816 batched_images, batched_z = [], [] 817 for z in range(z_start, z_stop): 818 # Skip feature computation in case of partial features in non-zero slice. 819 if partial_features and np.count_nonzero(features[z]) != 0: 820 continue 821 tile_input = _to_image(input_[z]) 822 batched_images.append(tile_input) 823 batched_z.append(z) 824 825 batched_embeddings, original_sizes, input_sizes = _compute_embeddings_batched(predictor, batched_images) 826 827 for z, embedding in zip(batched_z, batched_embeddings): 828 embedding = embedding.unsqueeze(0) 829 if save_features: 830 features[z] = embedding.cpu().numpy() 831 else: 832 features.append(embedding.unsqueeze(0)) 833 pbar_update(1) 834 835 if save_features: 836 _write_embedding_signature( 837 f, input_, predictor, tile_shape=None, halo=None, 838 input_size=input_sizes[-1], original_size=original_sizes[-1], 839 ) 840 else: 841 # Concatenate across the z axis. 842 features = torch.cat(features).cpu().numpy() 843 844 image_embeddings = {"features": features, "input_size": input_sizes[-1], "original_size": original_sizes[-1]} 845 return image_embeddings 846 847 848def _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size): 849 # Check if the features are already computed. 850 if "input_size" in f.attrs: 851 features = f["features"] 852 original_size, input_size = f.attrs["original_size"], f.attrs["input_size"] 853 image_embeddings = {"features": features, "input_size": input_size, "original_size": original_size} 854 return image_embeddings 855 856 # Otherwise compute them. Note: saving happens automatically because we 857 # always write the features to zarr. If no save path is given we use an in-memory zarr. 858 features = _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init, pbar_update, batch_size) 859 image_embeddings = {"features": features, "input_size": None, "original_size": None} 860 return image_embeddings 861 862 863def _compute_data_signature(input_): 864 data_signature = hashlib.sha1(np.asarray(input_).tobytes()).hexdigest() 865 return data_signature 866 867 868# Create all metadata that is stored along with the embeddings. 869def _get_embedding_signature(input_, predictor, tile_shape, halo, data_signature=None): 870 if data_signature is None: 871 data_signature = _compute_data_signature(input_) 872 873 signature = { 874 "data_signature": data_signature, 875 "tile_shape": tile_shape if tile_shape is None else list(tile_shape), 876 "halo": halo if halo is None else list(halo), 877 "model_type": predictor.model_type, 878 "model_name": predictor.model_name, 879 "micro_sam_version": __version__, 880 "model_hash": getattr(predictor, "_hash", None), 881 } 882 return signature 883 884 885# Note: the input size and orginal size are different if embeddings are tiled or not. 886# That's why we do not include them in the main signature that is being checked 887# (_get_embedding_signature), but just add it for serialization here. 888def _write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size, original_size): 889 signature = _get_embedding_signature(input_, predictor, tile_shape, halo) 890 signature.update({"input_size": input_size, "original_size": original_size}) 891 for key, val in signature.items(): 892 f.attrs[key] = val 893 894 895def _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo): 896 # We may have an empty zarr file that was already created to save the embeddings in. 897 # In this case the embeddings will be computed and we don't need to perform any checks. 898 if "input_size" not in f.attrs: 899 return 900 901 signature = _get_embedding_signature(input_, predictor, tile_shape, halo) 902 for key, val in signature.items(): 903 # Check whether the key is missing from the attrs or if the value is not matching. 904 if key not in f.attrs or f.attrs[key] != val: 905 # These keys were recently added, so we don't want to fail yet if they don't 906 # match in order to not invalidate previous embedding files. 907 # Instead we just raise a warning. (For the version we probably also don't want to fail 908 # i the future since it should not invalidate the embeddings). 909 if key in ("micro_sam_version", "model_hash", "model_name"): 910 warnings.warn( 911 f"The signature for {key} in embeddings file {save_path} has a mismatch: " 912 f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. " 913 "But please recompute them if model predictions don't look as expected." 914 ) 915 else: 916 raise RuntimeError( 917 f"Embeddings file {save_path} is invalid due to mismatch in {key}: " 918 f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file." 919 ) 920 921 922# Helper function for optional external progress bars. 923def handle_pbar(verbose, pbar_init, pbar_update): 924 """@private""" 925 926 # Noop to provide dummy functions. 927 def noop(*args): 928 pass 929 930 if verbose and pbar_init is None: # we are verbose and don't have an external progress bar. 931 assert pbar_update is None # avoid inconsistent state of callbacks 932 933 # Create our own progress bar and callbacks 934 pbar = tqdm() 935 936 def pbar_init(total, description): 937 pbar.total = total 938 pbar.set_description(description) 939 940 def pbar_update(update): 941 pbar.update(update) 942 943 def pbar_close(): 944 pbar.close() 945 946 elif verbose and pbar_init is not None: # external pbar -> we don't have to do anything 947 assert pbar_update is not None 948 pbar = None 949 pbar_close = noop 950 951 else: # we are not verbose, do nothing 952 pbar = None 953 pbar_init, pbar_update, pbar_close = noop, noop, noop 954 955 return pbar, pbar_init, pbar_update, pbar_close 956 957 958def precompute_image_embeddings( 959 predictor: SamPredictor, 960 input_: np.ndarray, 961 save_path: Optional[Union[str, os.PathLike]] = None, 962 lazy_loading: bool = False, 963 ndim: Optional[int] = None, 964 tile_shape: Optional[Tuple[int, int]] = None, 965 halo: Optional[Tuple[int, int]] = None, 966 verbose: bool = True, 967 batch_size: int = 1, 968 pbar_init: Optional[callable] = None, 969 pbar_update: Optional[callable] = None, 970) -> ImageEmbeddings: 971 """Compute the image embeddings (output of the encoder) for the input. 972 973 If 'save_path' is given the embeddings will be loaded/saved in a zarr container. 974 975 Args: 976 predictor: The Segment Anything predictor. 977 input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries. 978 save_path: Path to save the embeddings in a zarr container. 979 lazy_loading: Whether to load all embeddings into memory or return an 980 object to load them on demand when required. This only has an effect if 'save_path' is given 981 and if the input is 3 dimensional. 982 ndim: The dimensionality of the data. If not given will be deduced from the input data. 983 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 984 halo: Overlap of the tiles for tiled prediction. 985 verbose: Whether to be verbose in the computation. 986 batch_size: The batch size for precomputing image embeddings over tiles (or planes). 987 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 988 Can be used together with pbar_update to handle napari progress bar in other thread. 989 To enables using this function within a threadworker. 990 pbar_update: Callback to update an external progress bar. 991 992 Returns: 993 The image embeddings. 994 """ 995 ndim = input_.ndim if ndim is None else ndim 996 997 # Handle the embedding save_path. 998 # We don't have a save path, open in memory zarr file to hold tiled embeddings. 999 if save_path is None: 1000 f = zarr.group() 1001 1002 # We have a save path and it already exists. Embeddings will be loaded from it, 1003 # check that the saved embeddings in there match the parameters of the function call. 1004 elif os.path.exists(save_path): 1005 f = zarr.open(save_path, "a") 1006 _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) 1007 1008 # We have a save path and it does not exist yet. Create the zarr file to which the 1009 # embeddings will then be saved. 1010 else: 1011 f = zarr.open(save_path, "a") 1012 1013 _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update) 1014 1015 if ndim == 2 and tile_shape is None: 1016 embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update) 1017 elif ndim == 2 and tile_shape is not None: 1018 embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size) 1019 elif ndim == 3 and tile_shape is None: 1020 embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size) 1021 elif ndim == 3 and tile_shape is not None: 1022 embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size) 1023 else: 1024 raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.") 1025 1026 pbar_close() 1027 return embeddings 1028 1029 1030def set_precomputed( 1031 predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None, 1032) -> SamPredictor: 1033 """Set the precomputed image embeddings for a predictor. 1034 1035 Args: 1036 predictor: The Segment Anything predictor. 1037 image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`. 1038 i: Index for the image data. Required if `image` has three spatial dimensions 1039 or a time dimension and two spatial dimensions. 1040 tile_id: Index for the tile. This is required if the embeddings are tiled. 1041 1042 Returns: 1043 The predictor with set features. 1044 """ 1045 if tile_id is not None: 1046 tile_features = image_embeddings["features"][tile_id] 1047 tile_image_embeddings = { 1048 "features": tile_features, 1049 "input_size": tile_features.attrs["input_size"], 1050 "original_size": tile_features.attrs["original_size"] 1051 } 1052 return set_precomputed(predictor, tile_image_embeddings, i=i) 1053 1054 device = predictor.device 1055 features = image_embeddings["features"] 1056 assert features.ndim in (4, 5), f"{features.ndim}" 1057 if features.ndim == 5 and i is None: 1058 raise ValueError("The data is 3D so an index i is needed.") 1059 elif features.ndim == 4 and i is not None: 1060 raise ValueError("The data is 2D so an index is not needed.") 1061 1062 if i is None: 1063 predictor.features = features.to(device) if torch.is_tensor(features) else \ 1064 torch.from_numpy(features[:]).to(device) 1065 else: 1066 predictor.features = features[i].to(device) if torch.is_tensor(features) else \ 1067 torch.from_numpy(features[i]).to(device) 1068 1069 predictor.original_size = image_embeddings["original_size"] 1070 predictor.input_size = image_embeddings["input_size"] 1071 predictor.is_image_set = True 1072 1073 return predictor 1074 1075 1076# 1077# Misc functionality 1078# 1079 1080 1081def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: 1082 """Compute the intersection over union of two masks. 1083 1084 Args: 1085 mask1: The first mask. 1086 mask2: The second mask. 1087 1088 Returns: 1089 The intersection over union of the two masks. 1090 """ 1091 overlap = np.logical_and(mask1 == 1, mask2 == 1).sum() 1092 union = np.logical_or(mask1 == 1, mask2 == 1).sum() 1093 eps = 1e-7 1094 iou = float(overlap) / (float(union) + eps) 1095 return iou 1096 1097 1098def get_centers_and_bounding_boxes( 1099 segmentation: np.ndarray, mode: str = "v" 1100) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]: 1101 """Returns the center coordinates of the foreground instances in the ground-truth. 1102 1103 Args: 1104 segmentation: The segmentation. 1105 mode: Determines the functionality used for computing the centers. 1106 If 'v', the object's eccentricity centers computed by vigra are used. 1107 If 'p' the object's centroids computed by skimage are used. 1108 1109 Returns: 1110 A dictionary that maps object ids to the corresponding centroid. 1111 A dictionary that maps object_ids to the corresponding bounding box. 1112 """ 1113 assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra" 1114 1115 properties = regionprops(segmentation) 1116 1117 if mode == "p": 1118 center_coordinates = {prop.label: prop.centroid for prop in properties} 1119 elif mode == "v": 1120 center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32')) 1121 center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0} 1122 1123 bbox_coordinates = {prop.label: prop.bbox for prop in properties} 1124 1125 assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}" 1126 return center_coordinates, bbox_coordinates 1127 1128 1129def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray: 1130 """Helper function to load image data from file. 1131 1132 Args: 1133 path: The filepath to the image data. 1134 key: The internal filepath for complex data formats like hdf5. 1135 lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data. 1136 1137 Returns: 1138 The image data. 1139 """ 1140 if key is None: 1141 image_data = imageio.imread(path) 1142 else: 1143 with open_file(path, mode="r") as f: 1144 image_data = f[key] 1145 if not lazy_loading: 1146 image_data = image_data[:] 1147 1148 return image_data 1149 1150 1151def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor: 1152 """Convert the segmentation to one-hot encoded masks. 1153 1154 Args: 1155 segmentation: The segmentation. 1156 segmentation_ids: Optional subset of ids that will be used to subsample the masks. 1157 1158 Returns: 1159 The one-hot encoded masks. 1160 """ 1161 masks = segmentation.copy() 1162 if segmentation_ids is None: 1163 n_ids = int(segmentation.max()) 1164 1165 else: 1166 msg = "No foreground objects were found." 1167 if len(segmentation_ids) == 0: # The list should not be completely empty. 1168 raise RuntimeError(msg) 1169 1170 if 0 in segmentation_ids: # The list should not have 'zero' as a value. 1171 raise RuntimeError(msg) 1172 1173 # the segmentation ids have to be sorted 1174 segmentation_ids = np.sort(segmentation_ids) 1175 1176 # set the non selected objects to zero and relabel sequentially 1177 masks[~np.isin(masks, segmentation_ids)] = 0 1178 masks = relabel_sequential(masks)[0] 1179 n_ids = len(segmentation_ids) 1180 1181 masks = torch.from_numpy(masks) 1182 1183 one_hot_shape = (n_ids + 1,) + masks.shape 1184 masks = masks.unsqueeze(0) # add dimension to scatter 1185 masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:] 1186 1187 # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W 1188 masks = masks.unsqueeze(1) 1189 return masks 1190 1191 1192def get_block_shape(shape: Tuple[int]) -> Tuple[int]: 1193 """Get a suitable block shape for chunking a given shape. 1194 1195 The primary use for this is determining chunk sizes for 1196 zarr arrays or block shapes for parallelization. 1197 1198 Args: 1199 shape: The image or volume shape. 1200 1201 Returns: 1202 The block shape. 1203 """ 1204 ndim = len(shape) 1205 if ndim == 2: 1206 block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape)) 1207 elif ndim == 3: 1208 block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape)) 1209 else: 1210 raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.") 1211 1212 return block_shape 1213 1214 1215def micro_sam_info(): 1216 """Display μSAM information using a rich console.""" 1217 import psutil 1218 import platform 1219 from rich.panel import Panel 1220 from rich.table import Table 1221 from rich.console import Console 1222 1223 import torch 1224 import micro_sam 1225 1226 # Open up a new console. 1227 console = Console() 1228 1229 # The header for information CLI. 1230 console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center") 1231 console.print("-" * console.width) 1232 1233 # μSAM version panel. 1234 console.print( 1235 Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True) 1236 ) 1237 1238 # The documentation link panel. 1239 console.print( 1240 Panel( 1241 "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n" 1242 "https://computational-cell-analytics.github.io/micro-sam", title="Documentation" 1243 ) 1244 ) 1245 1246 # The publication panel. 1247 console.print( 1248 Panel( 1249 "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n" 1250 "https://www.nature.com/articles/s41592-024-02580-4", title="Publication" 1251 ) 1252 ) 1253 1254 # The cache directory panel. 1255 console.print( 1256 Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{get_cache_directory()}", title="Cache Directory") 1257 ) 1258 1259 # The available models panel. 1260 available_models = list(get_model_names()) 1261 # We filter out the decoder models. 1262 available_models = [m for m in available_models if not m.endswith("_decoder")] 1263 model_list = "\n".join(available_models) 1264 console.print( 1265 Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models") 1266 ) 1267 1268 # The system information table. 1269 total_memory = psutil.virtual_memory().total / (1024 ** 3) 1270 table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True) 1271 table.add_column("Property") 1272 table.add_column("Value", style="bold #56B4E9") 1273 table.add_row("System", platform.system()) 1274 table.add_row("Node Name", platform.node()) 1275 table.add_row("Release", platform.release()) 1276 table.add_row("Version", platform.version()) 1277 table.add_row("Machine", platform.machine()) 1278 table.add_row("Processor", platform.processor()) 1279 table.add_row("Platform", platform.platform()) 1280 table.add_row("Total RAM (GB)", f"{total_memory:.2f}") 1281 console.print(table) 1282 1283 # The device information and check for available GPU acceleration. 1284 default_device = _get_default_device() 1285 1286 if default_device == "cuda": 1287 device_index = torch.cuda.current_device() 1288 device_name = torch.cuda.get_device_name(device_index) 1289 console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information")) 1290 elif default_device == "mps": 1291 console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information")) 1292 else: 1293 console.print( 1294 Panel( 1295 "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]", 1296 title="Device Information" 1297 ) 1298 )
60def get_cache_directory() -> None: 61 """Get micro-sam cache directory location. 62 63 Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory. 64 """ 65 default_cache_directory = os.path.expanduser(pooch.os_cache("micro_sam")) 66 cache_directory = Path(os.environ.get("MICROSAM_CACHEDIR", default_cache_directory)) 67 return cache_directory
Get micro-sam cache directory location.
Users can set the MICROSAM_CACHEDIR environment variable for a custom cache directory.
75def microsam_cachedir() -> None: 76 """Return the micro-sam cache directory. 77 78 Returns the top level cache directory for micro-sam models and sample data. 79 80 Every time this function is called, we check for any user updates made to 81 the MICROSAM_CACHEDIR os environment variable since the last time. 82 """ 83 cache_directory = os.environ.get("MICROSAM_CACHEDIR") or pooch.os_cache("micro_sam") 84 return cache_directory
Return the micro-sam cache directory.
Returns the top level cache directory for micro-sam models and sample data.
Every time this function is called, we check for any user updates made to the MICROSAM_CACHEDIR os environment variable since the last time.
87def models(): 88 """Return the segmentation models registry. 89 90 We recreate the model registry every time this function is called, 91 so any user changes to the default micro-sam cache directory location 92 are respected. 93 """ 94 95 # We use xxhash to compute the hash of the models, see 96 # https://github.com/computational-cell-analytics/micro-sam/issues/283 97 # (It is now a dependency, so we don't provide the sha256 fallback anymore.) 98 # To generate the xxh128 hash: 99 # xxh128sum filename 100 encoder_registry = { 101 # The default segment anything models: 102 "vit_l": "xxh128:a82beb3c660661e3dd38d999cc860e9a", 103 "vit_h": "xxh128:97698fac30bd929c2e6d8d8cc15933c2", 104 "vit_b": "xxh128:6923c33df3637b6a922d7682bfc9a86b", 105 # The model with vit tiny backend fom https://github.com/ChaoningZhang/MobileSAM. 106 "vit_t": "xxh128:8eadbc88aeb9d8c7e0b4b60c3db48bd0", 107 # The current version of our models in the modelzoo. 108 # LM generalist models: 109 "vit_l_lm": "xxh128:fc32ea6f7fcc7eb02737d1304f81f5f2", 110 "vit_b_lm": "xxh128:8fd5806be3c3ba213e19a709d6d1495f", 111 "vit_t_lm": "xxh128:72ec5074774761a6e5c05a08942f981e", 112 # EM models: 113 "vit_l_em_organelles": "xxh128:096c9695966803ca6fde24f4c1e3c3fb", 114 "vit_b_em_organelles": "xxh128:f6f6593aeecd0e15a07bdac86360b6cc", 115 "vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9", 116 # Histopathology models: 117 "vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974", 118 "vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2", 119 "vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e", 120 # Medical Imaging models: 121 "vit_b_medical_imaging": "xxh128:5be672f1458263a9edc9fd40d7f56ac1", 122 } 123 # Additional decoders for instance segmentation. 124 decoder_registry = { 125 # LM generalist models: 126 "vit_l_lm_decoder": "xxh128:779b5a50ecc6d46d495753fba8717f2f", 127 "vit_b_lm_decoder": "xxh128:9f580a96984b3085389ced5d9a4ae75d", 128 "vit_t_lm_decoder": "xxh128:3e914a5f397b0312cdd36813031f8823", 129 # EM models: 130 "vit_l_em_organelles_decoder": "xxh128:d60fd96bd6060856f6430f29e42568fb", 131 "vit_b_em_organelles_decoder": "xxh128:b2d4dcffb99f76d83497d39ee500088f", 132 "vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44", 133 # Histopathology models: 134 "vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213", 135 "vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04", 136 "vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47", 137 } 138 registry = {**encoder_registry, **decoder_registry} 139 140 encoder_urls = { 141 "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 142 "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 143 "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 144 "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", 145 "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l.pt", 146 "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b.pt", 147 "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t.pt", 148 "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt", # noqa 149 "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt", 150 "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa 151 "vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download", 152 "vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download", 153 "vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download", 154 "vit_b_medical_imaging": "https://owncloud.gwdg.de/index.php/s/AB69HGhj8wuozXQ/download", 155 } 156 157 decoder_urls = { 158 "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1.1/files/vit_l_decoder.pt", # noqa 159 "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1.1/files/vit_b_decoder.pt", # noqa 160 "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1.1/files/vit_t_decoder.pt", # noqa 161 "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt", # noqa 162 "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt", # noqa 163 "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa 164 "vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download", 165 "vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download", 166 "vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download", 167 } 168 urls = {**encoder_urls, **decoder_urls} 169 170 models = pooch.create( 171 path=os.path.join(microsam_cachedir(), "models"), 172 base_url="", 173 registry=registry, 174 urls=urls, 175 ) 176 return models
Return the segmentation models registry.
We recreate the model registry every time this function is called, so any user changes to the default micro-sam cache directory location are respected.
198def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: 199 """Get the torch device. 200 201 If no device is passed the default device for your system is used. 202 Else it will be checked if the device you have passed is supported. 203 204 Args: 205 device: The input device. 206 207 Returns: 208 The device. 209 """ 210 if device is None or device == "auto": 211 device = _get_default_device() 212 else: 213 device_type = device if isinstance(device, str) else device.type 214 if device_type.lower() == "cuda": 215 if not torch.cuda.is_available(): 216 raise RuntimeError("PyTorch CUDA backend is not available.") 217 elif device_type.lower() == "mps": 218 if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): 219 raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") 220 elif device_type.lower() == "cpu": 221 pass # cpu is always available 222 else: 223 raise RuntimeError( 224 f"Unsupported device: {device}\n" 225 "Please choose from 'cpu', 'cuda', or 'mps'." 226 ) 227 228 return device
Get the torch device.
If no device is passed the default device for your system is used. Else it will be checked if the device you have passed is supported.
Arguments:
- device: The input device.
Returns:
The device.
290def get_sam_model( 291 model_type: str = _DEFAULT_MODEL, 292 device: Optional[Union[str, torch.device]] = None, 293 checkpoint_path: Optional[Union[str, os.PathLike]] = None, 294 return_sam: bool = False, 295 return_state: bool = False, 296 peft_kwargs: Optional[Dict] = None, 297 flexible_load_checkpoint: bool = False, 298 progress_bar_factory: Optional[Callable] = None, 299 **model_kwargs, 300) -> SamPredictor: 301 r"""Get the Segment Anything Predictor. 302 303 This function will download the required model or load it from the cached weight file. 304 This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR. 305 The name of the requested model can be set via `model_type`. 306 See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models 307 for an overview of the available models 308 309 Alternatively this function can also load a model from weights stored in a local filepath. 310 The corresponding file path is given via `checkpoint_path`. In this case `model_type` 311 must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for 312 a SAM model with vit_b encoder. 313 314 By default the models are downloaded to a folder named 'micro_sam/models' 315 inside your default cache directory, eg: 316 * Mac: ~/Library/Caches/<AppName> 317 * Unix: ~/.cache/<AppName> or the value of the XDG_CACHE_HOME environment variable, if defined. 318 * Windows: C:\Users\<user>\AppData\Local\<AppAuthor>\<AppName>\Cache 319 See the pooch.os_cache() documentation for more details: 320 https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html 321 322 Args: 323 model_type: The Segment Anything model to use. Will use the standard `vit_l` model by default. 324 To get a list of all available model names you can call `get_model_names`. 325 device: The device for the model. If none is given will use GPU if available. 326 checkpoint_path: The path to a file with weights that should be used instead of using the 327 weights corresponding to `model_type`. If given, `model_type` must match the architecture 328 corresponding to the weight file. e.g. if you use weights for SAM with `vit_b` encoder 329 then `model_type` must be given as "vit_b". 330 return_sam: Return the sam model object as well as the predictor. 331 return_state: Return the unpickled checkpoint state. 332 peft_kwargs: Keyword arguments for th PEFT wrapper class. 333 flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. 334 model_kwargs: Additional parameters necessary to initialize the Segment Anything model. 335 progress_bar_factory: A function to create a progress bar for the model download. 336 337 Returns: 338 The segment anything predictor. 339 """ 340 device = get_device(device) 341 342 # We support passing a local filepath to a checkpoint. 343 # In this case we do not download any weights but just use the local weight file, 344 # as it is, without copying it over anywhere or checking it's hashes. 345 346 # checkpoint_path has not been passed, we download a known model and derive the correct 347 # URL from the model_type. If the model_type is invalid pooch will raise an error. 348 _provided_checkpoint_path = checkpoint_path is not None 349 if checkpoint_path is None: 350 model_registry = models() 351 352 progress_bar = True 353 # Check if we have to download the model. 354 # If we do and have a progress bar factory, then we over-write the progress bar. 355 if not os.path.exists(os.path.join(get_cache_directory(), model_type)) and progress_bar_factory is not None: 356 progress_bar = progress_bar_factory(model_type) 357 358 checkpoint_path = model_registry.fetch(model_type, progressbar=progress_bar) 359 if not isinstance(progress_bar, bool): # Close the progress bar when the task finishes. 360 progress_bar.close() 361 362 model_hash = model_registry.registry[model_type] 363 364 # If we have a custom model then we may also have a decoder checkpoint. 365 # Download it here, so that we can add it to the state. 366 decoder_name = f"{model_type}_decoder" 367 decoder_path = model_registry.fetch( 368 decoder_name, progressbar=True 369 ) if decoder_name in model_registry.registry else None 370 371 # checkpoint_path has been passed, we use it instead of downloading a model. 372 else: 373 # Check if the file exists and raise an error otherwise. 374 # We can't check any hashes here, and we don't check if the file is actually a valid weight file. 375 # (If it isn't the model creation will fail below.) 376 if not os.path.exists(checkpoint_path): 377 raise ValueError(f"Checkpoint at {checkpoint_path} could not be found.") 378 model_hash = _compute_hash(checkpoint_path) 379 decoder_path = None 380 381 # Our fine-tuned model types have a suffix "_...". This suffix needs to be stripped 382 # before calling sam_model_registry. 383 abbreviated_model_type = model_type[:5] 384 if abbreviated_model_type not in _MODEL_TYPES: 385 raise ValueError(f"Invalid model_type: {abbreviated_model_type}. Expect one of {_MODEL_TYPES}") 386 if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT: 387 raise RuntimeError( 388 "'mobile_sam' is required for the vit-tiny. " 389 "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'" 390 ) 391 392 state, model_state = _load_checkpoint(checkpoint_path) 393 394 if _provided_checkpoint_path: 395 # To get the model weights, we prioritize having the correct 'checkpoint_path' over 'model_type' 396 # It is done to avoid strange parameter mismatch issues while incompatible model type and weights combination. 397 from micro_sam.models.build_sam import _validate_model_type 398 _provided_model_type = _validate_model_type(model_state) 399 400 # Verify whether the 'abbreviated_model_type' matches the '_provided_model_type' 401 # Otherwise replace 'abbreviated_model_type' with the later. 402 if abbreviated_model_type != _provided_model_type: 403 # Printing the message below to avoid any filtering of warnings on user's end. 404 print( 405 f"CRITICAL WARNING: The chosen 'model_type' is '{abbreviated_model_type}', " 406 f"however the model checkpoint provided correspond to '{_provided_model_type}', which does not match. " 407 f"We internally switch the model type to the expected value, i.e. '{_provided_model_type}'. " 408 "However, please avoid mismatching combination of 'model_type' and 'checkpoint_path' in future." 409 ) 410 411 # Replace the extracted 'abbreviated_model_type' subjected to the model weights. 412 abbreviated_model_type = _provided_model_type 413 414 # Whether to update parameters necessary to initialize the model 415 if model_kwargs: # Checks whether model_kwargs have been provided or not 416 if abbreviated_model_type == "vit_t": 417 raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") 418 sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs) 419 420 else: 421 sam = sam_model_registry[abbreviated_model_type]() 422 423 # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. 424 # Overwrites the SAM model by freezing the backbone and allow PEFT. 425 if peft_kwargs and isinstance(peft_kwargs, dict): 426 # NOTE: We bump out 'quantize' parameter, if found, as we do not quantize in inference. 427 peft_kwargs.pop("quantize", None) 428 429 if abbreviated_model_type == "vit_t": 430 raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") 431 432 sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam 433 # In case the model checkpoints have some issues when it is initialized with different parameters than default. 434 if flexible_load_checkpoint: 435 sam = _handle_checkpoint_loading(sam, model_state) 436 else: 437 sam.load_state_dict(model_state) 438 sam.to(device=device) 439 440 predictor = SamPredictor(sam) 441 predictor.model_type = abbreviated_model_type 442 predictor._hash = model_hash 443 predictor.model_name = model_type 444 predictor.checkpoint_path = checkpoint_path 445 446 # Add the decoder to the state if we have one and if the state is returned. 447 if decoder_path is not None and return_state: 448 state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False) 449 450 if return_sam and return_state: 451 return predictor, sam, state 452 if return_sam: 453 return predictor, sam 454 if return_state: 455 return predictor, state 456 return predictor
Get the Segment Anything Predictor.
This function will download the required model or load it from the cached weight file.
This location of the cache can be changed by setting the environment variable: MICROSAM_CACHEDIR.
The name of the requested model can be set via model_type
.
See https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#finetuned-models
for an overview of the available models
Alternatively this function can also load a model from weights stored in a local filepath.
The corresponding file path is given via checkpoint_path
. In this case model_type
must be given as the matching encoder architecture, e.g. "vit_b" if the weights are for
a SAM model with vit_b encoder.
By default the models are downloaded to a folder named 'micro_sam/models' inside your default cache directory, eg:
- Mac: ~/Library/Caches/
- Unix: ~/.cache/
or the value of the XDG_CACHE_HOME environment variable, if defined. - Windows: C:\Users<user>\AppData\Local<AppAuthor><AppName>\Cache See the pooch.os_cache() documentation for more details: https://www.fatiando.org/pooch/latest/api/generated/pooch.os_cache.html
Arguments:
- model_type: The Segment Anything model to use. Will use the standard
vit_l
model by default. To get a list of all available model names you can callget_model_names
. - device: The device for the model. If none is given will use GPU if available.
- checkpoint_path: The path to a file with weights that should be used instead of using the
weights corresponding to
model_type
. If given,model_type
must match the architecture corresponding to the weight file. e.g. if you use weights for SAM withvit_b
encoder thenmodel_type
must be given as "vit_b". - return_sam: Return the sam model object as well as the predictor.
- return_state: Return the unpickled checkpoint state.
- peft_kwargs: Keyword arguments for th PEFT wrapper class.
- flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
- model_kwargs: Additional parameters necessary to initialize the Segment Anything model.
- progress_bar_factory: A function to create a progress bar for the model download.
Returns:
The segment anything predictor.
492def export_custom_sam_model( 493 checkpoint_path: Union[str, os.PathLike], 494 model_type: str, 495 save_path: Union[str, os.PathLike], 496 with_segmentation_decoder: bool = False, 497) -> None: 498 """Export a finetuned Segment Anything Model to the standard model format. 499 500 The exported model can be used by the interactive annotation tools in `micro_sam.annotator`. 501 502 Args: 503 checkpoint_path: The path to the corresponding checkpoint if not in the default model folder. 504 model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t). 505 save_path: Where to save the exported model. 506 with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. 507 If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'. 508 """ 509 _, state = get_sam_model( 510 model_type=model_type, checkpoint_path=checkpoint_path, return_state=True, device="cpu", 511 ) 512 model_state = state["model_state"] 513 prefix = "sam." 514 model_state = OrderedDict( 515 [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()] 516 ) 517 518 # Store the 'decoder_state' as well, if desired. 519 if with_segmentation_decoder: 520 if "decoder_state" not in state: 521 raise RuntimeError(f"'decoder_state' is not found in the model at '{checkpoint_path}'.") 522 decoder_state = state["decoder_state"] 523 save_state = {"model_state": model_state, "decoder_state": decoder_state} 524 else: 525 save_state = model_state 526 527 torch.save(save_state, save_path)
Export a finetuned Segment Anything Model to the standard model format.
The exported model can be used by the interactive annotation tools in micro_sam.annotator
.
Arguments:
- checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
- model_type: The Segment Anything Model type corresponding to the checkpoint (vit_h, vit_b, vit_l or vit_t).
- save_path: Where to save the exported model.
- with_segmentation_decoder: Whether to store the decoder state in the model checkpoint as well. If set to 'True', the model checkpoint will not be compatible with other tools besides 'micro-sam'.
530def export_custom_qlora_model( 531 checkpoint_path: Optional[Union[str, os.PathLike]], 532 finetuned_path: Union[str, os.PathLike], 533 model_type: str, 534 save_path: Union[str, os.PathLike], 535) -> None: 536 """Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format. 537 538 The exported model can be used with the LoRA backbone by passing the relevant `peft_kwargs` to `get_sam_model`. 539 540 Args: 541 checkpoint_path: The path to the base foundation model from which the new model has been finetuned. 542 finetuned_path: The path to the new finetuned model, using QLoRA. 543 model_type: The Segment Anything Model type corresponding to the checkpoint. 544 save_path: Where to save the exported model. 545 """ 546 # Step 1: Get the base SAM model: used to start finetuning from. 547 _, sam = get_sam_model( 548 model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True, 549 ) 550 551 # Step 2: Load the QLoRA-style finetuned model. 552 ft_state, ft_model_state = _load_checkpoint(finetuned_path) 553 554 # Step 3: Get LoRA weights from QLoRA and retain all original parameters from the base SAM model. 555 updated_model_state = {} 556 557 # - At first, we get all LoRA layers from the QLoRA-style finetuned model checkpoint. 558 for k, v in ft_model_state.items(): 559 if k.find("w_b_linear") != -1 or k.find("w_a_linear") != -1: 560 updated_model_state[k] = v 561 562 # - Next, we get all the remaining parameters from the base SAM model. 563 for k, v in sam.state_dict().items(): 564 if k.find("attn.qkv.") != -1: 565 k = k.replace("qkv", "qkv.qkv_proj") 566 updated_model_state[k] = v 567 else: 568 569 updated_model_state[k] = v 570 571 # - Finally, we replace the old model state with the new one (to retain other relevant stuff) 572 ft_state['model_state'] = updated_model_state 573 574 # Step 4: Store the new "state" to "save_path" 575 torch.save(ft_state, save_path)
Export a finetuned Segment Anything Model, in QLoRA style, to LoRA-style checkpoint format.
The exported model can be used with the LoRA backbone by passing the relevant peft_kwargs
to get_sam_model
.
Arguments:
- checkpoint_path: The path to the base foundation model from which the new model has been finetuned.
- finetuned_path: The path to the new finetuned model, using QLoRA.
- model_type: The Segment Anything Model type corresponding to the checkpoint.
- save_path: Where to save the exported model.
959def precompute_image_embeddings( 960 predictor: SamPredictor, 961 input_: np.ndarray, 962 save_path: Optional[Union[str, os.PathLike]] = None, 963 lazy_loading: bool = False, 964 ndim: Optional[int] = None, 965 tile_shape: Optional[Tuple[int, int]] = None, 966 halo: Optional[Tuple[int, int]] = None, 967 verbose: bool = True, 968 batch_size: int = 1, 969 pbar_init: Optional[callable] = None, 970 pbar_update: Optional[callable] = None, 971) -> ImageEmbeddings: 972 """Compute the image embeddings (output of the encoder) for the input. 973 974 If 'save_path' is given the embeddings will be loaded/saved in a zarr container. 975 976 Args: 977 predictor: The Segment Anything predictor. 978 input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries. 979 save_path: Path to save the embeddings in a zarr container. 980 lazy_loading: Whether to load all embeddings into memory or return an 981 object to load them on demand when required. This only has an effect if 'save_path' is given 982 and if the input is 3 dimensional. 983 ndim: The dimensionality of the data. If not given will be deduced from the input data. 984 tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. 985 halo: Overlap of the tiles for tiled prediction. 986 verbose: Whether to be verbose in the computation. 987 batch_size: The batch size for precomputing image embeddings over tiles (or planes). 988 pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. 989 Can be used together with pbar_update to handle napari progress bar in other thread. 990 To enables using this function within a threadworker. 991 pbar_update: Callback to update an external progress bar. 992 993 Returns: 994 The image embeddings. 995 """ 996 ndim = input_.ndim if ndim is None else ndim 997 998 # Handle the embedding save_path. 999 # We don't have a save path, open in memory zarr file to hold tiled embeddings. 1000 if save_path is None: 1001 f = zarr.group() 1002 1003 # We have a save path and it already exists. Embeddings will be loaded from it, 1004 # check that the saved embeddings in there match the parameters of the function call. 1005 elif os.path.exists(save_path): 1006 f = zarr.open(save_path, "a") 1007 _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) 1008 1009 # We have a save path and it does not exist yet. Create the zarr file to which the 1010 # embeddings will then be saved. 1011 else: 1012 f = zarr.open(save_path, "a") 1013 1014 _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update) 1015 1016 if ndim == 2 and tile_shape is None: 1017 embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update) 1018 elif ndim == 2 and tile_shape is not None: 1019 embeddings = _compute_tiled_2d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size) 1020 elif ndim == 3 and tile_shape is None: 1021 embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size) 1022 elif ndim == 3 and tile_shape is not None: 1023 embeddings = _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_update, batch_size) 1024 else: 1025 raise ValueError(f"Invalid dimesionality {input_.ndim}, expect 2 or 3 dim data.") 1026 1027 pbar_close() 1028 return embeddings
Compute the image embeddings (output of the encoder) for the input.
If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
Arguments:
- predictor: The Segment Anything predictor.
- input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
- save_path: Path to save the embeddings in a zarr container.
- lazy_loading: Whether to load all embeddings into memory or return an object to load them on demand when required. This only has an effect if 'save_path' is given and if the input is 3 dimensional.
- ndim: The dimensionality of the data. If not given will be deduced from the input data.
- tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
- halo: Overlap of the tiles for tiled prediction.
- verbose: Whether to be verbose in the computation.
- batch_size: The batch size for precomputing image embeddings over tiles (or planes).
- pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description. Can be used together with pbar_update to handle napari progress bar in other thread. To enables using this function within a threadworker.
- pbar_update: Callback to update an external progress bar.
Returns:
The image embeddings.
1031def set_precomputed( 1032 predictor: SamPredictor, image_embeddings: ImageEmbeddings, i: Optional[int] = None, tile_id: Optional[int] = None, 1033) -> SamPredictor: 1034 """Set the precomputed image embeddings for a predictor. 1035 1036 Args: 1037 predictor: The Segment Anything predictor. 1038 image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`. 1039 i: Index for the image data. Required if `image` has three spatial dimensions 1040 or a time dimension and two spatial dimensions. 1041 tile_id: Index for the tile. This is required if the embeddings are tiled. 1042 1043 Returns: 1044 The predictor with set features. 1045 """ 1046 if tile_id is not None: 1047 tile_features = image_embeddings["features"][tile_id] 1048 tile_image_embeddings = { 1049 "features": tile_features, 1050 "input_size": tile_features.attrs["input_size"], 1051 "original_size": tile_features.attrs["original_size"] 1052 } 1053 return set_precomputed(predictor, tile_image_embeddings, i=i) 1054 1055 device = predictor.device 1056 features = image_embeddings["features"] 1057 assert features.ndim in (4, 5), f"{features.ndim}" 1058 if features.ndim == 5 and i is None: 1059 raise ValueError("The data is 3D so an index i is needed.") 1060 elif features.ndim == 4 and i is not None: 1061 raise ValueError("The data is 2D so an index is not needed.") 1062 1063 if i is None: 1064 predictor.features = features.to(device) if torch.is_tensor(features) else \ 1065 torch.from_numpy(features[:]).to(device) 1066 else: 1067 predictor.features = features[i].to(device) if torch.is_tensor(features) else \ 1068 torch.from_numpy(features[i]).to(device) 1069 1070 predictor.original_size = image_embeddings["original_size"] 1071 predictor.input_size = image_embeddings["input_size"] 1072 predictor.is_image_set = True 1073 1074 return predictor
Set the precomputed image embeddings for a predictor.
Arguments:
- predictor: The Segment Anything predictor.
- image_embeddings: The precomputed image embeddings computed by
precompute_image_embeddings
. - i: Index for the image data. Required if
image
has three spatial dimensions or a time dimension and two spatial dimensions. - tile_id: Index for the tile. This is required if the embeddings are tiled.
Returns:
The predictor with set features.
1082def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: 1083 """Compute the intersection over union of two masks. 1084 1085 Args: 1086 mask1: The first mask. 1087 mask2: The second mask. 1088 1089 Returns: 1090 The intersection over union of the two masks. 1091 """ 1092 overlap = np.logical_and(mask1 == 1, mask2 == 1).sum() 1093 union = np.logical_or(mask1 == 1, mask2 == 1).sum() 1094 eps = 1e-7 1095 iou = float(overlap) / (float(union) + eps) 1096 return iou
Compute the intersection over union of two masks.
Arguments:
- mask1: The first mask.
- mask2: The second mask.
Returns:
The intersection over union of the two masks.
1099def get_centers_and_bounding_boxes( 1100 segmentation: np.ndarray, mode: str = "v" 1101) -> Tuple[Dict[int, np.ndarray], Dict[int, tuple]]: 1102 """Returns the center coordinates of the foreground instances in the ground-truth. 1103 1104 Args: 1105 segmentation: The segmentation. 1106 mode: Determines the functionality used for computing the centers. 1107 If 'v', the object's eccentricity centers computed by vigra are used. 1108 If 'p' the object's centroids computed by skimage are used. 1109 1110 Returns: 1111 A dictionary that maps object ids to the corresponding centroid. 1112 A dictionary that maps object_ids to the corresponding bounding box. 1113 """ 1114 assert mode in ["p", "v"], "Choose either 'p' for regionprops or 'v' for vigra" 1115 1116 properties = regionprops(segmentation) 1117 1118 if mode == "p": 1119 center_coordinates = {prop.label: prop.centroid for prop in properties} 1120 elif mode == "v": 1121 center_coordinates = vigra.filters.eccentricityCenters(segmentation.astype('float32')) 1122 center_coordinates = {i: coord for i, coord in enumerate(center_coordinates) if i > 0} 1123 1124 bbox_coordinates = {prop.label: prop.bbox for prop in properties} 1125 1126 assert len(bbox_coordinates) == len(center_coordinates), f"{len(bbox_coordinates)}, {len(center_coordinates)}" 1127 return center_coordinates, bbox_coordinates
Returns the center coordinates of the foreground instances in the ground-truth.
Arguments:
- segmentation: The segmentation.
- mode: Determines the functionality used for computing the centers. If 'v', the object's eccentricity centers computed by vigra are used. If 'p' the object's centroids computed by skimage are used.
Returns:
A dictionary that maps object ids to the corresponding centroid. A dictionary that maps object_ids to the corresponding bounding box.
1130def load_image_data(path: str, key: Optional[str] = None, lazy_loading: bool = False) -> np.ndarray: 1131 """Helper function to load image data from file. 1132 1133 Args: 1134 path: The filepath to the image data. 1135 key: The internal filepath for complex data formats like hdf5. 1136 lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data. 1137 1138 Returns: 1139 The image data. 1140 """ 1141 if key is None: 1142 image_data = imageio.imread(path) 1143 else: 1144 with open_file(path, mode="r") as f: 1145 image_data = f[key] 1146 if not lazy_loading: 1147 image_data = image_data[:] 1148 1149 return image_data
Helper function to load image data from file.
Arguments:
- path: The filepath to the image data.
- key: The internal filepath for complex data formats like hdf5.
- lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
Returns:
The image data.
1152def segmentation_to_one_hot(segmentation: np.ndarray, segmentation_ids: Optional[np.ndarray] = None) -> torch.Tensor: 1153 """Convert the segmentation to one-hot encoded masks. 1154 1155 Args: 1156 segmentation: The segmentation. 1157 segmentation_ids: Optional subset of ids that will be used to subsample the masks. 1158 1159 Returns: 1160 The one-hot encoded masks. 1161 """ 1162 masks = segmentation.copy() 1163 if segmentation_ids is None: 1164 n_ids = int(segmentation.max()) 1165 1166 else: 1167 msg = "No foreground objects were found." 1168 if len(segmentation_ids) == 0: # The list should not be completely empty. 1169 raise RuntimeError(msg) 1170 1171 if 0 in segmentation_ids: # The list should not have 'zero' as a value. 1172 raise RuntimeError(msg) 1173 1174 # the segmentation ids have to be sorted 1175 segmentation_ids = np.sort(segmentation_ids) 1176 1177 # set the non selected objects to zero and relabel sequentially 1178 masks[~np.isin(masks, segmentation_ids)] = 0 1179 masks = relabel_sequential(masks)[0] 1180 n_ids = len(segmentation_ids) 1181 1182 masks = torch.from_numpy(masks) 1183 1184 one_hot_shape = (n_ids + 1,) + masks.shape 1185 masks = masks.unsqueeze(0) # add dimension to scatter 1186 masks = torch.zeros(one_hot_shape).scatter_(0, masks, 1)[1:] 1187 1188 # add the extra singleton dimenion to get shape NUM_OBJECTS x 1 x H x W 1189 masks = masks.unsqueeze(1) 1190 return masks
Convert the segmentation to one-hot encoded masks.
Arguments:
- segmentation: The segmentation.
- segmentation_ids: Optional subset of ids that will be used to subsample the masks.
Returns:
The one-hot encoded masks.
1193def get_block_shape(shape: Tuple[int]) -> Tuple[int]: 1194 """Get a suitable block shape for chunking a given shape. 1195 1196 The primary use for this is determining chunk sizes for 1197 zarr arrays or block shapes for parallelization. 1198 1199 Args: 1200 shape: The image or volume shape. 1201 1202 Returns: 1203 The block shape. 1204 """ 1205 ndim = len(shape) 1206 if ndim == 2: 1207 block_shape = tuple(min(bs, sh) for bs, sh in zip((1024, 1024), shape)) 1208 elif ndim == 3: 1209 block_shape = tuple(min(bs, sh) for bs, sh in zip((32, 256, 256), shape)) 1210 else: 1211 raise ValueError(f"Only 2 or 3 dimensional shapes are supported, got {ndim}D.") 1212 1213 return block_shape
Get a suitable block shape for chunking a given shape.
The primary use for this is determining chunk sizes for zarr arrays or block shapes for parallelization.
Arguments:
- shape: The image or volume shape.
Returns:
The block shape.
1216def micro_sam_info(): 1217 """Display μSAM information using a rich console.""" 1218 import psutil 1219 import platform 1220 from rich.panel import Panel 1221 from rich.table import Table 1222 from rich.console import Console 1223 1224 import torch 1225 import micro_sam 1226 1227 # Open up a new console. 1228 console = Console() 1229 1230 # The header for information CLI. 1231 console.print("[bold #0072B2]μSAM Information Booth[/bold #0072B2]", justify="center") 1232 console.print("-" * console.width) 1233 1234 # μSAM version panel. 1235 console.print( 1236 Panel(f"[bold #F0E442]Version:[/bold #F0E442] {micro_sam.__version__}", title="μSAM Version", expand=True) 1237 ) 1238 1239 # The documentation link panel. 1240 console.print( 1241 Panel( 1242 "[bold #CC79A7]Tools documented at:[/bold #CC79A7]\n" 1243 "https://computational-cell-analytics.github.io/micro-sam", title="Documentation" 1244 ) 1245 ) 1246 1247 # The publication panel. 1248 console.print( 1249 Panel( 1250 "[bold #E69F00]Published in Nature Methods:[/bold #E69F00]\n" 1251 "https://www.nature.com/articles/s41592-024-02580-4", title="Publication" 1252 ) 1253 ) 1254 1255 # The cache directory panel. 1256 console.print( 1257 Panel(f"[bold #009E73]Cache Directory:[/bold #009E73]\n{get_cache_directory()}", title="Cache Directory") 1258 ) 1259 1260 # The available models panel. 1261 available_models = list(get_model_names()) 1262 # We filter out the decoder models. 1263 available_models = [m for m in available_models if not m.endswith("_decoder")] 1264 model_list = "\n".join(available_models) 1265 console.print( 1266 Panel(f"[bold #D55E00]Available Models:[/bold #D55E00]\n{model_list}", title="List of Supported Models") 1267 ) 1268 1269 # The system information table. 1270 total_memory = psutil.virtual_memory().total / (1024 ** 3) 1271 table = Table(title="System Information", show_header=True, header_style="bold #0072B2", expand=True) 1272 table.add_column("Property") 1273 table.add_column("Value", style="bold #56B4E9") 1274 table.add_row("System", platform.system()) 1275 table.add_row("Node Name", platform.node()) 1276 table.add_row("Release", platform.release()) 1277 table.add_row("Version", platform.version()) 1278 table.add_row("Machine", platform.machine()) 1279 table.add_row("Processor", platform.processor()) 1280 table.add_row("Platform", platform.platform()) 1281 table.add_row("Total RAM (GB)", f"{total_memory:.2f}") 1282 console.print(table) 1283 1284 # The device information and check for available GPU acceleration. 1285 default_device = _get_default_device() 1286 1287 if default_device == "cuda": 1288 device_index = torch.cuda.current_device() 1289 device_name = torch.cuda.get_device_name(device_index) 1290 console.print(Panel(f"[bold #000000]CUDA Device:[/bold #000000] {device_name}", title="GPU Information")) 1291 elif default_device == "mps": 1292 console.print(Panel("[bold #000000]MPS Device is available[/bold #000000]", title="GPU Information")) 1293 else: 1294 console.print( 1295 Panel( 1296 "[bold #000000]No GPU acceleration device detected. Running on CPU.[/bold #000000]", 1297 title="Device Information" 1298 ) 1299 )
Display μSAM information using a rich console.